28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/TypeSwitch.h"
36#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
63 IRMapping &map)
const final {
69 IRMapping &map)
const final {
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(Dialect *dialect)
78 : BytecodeDialectInterface(dialect) {}
83 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
87 LogicalResult writeAttribute(Attribute attr,
88 DialectBytecodeWriter &writer)
const override {
89 return ::writeAttribute(attr, writer);
95 Type readType(DialectBytecodeReader &reader)
const override {
99 LogicalResult writeType(Type type,
100 DialectBytecodeWriter &writer)
const override {
101 return ::writeType(type, writer);
104 void writeVersion(DialectBytecodeWriter &writer)
const final {
108 std::unique_ptr<DialectVersion>
109 readVersion(DialectBytecodeReader &reader)
const final {
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(Operation *topLevelOp,
116 const DialectVersion &version)
const final {
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(
shape, [](
int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
147 return RankedTensorType::get(
shape, elementType);
154void TosaDialect::initialize() {
156#define GET_TYPEDEF_LIST
157#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164#define GET_ATTRDEF_LIST
165#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
207 typeAttr = TypeAttr::get(elementType);
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
260 auto tensorTypeAttr = TypeAttr::get(tensorType);
265 if (initialValueAttr) {
276template <
typename EnumType>
277ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
279 llvm::StringRef name;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name ==
"rounding_mode" &&
289 auto sym = symbolizeRoundingMode(kw);
292 <<
"invalid rounding_mode value: " << kw;
293 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 auto sym = symbolizeResizeMode(kw);
304 <<
"invalid resize mode value: " << kw;
305 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 auto sym = symbolizeNanPropagationMode(kw);
317 <<
"invalid nan_mode value: " << kw;
318 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
325 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
327 auto sym = symbolizeBlockSize(kw);
330 <<
"invalid block_size value: " << kw;
331 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
343template <
typename EnumType>
348 [&]() { return parser.parseOperand(operands.emplace_back()); }))
356 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
373 result.addTypes(fnTy.getResults());
374 result.addAttributes(attrs);
380 parser << namedAttr.
getName().strref() <<
" = ";
382 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
383 parser << roundingModeAttr.getValue();
384 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
385 parser << resizeModeAttr.getValue();
386 }
else if (
auto nanPropagationModeAttr =
387 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
388 parser << nanPropagationModeAttr.getValue();
389 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
390 parser << blockSizeAttr.getValue();
403 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
405 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
406 if (nanAttr.getValue() == kDefaultNanValue) {
408 toPrint.erase(attr.getName());
414 if (!toPrint.empty()) {
416 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
417 printNamedAttr(parser, namedAttr);
433 llvm::interleaveComma(op->
getAttrs(), parser,
435 printNamedAttr(parser, namedAttr);
447 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
451 printWithEnumHandling(parser, *
this);
455 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
459 printWithEnumHandling(parser, *
this);
463 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
467 printWithEnumHandling(parser, *
this);
471 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
475 printWithNanPropagationHandling(parser, *
this);
479 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
483 printWithNanPropagationHandling(parser, *
this);
487 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
491 printWithNanPropagationHandling(parser, *
this);
495 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
499 printWithNanPropagationHandling(parser, *
this);
503 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
507 printWithNanPropagationHandling(parser, *
this);
511 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
515 printWithNanPropagationHandling(parser, *
this);
519 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
523 printWithNanPropagationHandling(parser, *
this);
526ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
528 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
532 printWithEnumHandling(parser, *
this);
535ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
537 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
540void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
541 printWithEnumHandling(parser, *
this);
544ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
546 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
550 printWithEnumHandling(parser, *
this);
565 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566 srcType = quantType.getStorageType();
575 Value valZp, StringRef name) {
580 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
584 if (!bothInts || !sameBitWidth) {
586 <<
"expected " << name <<
" and " << name
587 <<
"_zp to both be integer of the same bitwidth, but got " << eType
588 <<
" vs. " << eZpType;
595 Value src, int32_t val) {
598 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
599 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
600 const auto padConstAttr{
601 llvm::isa<FloatType>(srcElemType)
606 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
610 if (dyn_cast<tosa::mxint8Type>(type))
621 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
622 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
624 auto inputEType = inputType.getElementType();
625 auto weightEType = weightType.getElementType();
627 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
629 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
630 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
631 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
633 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634 inputEType = quantType.getStorageType();
636 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637 weightEType = quantType.getStorageType();
639 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640 biasEType = quantType.getStorageType();
642 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643 resultEType = quantType.getStorageType();
645 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
649 "expect both bias and result to have same element type, got ")
650 << biasEType <<
" and " << resultEType;
654 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
655 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
656 if (inputEType != weightEType) {
658 "expect both input and weight to have same element type, got ")
659 << inputEType <<
" and " << weightEType;
664 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
665 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
668 if (inputIsFloat != weightIsFloat) {
670 "expect both input and weight to be float or not together, got ")
671 << inputEType <<
" and " << weightEType;
676 if (inputEType != inputZpEType) {
677 return op.emitOpError(
"expect both input and its zero point are the same "
678 "element type, got ")
679 << inputEType <<
" and " << inputZpEType;
683 if (weightEType != weightZpEType) {
684 return op.emitOpError(
"expect both weight and its zero point are the same "
685 "element type, got ")
686 << weightEType <<
" and " << weightZpEType;
689 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
690 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
693 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
694 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
700LogicalResult tosa::ConstOp::verify() {
702 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
703 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
705 if (!attrType || !outputType) {
706 emitOpError(
"expected tensors for attr/result type");
710 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711 outputType.getElementType())) {
712 if (
result.getStorageType() == attrType.getElementType())
716 if (attrType.getElementType() != outputType.getElementType()) {
717 emitOpError(
"expected same attr/result element types");
727 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
729 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730 inputEType = quantType.getStorageType();
732 auto accType = op.getAccType();
733 if (inputEType.isInteger(8) && !accType.isInteger(32))
734 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
736 if (inputEType.isInteger(16) && !accType.isInteger(48))
737 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
739 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
740 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
742 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
743 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
745 if (inputEType.isBF16() && !accType.isF32())
746 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
748 if (inputEType.isF32() && !accType.isF32())
749 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
752 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
754 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755 resultEType = quantType.getStorageType();
768 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
769 return op.emitOpError(
"expect all padding values to be >= 0, got ")
773 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
774 return op.emitOpError(
"expect all stride values to be >= 1, got ")
778 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
779 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
782 const RankedTensorType outputType =
783 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
788 const RankedTensorType inputType =
789 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
790 const RankedTensorType weightType =
791 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
793 if (inputType && weightType) {
794 const auto verifyOutputSize =
798 const int64_t dilation,
const llvm::StringRef dimName,
799 const llvm::StringRef dimAxis,
800 const llvm::StringRef padBeforeName,
801 const llvm::StringRef padAfterName) -> LogicalResult {
802 if (inputSize == ShapedType::kDynamic ||
803 kernelSize == ShapedType::kDynamic)
808 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
809 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
811 if (!calculatedOutSizeMinusOne.has_value())
812 return op.emitOpError(
"expected input_")
813 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
814 << padAfterName <<
" - (kernel_" << dimName
815 <<
" - 1) * dilation_" << dimAxis
816 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
817 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
818 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
821 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
822 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
823 return op.emitOpError(
"calculated output ")
824 << dimName <<
" did not match expected: "
825 <<
"calculated=" << calculatedOutSize
826 <<
", expected=" << outputSize;
832 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
833 if (failed(verifyOutputSize(
834 inputType.getDimSize(1), weightType.getDimSize(1),
835 outputType.getDimSize(1), padding[0], padding[1], strides[0],
836 dilations[0],
"height",
"y",
"top",
"bottom")))
839 if (failed(verifyOutputSize(
840 inputType.getDimSize(2), weightType.getDimSize(2),
841 outputType.getDimSize(2), padding[2], padding[3], strides[1],
842 dilations[1],
"width",
"x",
"left",
"right")))
847 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
848 if (failed(verifyOutputSize(
849 inputType.getDimSize(1), weightType.getDimSize(0),
850 outputType.getDimSize(1), padding[0], padding[1], strides[0],
851 dilations[0],
"height",
"y",
"top",
"bottom")))
854 if (failed(verifyOutputSize(
855 inputType.getDimSize(2), weightType.getDimSize(1),
856 outputType.getDimSize(2), padding[2], padding[3], strides[1],
857 dilations[1],
"width",
"x",
"left",
"right")))
862 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
863 if (failed(verifyOutputSize(
864 inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0],
"depth",
"d",
"front",
"back")))
869 if (failed(verifyOutputSize(
870 inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1],
"height",
"y",
"top",
"bottom")))
875 if (failed(verifyOutputSize(
876 inputType.getDimSize(3), weightType.getDimSize(3),
877 outputType.getDimSize(3), padding[4], padding[5], strides[2],
878 dilations[2],
"width",
"x",
"left",
"right")))
883 const RankedTensorType biasType =
884 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
889 const int64_t biasChannels = biasType.getDimSize(0);
891 outputType.getDimSize(outputType.getRank() - 1);
892 if (biasChannels == ShapedType::kDynamic ||
893 outputChannels == ShapedType::kDynamic)
897 if (biasChannels != outputChannels && biasChannels != 1)
898 return op.emitOpError(
899 "bias channels expected to be equal to output channels (")
900 << outputChannels <<
") or 1, got " << biasChannels;
907 StringRef name1,
Type type2,
909 auto shapeType1 = dyn_cast<ShapedType>(type1);
910 auto shapeType2 = dyn_cast<ShapedType>(type2);
911 if (!shapeType1 || !shapeType2)
914 auto elemType1 = shapeType1.getElementType();
915 auto elemType2 = shapeType2.getElementType();
916 if (elemType1 != elemType2)
918 <<
"require same element type for " << name1 <<
" (" << elemType1
919 <<
") and " << name2 <<
" (" << elemType2 <<
")";
923 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
924 << name2 <<
" (" << type2 <<
")";
934 if (list1.size() != list2.size())
936 <<
"require same number of values in " << name1 <<
" ("
937 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
939 for (
auto [type1, type2] :
959 op->template getParentWithTrait<OpTrait::SymbolTable>();
966 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
970 return op->emitOpError(
"'")
971 << op.getName() <<
"' has not been declared by 'tosa.variable'";
985 StringRef aName =
"input",
986 StringRef bName =
"output") {
987 auto aTType = llvm::dyn_cast<TensorType>(aType);
988 auto bTType = llvm::dyn_cast<TensorType>(bType);
990 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
994 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
997 auto aElementType = aTType.getElementType();
998 auto bElementType = bTType.getElementType();
1000 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1002 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1003 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1004 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1005 aElementType != bElementType) {
1010 op.emitOpError(
"expect ")
1011 << aName <<
" and " << bName <<
" to have same element type, got "
1012 << aElementType <<
" and " << bElementType;
1018LogicalResult tosa::ArgMaxOp::verify() {
1019 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1022 if (
const auto resultETy = resultType.getElementType();
1023 !resultETy.isIntOrIndex())
1024 return emitOpError(
"result tensor is not of integer type");
1026 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1027 if (!inputType.hasRank())
1031 const int64_t axis = getAxisAttr().getInt();
1032 if (((axis < 0) || axis >= inputType.getRank()))
1033 return emitOpError(
"specified axis is outside the rank of the tensor");
1035 if (!resultType.hasRank())
1041 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1044 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1049template <
typename T>
1052 if (llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1053 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1057 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1058 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1062 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1063 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1067 const int64_t kernelX = kernel[1];
1068 const int64_t padLeft = padding[2];
1069 const int64_t padRight = padding[3];
1070 if (padRight >= kernelX || padLeft >= kernelX)
1071 return op.emitOpError(
"expected left/right padding to be less than the "
1072 "width of the kernel, got pad_left=")
1073 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1075 const int64_t kernelY = kernel[0];
1076 const int64_t padTop = padding[0];
1077 const int64_t padBottom = padding[1];
1078 if (padTop >= kernelY || padBottom >= kernelY)
1079 return op.emitOpError(
"expected top/bottom padding to be less than the "
1080 "height of the kernel, got pad_top=")
1081 << padTop <<
", pad_bottom=" << padBottom
1082 <<
", kernel_y=" << kernelY;
1084 const auto inputType =
1085 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1086 const auto outputType =
1087 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1088 if (!inputType || !outputType)
1091 const auto verifyOutputSize =
1095 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1096 const llvm::StringRef padBeforeName,
1097 const llvm::StringRef padAfterName) -> LogicalResult {
1098 if (ShapedType::isDynamic(inputSize))
1101 const std::optional<int64_t> calculatedOutSizeMinusOne =
1102 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1103 if (!calculatedOutSizeMinusOne.has_value())
1104 return op.emitOpError(
"expected input_")
1105 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1106 << padAfterName <<
" - kernel_" << dimAxis
1107 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1108 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1109 << kernelSize <<
") / " << strideSize;
1111 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1112 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1113 return op.emitOpError(
"calculated output ")
1114 << dimName <<
" did not match expected: "
1115 <<
"calculated=" << calculatedOutSize
1116 <<
", expected=" << outputSize;
1121 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1122 kernel[0], strides[0], padding[0], padding[1],
1123 "height",
"y",
"top",
"bottom")))
1126 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1127 kernel[1], strides[1], padding[2], padding[3],
1128 "width",
"x",
"left",
"right")))
1134LogicalResult tosa::AvgPool2dOp::verify() {
1143 auto accType = getAccType();
1144 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1145 return emitOpError(
"accumulator type for integer tensor is not i32");
1147 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1148 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1150 if (inputETy.
isBF16() && !accType.isF32())
1151 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1153 if (inputETy.
isF32() && !accType.isF32())
1154 return emitOpError(
"accumulator type for f32 tensor is not f32");
1156 if (inputETy != inputZpETy)
1157 return emitOpError(
"expect both input and its zero point are the same "
1158 "element type, got ")
1159 << inputETy <<
" and " << inputZpETy;
1161 if (resultETy != outputZpETy)
1162 return emitOpError(
"expect both output and its zero point are the same "
1163 "element type, got ")
1164 << resultETy <<
" and " << outputZpETy;
1166 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1167 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1170 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1171 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1177LogicalResult tosa::ClampOp::verify() {
1179 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1180 if (
auto quantType =
1181 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182 inputETy = quantType.getStorageType();
1185 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1186 if (
auto quantType =
1187 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188 outputETy = quantType.getStorageType();
1190 if (inputETy != outputETy)
1191 return emitOpError(
"input/output element types are incompatible.");
1193 auto maxValAttr = getMaxValAttr();
1194 auto minValAttr = getMinValAttr();
1198 if (inputETy.
isInteger(dataTypeBitWidth)) {
1202 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1203 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1204 if (!intMaxValAttr || !intMinValAttr ||
1205 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1206 (intMaxValAttr.getType() != inputETy))
1207 return emitOpError(
"min/max attributes types are incompatible with "
1208 "input/output element types.");
1211 const bool isBoolean = inputETy.
isInteger(1);
1212 const APInt minVal = intMinValAttr.getValue();
1213 const APInt maxVal = intMaxValAttr.getValue();
1214 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1215 return emitOpError(
"expected min_val <= max_val, got min_val=")
1216 << minValAttr <<
", max_val=" << maxValAttr;
1221 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1222 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1223 if (!floatMaxValAttr || !floatMinValAttr ||
1224 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1225 (floatMaxValAttr.getType() != inputETy))
1226 return emitOpError(
"min/max attributes types are incompatible with "
1227 "input/output element types.");
1229 const APFloat minVal = floatMinValAttr.getValue();
1230 const APFloat maxVal = floatMaxValAttr.getValue();
1231 if (minVal.isNaN() || maxVal.isNaN())
1232 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1233 << minValAttr <<
", max_val=" << maxValAttr;
1235 if (maxVal < minVal)
1236 return emitOpError(
"expected min_val <= max_val, got min_val=")
1237 << minValAttr <<
", max_val=" << maxValAttr;
1257 result.addOperands({input, weight, bias, zps.first, zps.second});
1258 result.addAttribute(
"pad", pad);
1259 result.addAttribute(
"stride", stride);
1260 result.addAttribute(
"dilation", dilation);
1261 result.addAttribute(
"acc_type", accType);
1262 Type finalOutputType = outputType;
1268 result.addTypes(finalOutputType);
1279 result.addOperands({input, weight, bias, zps.first, zps.second});
1280 result.addAttribute(
"out_pad", outpad);
1281 result.addAttribute(
"stride", stride);
1282 result.addAttribute(
"acc_type", accType);
1283 Type finalOutputType = outputType;
1289 result.addTypes(finalOutputType);
1300 result.addOperands({a,
b, zps.first, zps.second});
1302 Type finalOutputType{outputType};
1305 auto inputBits = eType.getIntOrFloatBitWidth();
1307 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1308 assert(outputShapedType &&
"Output must be a shaped type");
1310 IntegerType accElementType;
1311 if (inputBits == 16)
1316 finalOutputType = outputShapedType.clone(accElementType);
1318 result.addTypes(finalOutputType);
1327 DenseArrayAttr kernel, DenseArrayAttr stride,
1328 DenseArrayAttr pad, TypeAttr accType) {
1333 if (
auto quantAttr =
1335 inputZp = quantAttr.getInputZp();
1336 outputZp = quantAttr.getOutputZp();
1338 const std::optional<Value> inputZpOp =
1343 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1345 const std::optional<Value> outputZpOp =
1348 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1349 "quantized AVG_POOL2D op");
1352 if (inputZpOp && outputZpOp) {
1353 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1358 result.addOperands({input});
1360 result.addAttribute(
"kernel", kernel);
1361 result.addAttribute(
"stride", stride);
1362 result.addAttribute(
"pad", pad);
1363 result.addAttribute(
"acc_type", accType);
1364 result.types.push_back(outputType);
1378 input1Zp = quantAttr.getInputZp();
1379 outputZp = quantAttr.getOutputZp();
1381 const std::optional<Value> input1ZpOp =
1385 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1388 const std::optional<Value> outputZpOp =
1392 loc,
"Failed to create output zero point for quantized NEGATE op");
1395 if (input1ZpOp && outputZpOp) {
1396 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1401 result.addOperands({input});
1404 result.types.push_back(outputType);
1417 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1420 result.addOperands({input, paddings, padConstOp});
1421 result.types.push_back(outputType);
1425 StringRef name,
Type variableType,
1430 auto shapedType = dyn_cast<ShapedType>(variableType);
1432 (
void)
emitError(loc,
"variable type must be a shaped type");
1435 if (!shapedType.hasRank()) {
1436 (
void)
emitError(loc,
"variable type must be a ranked type");
1440 auto elementType = shapedType.getElementType();
1441 auto elementTypeAttr = TypeAttr::get(elementType);
1445 result.addAttribute(
"sym_name", nameAttr);
1446 result.addAttribute(
"var_shape", varShapeAttr);
1447 result.addAttribute(
"type", elementTypeAttr);
1448 result.addAttribute(
"initial_value", initialValue);
1458 for (
int i = 0, e = operands.size(); i != e; ++i) {
1460 if (!
shape.hasRank()) {
1465 outRank = std::max<int64_t>(outRank,
shape.getRank());
1468 outShape.resize(outRank, 1);
1470 for (
int i = 0, e = operands.size(); i != e; ++i) {
1472 auto rankDiff = outShape.size() -
shape.getRank();
1474 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1475 auto dim1 = outShape[i + rankDiff];
1476 auto dim2 =
shape.getDimSize(i);
1477 auto resolvedDim = dim1;
1481 }
else if (dim2 == 1) {
1483 }
else if (dim1 != dim2) {
1486 outShape[i + rankDiff] = resolvedDim;
1493LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1494 MLIRContext *context, ::std::optional<Location> location,
1495 ArgMaxOp::Adaptor adaptor,
1498 IntegerAttr axis = adaptor.getProperties().axis;
1499 int32_t axisVal = axis.getValue().getSExtValue();
1501 if (!inputShape.hasRank()) {
1507 outShape.reserve(inputShape.getRank() - 1);
1508 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1511 outShape.push_back(inputShape.getDimSize(i));
1518LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1519 MLIRContext *context, ::std::optional<Location> location,
1520 RFFT2dOp::Adaptor adaptor,
1522 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1524 if (!inputShape.hasRank())
1528 outputShape.resize(3, ShapedType::kDynamic);
1529 outputShape[0] = inputShape.getDimSize(0);
1530 outputShape[1] = inputShape.getDimSize(1);
1531 int64_t inWidth = inputShape.getDimSize(2);
1535 if (inWidth != ShapedType::kDynamic)
1536 outputShape[2] = inWidth / 2 + 1;
1545 const llvm::StringRef dimName) {
1546 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1549 << dimName <<
" to be a power of two, got " << dimSize;
1554LogicalResult tosa::RFFT2dOp::verify() {
1555 const auto outputTypes = getResultTypes();
1557 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1559 const auto inputType =
1560 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1564 const int64_t height = inputType.getDimSize(1);
1565 if (ShapedType::isStatic(height) &&
1569 const int64_t width = inputType.getDimSize(2);
1570 if (ShapedType::isStatic(width) &&
1574 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1580 outputType.getShape().drop_back())))
1581 return emitOpError(
"expected batch and height dimensions of input/output "
1582 "to match, got input=")
1583 << inputType <<
" output=" << outputType;
1586 const int64_t outputWidth = outputType.getDimSize(2);
1587 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1588 (outputWidth != (width / 2) + 1))
1590 "expected output width to be equal to input_width / 2 + 1, got ")
1596LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1597 MLIRContext *context, ::std::optional<Location> location,
1598 FFT2dOp::Adaptor adaptor,
1600 inferredReturnShapes.push_back(
1602 inferredReturnShapes.push_back(
1607LogicalResult tosa::FFT2dOp::verify() {
1608 const auto inputRealType =
1609 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1610 const auto inputImagType =
1611 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1612 if (!inputRealType || !inputImagType)
1615 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1616 return ShapedType::isDynamic(a) ? a :
b;
1619 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1620 inputImagType.getDimSize(1));
1621 if (ShapedType::isStatic(height) &&
1625 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1626 inputImagType.getDimSize(2));
1627 if (ShapedType::isStatic(width) &&
1634LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1635 MLIRContext *context, ::std::optional<Location> location,
1636 ConcatOp::Adaptor adaptor,
1639 const Properties &prop = adaptor.getProperties();
1640 int32_t axis = prop.axis.getValue().getSExtValue();
1642 bool hasRankedInput =
false;
1643 for (
auto operand : adaptor.getOperands()) {
1645 if (!operandShape.hasRank())
1649 if (!hasRankedInput)
1650 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1653 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1654 if (i == axis || operandShape.isDynamicDim(i))
1656 if (outputShape[i] == ShapedType::kDynamic)
1657 outputShape[i] = operandShape.getDimSize(i);
1658 if (outputShape[i] != operandShape.getDimSize(i))
1660 "Cannot concat tensors with different sizes"
1661 " on the non-axis dimension ",
1665 hasRankedInput =
true;
1668 if (adaptor.getInput1().empty())
1672 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1673 if (!hasRankedInput) {
1680 for (
auto operand : adaptor.getOperands()) {
1685 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1686 concatDimSize = ShapedType::kDynamic;
1690 concatDimSize += operandShape.getDimSize(axis);
1693 outputShape[axis] = concatDimSize;
1699LogicalResult tosa::ConcatOp::verify() {
1701 auto outType = getOutput().getType();
1705 if (inputList.empty())
1708 if (!llvm::all_of(inputList, [&](
auto input) {
1710 *
this, input.getType(), outType));
1715 const int32_t axis = getAxis();
1717 for (
const auto &input : inputList) {
1718 const Type inputType = input.getType();
1720 if (currShape.hasRank()) {
1721 firstRankedInputShape = currShape;
1723 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1724 return emitOpError(
"expect axis to be within range 0 < axis < "
1725 "rank(input1[firstRankedTensorIdx]), got ")
1731 const auto allOperandsHasRank = [](
const Value input) {
1734 if (llvm::all_of(inputList, allOperandsHasRank)) {
1737 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1739 const int64_t inputRank = inputShape.getRank();
1740 const size_t operandNum =
index + 1;
1743 if (inputRank != firstInputRank)
1745 "expect all operands to have the same rank, but got ")
1746 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1750 for (
int i = 0; i < inputRank; i++) {
1751 const int64_t inputDim = inputShape.getDimSize(i);
1753 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1754 inputShape.isDynamicDim(i))
1756 if (inputDim != firstInputDim)
1757 return emitOpError(
"expect all operand shapes to have the same sizes "
1758 "on non-axis dimensions, but got ")
1759 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1760 <<
" on operands 0 and " << operandNum;
1765 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1766 return emitOpError(
"expect output rank to match inputs rank, got ")
1767 << outputShape.getRank() <<
" vs " << firstInputRank;
1771 for (
const auto &input : inputList) {
1773 if (inputShape.isDynamicDim(axis)) {
1778 axisSum += inputShape.getDimSize(axis);
1781 if (axisSum >= 0 && outputShape.hasRank() &&
1782 !outputShape.isDynamicDim(axis) &&
1783 axisSum != outputShape.getDimSize(axis))
1784 return emitOpError(
"requires sum of axis dimensions of input1 "
1785 "equal to output axis dimension, got ")
1786 << axisSum <<
" and " << outputShape.getDimSize(axis);
1792LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1793 MLIRContext *context, ::std::optional<Location> location,
1797 auto elementType = IntegerType::get(context, 1);
1810 if (l.size() != r.size() || l.size() != 1)
1815LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1816 MLIRContext *context, ::std::optional<Location> location,
1817 MatMulOp::Adaptor adaptor,
1824 outShape.resize(3, ShapedType::kDynamic);
1826 if (lhsShape.hasRank()) {
1827 outShape[0] = lhsShape.getDimSize(0);
1828 outShape[1] = lhsShape.getDimSize(1);
1831 if (rhsShape.hasRank()) {
1832 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1834 outShape[2] = rhsShape.getDimSize(2);
1841LogicalResult MatMulOp::verify() {
1842 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1843 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1847 return emitOpError(
"expect a shaped tensor for input a, got ")
1848 << getA().getType();
1851 return emitOpError(
"expect a shaped tensor for input b, got ")
1852 << getB().getType();
1854 auto aElementType = aType.getElementType();
1855 auto bElementType = bType.getElementType();
1857 auto aQuantizedEType =
1858 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1859 auto bQuantizedEType =
1860 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1862 if (aQuantizedEType || bQuantizedEType) {
1863 if (!aQuantizedEType || !bQuantizedEType) {
1864 return emitOpError(
"expect operands to be both quantized or both not "
1866 << aElementType <<
" and " << bElementType;
1869 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1870 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1871 if (aQuantWidth != bQuantWidth) {
1872 return emitOpError(
"expect quantized operands to have same widths, got ")
1873 << aQuantWidth <<
" and " << bQuantWidth;
1880 if (aEType != aZpEType) {
1881 return emitOpError(
"expect input a and a_zp have the same "
1882 "element type, got ")
1883 << aEType <<
" and " << aZpEType;
1888 if (bEType != bZpEType) {
1889 return emitOpError(
"expect input b and b_zp have the same "
1890 "element type, got ")
1891 << bEType <<
" and " << bZpEType;
1894 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1895 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1898 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1899 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1905LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1906 MLIRContext *context, ::std::optional<Location> location,
1907 MatmulTBlockScaledOp::Adaptor adaptor,
1911 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1912 if (aDataShape.hasRank()) {
1913 outShape[0] = aDataShape.getDimSize(0);
1914 outShape[1] = aDataShape.getDimSize(1);
1917 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1918 if (aScaleShape.hasRank()) {
1919 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1921 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1926 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1927 if (bDataShape.hasRank()) {
1928 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1929 if (bDataBatchSize != 1)
1931 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1932 outShape[2] = bDataShape.getDimSize(1);
1935 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1936 if (bScaleShape.hasRank()) {
1937 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1938 if (bScaleBatchSize != 1)
1940 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1941 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1949LogicalResult MatmulTBlockScaledOp::verify() {
1951 const Type aDataType = getAData().getType();
1952 const Type bDataType = getBData().getType();
1957 auto tryUpdateDimOrFailure = [&](
int64_t &currDim,
const int64_t newDim,
1958 const StringRef operandName,
1959 const StringRef dimName) -> LogicalResult {
1960 if (ShapedType::isDynamic(currDim)) {
1963 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
1965 << dimName <<
" of " << operandName <<
" to match size " << currDim
1966 <<
", got " << newDim;
1972 int64_t N = ShapedType::kDynamic;
1973 int64_t D = ShapedType::kDynamic;
1974 int64_t H = ShapedType::kDynamic;
1977 int64_t multiplesOfC = ShapedType::kDynamic;
2013 "b_scale",
"C/block_size")))
2018 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2019 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2021 << D <<
" vs N=" << N;
2024 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2025 if (ShapedType::isStatic(C) && C % blockSize != 0)
2026 return emitOpError(
"expect C to be a multiple of block size, got C=")
2027 <<
C <<
", block_size=" << blockSize;
2030 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2031 multiplesOfC != C / blockSize)
2033 "expect scale operands dimension 2 to equal C/block_size (")
2034 <<
C <<
"/" << blockSize <<
")"
2035 <<
", got " << multiplesOfC;
2038 N = ShapedType::isDynamic(N) ? D : N;
2040 const auto outputType = cast<ShapedType>(getResult().
getType());
2041 if (outputType.hasRank() &&
2045 auto stringifyDim = [&](
int64_t d) {
2046 if (ShapedType::isDynamic(d))
2051 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2052 opError <<
" to be compatible with expected output shape ";
2053 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2060LogicalResult tosa::PadOp::inferReturnTypeComponents(
2061 MLIRContext *context, ::std::optional<Location> location,
2062 PadOp::Adaptor adaptor,
2064 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2066 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2071 if (!inputShape.hasRank()) {
2072 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2081 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2086 outputShape.reserve(inputShape.getRank());
2087 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2088 if (inputShape.isDynamicDim(i)) {
2089 outputShape.push_back(ShapedType::kDynamic);
2092 auto padFront = paddingValues[i * 2];
2093 auto padBack = paddingValues[i * 2 + 1];
2094 if (padFront < 0 || padBack < 0) {
2096 outputShape.push_back(ShapedType::kDynamic);
2100 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2107LogicalResult tosa::PadOp::verify() {
2114 if (
auto padConst = getPadConst()) {
2122 RankedTensorType inputType =
2123 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2124 RankedTensorType outputType =
2125 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2126 if (!inputType || !outputType)
2129 auto inputRank = inputType.getRank();
2130 auto outputRank = outputType.getRank();
2131 if (inputRank != outputRank)
2132 return emitOpError() <<
"expect same input and output tensor rank, but got "
2133 <<
"inputRank: " << inputRank
2134 <<
", outputRank: " << outputRank;
2141 auto paddingValues = paddingAttr.getValues<APInt>();
2142 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2143 return emitOpError() <<
"padding tensor must have " << inputRank
2144 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2145 << paddingValues.size();
2147 auto inputShape = inputType.getShape();
2148 auto outputShape = outputType.getShape();
2150 for (
int64_t i = 0; i < inputRank; ++i) {
2151 int64_t padStart = paddingValues[i * 2].getSExtValue();
2152 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2154 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2156 <<
"invalid padding values at dimension " << i
2157 <<
": values must be non-negative or -1 for dynamic padding, got ["
2158 << padStart <<
", " << padEnd <<
"]";
2162 if (inputShape[i] == ShapedType::kDynamic ||
2163 outputShape[i] == ShapedType::kDynamic)
2166 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2167 return emitOpError() <<
"mismatch in output shape at dimension " << i
2168 <<
": expected " << inputShape[i] <<
" + "
2169 << padStart <<
" + " << padEnd <<
" = "
2170 << (inputShape[i] + padStart + padEnd)
2171 <<
", but got " << outputShape[i];
2178LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2179 MLIRContext *context, ::std::optional<Location> location,
2180 SliceOp::Adaptor adaptor,
2189 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2197 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2200 if (inputShape.hasRank()) {
2201 for (
size_t i = 0; i < size.size(); i++) {
2202 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2203 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2204 start[i] < inputShape.getDimSize(i))) {
2206 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2209 outputShape[i] = size[i];
2213 if (size[i] == -1) {
2214 outputShape[i] = inputShape.getDimSize(i) - start[i];
2215 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2217 outputShape[i] = size[i];
2229LogicalResult tosa::SliceOp::verify() {
2236 if (inputShape.hasRank()) {
2237 const auto inputRank = inputShape.getRank();
2239 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2241 "expect input1 and output to have the same ranks, got ")
2242 << inputRank <<
" and " << outputShape.getRank();
2244 const auto startShapeRank =
2245 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2246 if (inputRank != startShapeRank)
2247 return emitOpError(
"length of start is not equal to rank of input shape");
2249 const auto sizeShapeRank =
2250 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2251 if (inputRank != sizeShapeRank)
2252 return emitOpError(
"length of size is not equal to rank of input shape");
2258LogicalResult tosa::MulOp::inferReturnTypeComponents(
2259 MLIRContext *context, ::std::optional<Location> location,
2274LogicalResult tosa::MulOp::verify() {
2275 const Value output = getOutput();
2280 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2281 IntegerType lhsIntType =
2283 IntegerType rhsIntType =
2285 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2286 return emitOpError(
"requires the same element type for all operands");
2291 if (lhsIntType.getWidth() > resIntType.getWidth())
2292 return emitOpError(
"invalid data type size for operands or result");
2297 for (
int i = 0; i < 2; ++i) {
2300 "requires the same element type for all operands and results");
2304 ElementsAttr shift_elem;
2306 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2308 return emitOpError() <<
"require shift to be 0 for float type";
2316 TypeRange operandTypes = getOperandTypes();
2317 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2318 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2320 const bool aHasRank = aType.hasRank();
2321 const bool bHasRank = bType.hasRank();
2322 if (aHasRank && bHasRank) {
2323 const int64_t aRank = aType.getRank();
2324 const int64_t bRank = bType.getRank();
2326 return emitOpError(
"a and b operands don't have matching ranks, got ")
2327 << aRank <<
" and " << bRank;
2332 aType.getShape(), bType.getShape(), resultShape))
2333 return emitOpError(
"a and b operands don't have broadcast-compatible "
2335 << aType <<
" and " << bType;
2338 ShapedType resultType = cast<ShapedType>(output.
getType());
2339 if (!resultType.hasRank())
2342 const int64_t resultRank = resultType.getRank();
2343 if (aHasRank && resultRank != aType.getRank())
2344 return emitOpError(
"result type has different rank than a, got ")
2345 << resultRank <<
" vs " << aType.getRank();
2346 if (bHasRank && resultRank != bType.getRank())
2347 return emitOpError(
"result type has different rank than b, got ")
2348 << resultRank <<
" vs " << bType.getRank();
2353LogicalResult tosa::TableOp::inferReturnTypeComponents(
2354 MLIRContext *context, ::std::optional<Location> location,
2355 TableOp::Adaptor adaptor,
2357 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2359 if (!inputShape.hasRank()) {
2364 inferredReturnShapes.resize(1);
2365 inputShape.getDims(inferredReturnShapes[0]);
2369LogicalResult tosa::TableOp::verify() {
2370 const TensorType inputType = getInput1().getType();
2371 const TensorType outputType = getOutput().getType();
2376 if (inputType.getRank() != outputType.getRank())
2378 <<
"expected input tensor rank to equal result tensor rank";
2380 auto inputDims = inputType.
getShape();
2381 auto outputDims = outputType.
getShape();
2382 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2384 auto [inputDim, outputDim] = it.value();
2385 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2386 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2387 <<
" doesn't match dim(input, " << dim
2388 <<
") = " << inputDim;
2400 multiples = llvm::to_vector(
2401 llvm::map_range(multiplesAttr.getValues<APInt>(),
2402 [](
const APInt &val) { return val.getSExtValue(); }));
2406LogicalResult tosa::TileOp::inferReturnTypeComponents(
2407 MLIRContext *context, ::std::optional<Location> location,
2408 TileOp::Adaptor adaptor,
2415 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2423 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2425 if (!inputShape.hasRank()) {
2426 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2427 inferredReturnShapes.push_back(
2430 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2434 outputShape.reserve(multiples.size());
2435 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2436 if (multiples[i] == ShapedType::kDynamic) {
2437 outputShape.push_back(ShapedType::kDynamic);
2439 int64_t dim = inputShape.getDimSize(i);
2440 if (dim != ShapedType::kDynamic)
2441 dim *= multiples[i];
2442 outputShape.push_back(dim);
2450LogicalResult tosa::TileOp::verify() {
2456 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2457 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2459 shapeType multiplesType =
2460 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2462 auto multiplesRank = multiplesType.getRank();
2464 if (inputType.hasRank()) {
2465 if (inputType.getRank() != multiplesRank)
2466 return emitOpError(
"expect 'multiples' to have rank ")
2467 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2468 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2469 return emitOpError(
"expect same input and output tensor rank.");
2470 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2471 return emitOpError(
"expect 'multiples' array to have length ")
2472 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2475 if (getConstantMultiples(multiples).succeeded() &&
2476 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2478 "expect element of 'multiples' to be positive integer or -1.");
2484 if (l.size() != r.size() || l.size() != 1)
2489LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2490 MLIRContext *context, ::std::optional<Location> location,
2491 ReshapeOp::Adaptor adaptor,
2493 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2498 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2508 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2509 inferredReturnShapes.push_back(
2517 int64_t numElements = inputShape.getNumElements();
2519 for (
auto val : newShapeValue) {
2520 if (ShapedType::isStatic(val)) {
2526 for (
auto &val : newShapeValue) {
2527 if (ShapedType::isDynamic(val))
2528 val = numElements / staticMul;
2531 inferredReturnShapes.push_back(
2536llvm::LogicalResult tosa::ReshapeOp::verify() {
2542 TensorType inputType = getInput1().getType();
2547 return mlir::success();
2550 int missingDims = llvm::count(shapeValues, -1);
2551 if (missingDims > 1)
2552 return emitOpError() <<
"expected at most one target dimension to be -1";
2554 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2558 if ((
int64_t)shapeValues.size() != outputType.getRank())
2559 return emitOpError() <<
"new shape does not match result rank";
2561 for (
auto [newShapeDim, outputShapeDim] :
2562 zip(shapeValues, outputType.getShape())) {
2563 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2564 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2565 return emitOpError() <<
"new shape is inconsistent with result shape";
2567 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2568 return emitOpError() <<
"new shape has invalid tensor dimension size "
2572 if (inputType.hasStaticShape()) {
2573 int64_t inputElementsNum = inputType.getNumElements();
2574 if (outputType.hasStaticShape()) {
2575 int64_t outputElementsNum = outputType.getNumElements();
2576 if (inputElementsNum != outputElementsNum) {
2577 return emitOpError() <<
"cannot reshape " << inputElementsNum
2578 <<
" elements into " << outputElementsNum;
2584 return (dim > 0) ?
acc * dim :
acc;
2586 bool isStaticNewShape =
2587 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2588 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2589 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2590 return emitOpError() <<
"cannot reshape " << inputElementsNum
2591 <<
" elements into " << newShapeElementsNum;
2595 return mlir::success();
2602 ElementsAttr zpAttr;
2607 Type zpElemType = zpAttr.getElementType();
2609 if (llvm::isa<FloatType>(zpElemType)) {
2610 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2617 if (llvm::isa<IntegerType>(zpElemType)) {
2619 return zpAttr.getValues<APInt>()[0].getSExtValue();
2621 return zpAttr.getValues<APInt>()[0].getZExtValue();
2628template <
typename T>
2630 const std::string &operand) {
2633 if (!zpElemType.
isInteger(8) && zp != 0) {
2635 std::string lower = operand;
2636 llvm::transform(lower, lower.begin(), ::tolower);
2637 return op.emitOpError()
2638 << lower <<
" zero point must be zero for non-int8 integer types";
2646 const std::string &operand) {
2647 bool isInputZp = (operand ==
"Input");
2649 bool tensorUnsigned =
2650 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2651 StringRef tensorName = isInputZp ?
"input" :
"output";
2657 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2658 return op.emitOpError()
2659 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2661 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2662 return op.emitOpError() <<
"expect " << tensorName
2663 <<
"_zp of 0 or 32768 for unsigned int16 "
2664 << tensorName <<
", got " << zp;
2671#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2672 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2673 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2675 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2676 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2695#undef ZERO_POINT_HELPER
2697LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2698 MLIRContext *context, ::std::optional<Location> location,
2699 TransposeOp::Adaptor adaptor,
2701 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2710 const auto inputRank = inputShape.
getRank();
2714 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2720 if (inputRank == 0) {
2726 bool allTheSame =
true;
2727 for (
int i = 1, s = inputRank; i < s; i++) {
2737 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2742 outputShape.resize(inputRank, ShapedType::kDynamic);
2745 if (llvm::any_of(adaptor.getPerms(),
2746 [inputRank](
const auto i) { return i >= inputRank; }))
2749 outputShape.reserve(inputRank);
2750 for (
int i = 0, s = inputRank; i < s; i++) {
2751 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2758LogicalResult tosa::TransposeOp::verify() {
2770 if (inputShape.hasRank() &&
2771 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2772 return emitOpError() <<
"expected perms attribute to have size "
2773 << inputShape.getRank()
2774 <<
" (input rank) but got size "
2775 << constantPerms.size();
2777 if (inputShape.hasRank() && outputShape.hasRank() &&
2778 inputShape.getRank() != outputShape.getRank())
2780 <<
"expected input tensor rank to equal result tensor rank";
2782 if (outputShape.hasRank() &&
2783 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2784 return emitOpError() <<
"expected perms attribute to have size "
2785 << outputShape.getRank()
2786 <<
" (output rank) but got size "
2787 << constantPerms.size();
2789 if (!llvm::all_of(constantPerms,
2790 [&constantPerms](int32_t s) {
2792 static_cast<size_t>(s) < constantPerms.size();
2795 constantPerms, [](int32_t v) ->
int64_t {
return v; }))))
2796 return emitOpError() <<
"expected valid permutation indices";
2799 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2800 inputShape.getNumElements() != outputShape.getNumElements())
2801 return emitOpError() <<
"expected input1 and output to have same numbers "
2803 << inputShape.getNumElements() <<
" and "
2804 << outputShape.getNumElements();
2808 if (inputShape.hasRank() && outputShape.hasRank()) {
2809 for (
auto i = 0; i < outputShape.getRank(); i++) {
2810 if (inputShape.isDynamicDim(constantPerms[i]) ||
2811 outputShape.isDynamicDim(i))
2814 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2816 <<
"expected output tensor dim " << i <<
" to match "
2817 <<
"input dim " << constantPerms[i] <<
" with value of "
2818 << inputShape.getDimSize(constantPerms[i]);
2825LogicalResult TransposeOp::reifyResultShapes(
2828 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2830 Value input = getInput1();
2831 auto inputType = cast<TensorType>(input.
getType());
2833 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2834 for (
auto dim : transposePerms) {
2835 int32_t dimInInput = transposePerms[dim];
2836 if (inputType.isDynamicDim(dimInInput))
2838 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2842 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2845 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2849LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2850 MLIRContext *context, ::std::optional<Location> location,
2851 GatherOp::Adaptor adaptor,
2852 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2853 llvm::SmallVector<int64_t> outputShape;
2854 outputShape.resize(3, ShapedType::kDynamic);
2856 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2857 if (valuesShape.hasRank()) {
2858 outputShape[0] = valuesShape.getDimSize(0);
2859 outputShape[2] = valuesShape.getDimSize(2);
2862 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2863 if (indicesShape.hasRank()) {
2864 if (outputShape[0] == ShapedType::kDynamic)
2865 outputShape[0] = indicesShape.getDimSize(0);
2866 if (outputShape[1] == ShapedType::kDynamic)
2867 outputShape[1] = indicesShape.getDimSize(1);
2870 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2874LogicalResult tosa::GatherOp::verify() {
2881 const ShapeAdaptor valuesShape(getValues().
getType());
2883 const ShapeAdaptor outputShape(getOutput().
getType());
2885 int64_t N = ShapedType::kDynamic;
2886 int64_t
W = ShapedType::kDynamic;
2887 int64_t
C = ShapedType::kDynamic;
2889 if (valuesShape.hasRank()) {
2890 N = valuesShape.getDimSize(0);
2891 C = valuesShape.getDimSize(2);
2893 if (indicesShape.hasRank()) {
2894 const int64_t indicesN = indicesShape.getDimSize(0);
2895 W = indicesShape.getDimSize(1);
2896 if (N == ShapedType::kDynamic)
2898 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2899 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2900 <<
", got " << indicesN;
2902 if (outputShape.hasRank()) {
2903 const int64_t outputN = outputShape.getDimSize(0);
2904 const int64_t outputW = outputShape.getDimSize(1);
2905 const int64_t outputC = outputShape.getDimSize(2);
2906 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2908 return emitOpError() <<
"requires output dimension 0 to have size " << N
2909 <<
", got " << outputN;
2911 if (
W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2913 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2914 <<
", got " << outputW;
2915 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2917 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2918 <<
", got " << outputC;
2923LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2924 MLIRContext *context, ::std::optional<Location> location,
2925 ResizeOp::Adaptor adaptor,
2926 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2927 llvm::SmallVector<int64_t, 4> outputShape;
2928 outputShape.resize(4, ShapedType::kDynamic);
2930 ShapeAdaptor inputShape(adaptor.getInput().getType());
2931 if (!inputShape.hasRank())
2934 outputShape[0] = inputShape.getDimSize(0);
2935 outputShape[3] = inputShape.getDimSize(3);
2936 int64_t inputHeight = inputShape.getDimSize(1);
2937 int64_t inputWidth = inputShape.getDimSize(2);
2939 if ((inputHeight == ShapedType::kDynamic) ||
2940 (inputWidth == ShapedType::kDynamic))
2943 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2954 const int64_t outputHeight =
2955 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2959 const int64_t outputWidth =
2960 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2964 if (outputHeight < 0 || outputWidth < 0) {
2967 "calculated output height and width must be non-negative, "
2969 outputHeight,
", width = ", outputWidth);
2972 outputShape[1] = outputHeight;
2973 outputShape[2] = outputWidth;
2974 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2978LogicalResult tosa::ResizeOp::verify() {
2979 const Value input = getInput();
2980 const Value output = getOutput();
2981 const RankedTensorType inputType =
2982 llvm::dyn_cast<RankedTensorType>(input.
getType());
2983 const RankedTensorType outputType =
2984 llvm::dyn_cast<RankedTensorType>(output.
getType());
2986 SmallVector<int64_t> scaleValues;
2987 SmallVector<int64_t> offsetValues;
2988 SmallVector<int64_t> borderValues;
2996 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2997 return emitOpError(
"expect all scale values to be > 0, got ")
3000 const int64_t scaleYN = scaleValues[0];
3001 const int64_t scaleYD = scaleValues[1];
3002 const int64_t scaleXN = scaleValues[2];
3003 const int64_t scaleXD = scaleValues[3];
3005 const int64_t offsetY = offsetValues[0];
3006 const int64_t offsetX = offsetValues[1];
3008 const int64_t borderY = borderValues[0];
3009 const int64_t borderX = borderValues[1];
3016 const int64_t oh = outputType.getDimSize(1);
3017 const int64_t ow = outputType.getDimSize(2);
3018 const int64_t ih = inputType.getDimSize(1);
3019 const int64_t iw = inputType.getDimSize(2);
3025 if (ih != ShapedType::kDynamic && ih != 1) {
3026 const std::optional<int64_t> calculatedOutHeightMinusOne =
3027 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3028 if (!calculatedOutHeightMinusOne.has_value())
3029 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3031 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3032 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3033 <<
") / " << scaleYD;
3034 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3035 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3036 return emitOpError(
"calculated output height did not match expected: ")
3037 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3044 if (iw != ShapedType::kDynamic && iw != 1) {
3045 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3046 const std::optional<int64_t> calculatedOutWidthMinusOne =
3048 if (!calculatedOutWidthMinusOne.has_value())
3049 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3051 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3052 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3053 <<
") / " << scaleXD;
3054 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3055 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3056 return emitOpError(
"calculated output width did not match expected: ")
3057 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3063LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3064 MLIRContext *context, ::std::optional<Location> location,
3065 ScatterOp::Adaptor adaptor,
3066 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3067 llvm::SmallVector<int64_t> outputShape;
3068 outputShape.resize(3, ShapedType::kDynamic);
3070 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3071 if (valuesInShape.hasRank()) {
3072 outputShape[0] = valuesInShape.getDimSize(0);
3073 outputShape[1] = valuesInShape.getDimSize(1);
3074 outputShape[2] = valuesInShape.getDimSize(2);
3077 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3078 if (indicesShape.hasRank()) {
3079 if (outputShape[0] == ShapedType::kDynamic)
3080 outputShape[0] = indicesShape.getDimSize(0);
3083 ShapeAdaptor inputShape(adaptor.getInput().getType());
3084 if (inputShape.hasRank()) {
3085 if (outputShape[0] == ShapedType::kDynamic)
3086 outputShape[0] = inputShape.getDimSize(0);
3087 if (outputShape[2] == ShapedType::kDynamic)
3088 outputShape[2] = inputShape.getDimSize(2);
3091 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3095LogicalResult tosa::ScatterOp::verify() {
3105 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3107 const ShapeAdaptor inputShape(getInput().
getType());
3108 const ShapeAdaptor outputShape(getValuesOut().
getType());
3110 int64_t N = ShapedType::kDynamic;
3111 int64_t K = ShapedType::kDynamic;
3112 int64_t
W = ShapedType::kDynamic;
3113 int64_t
C = ShapedType::kDynamic;
3114 if (valuesInShape.hasRank()) {
3115 N = valuesInShape.getDimSize(0);
3116 K = valuesInShape.getDimSize(1);
3117 C = valuesInShape.getDimSize(2);
3119 if (indicesShape.hasRank()) {
3120 const int64_t indicesN = indicesShape.getDimSize(0);
3121 W = indicesShape.getDimSize(1);
3122 if (N == ShapedType::kDynamic)
3124 else if (indicesN != ShapedType::kDynamic && N != indicesN)
3125 return emitOpError() <<
"requires indices dimension 0 to have size " << N
3126 <<
", got " << indicesN;
3128 if (inputShape.hasRank()) {
3129 const int64_t inputN = inputShape.getDimSize(0);
3130 const int64_t inputW = inputShape.getDimSize(1);
3131 const int64_t inputC = inputShape.getDimSize(2);
3132 if (N == ShapedType::kDynamic)
3134 else if (inputN != ShapedType::kDynamic && N != inputN)
3135 return emitOpError() <<
"requires input dimension 0 to have size " << N
3136 <<
", got " << inputN;
3137 if (
W == ShapedType::kDynamic)
3139 else if (inputW != ShapedType::kDynamic &&
W != inputW)
3140 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
3141 <<
", got " << inputW;
3143 if (C == ShapedType::kDynamic)
3145 else if (inputC != ShapedType::kDynamic && C != inputC)
3146 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
3147 <<
", got " << inputC;
3149 if (outputShape.hasRank()) {
3150 const int64_t outputN = outputShape.getDimSize(0);
3151 const int64_t outputK = outputShape.getDimSize(1);
3152 const int64_t outputC = outputShape.getDimSize(2);
3153 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3155 return emitOpError() <<
"requires values_out dimension 0 to have size "
3156 << N <<
", got " << outputN;
3157 if (K == ShapedType::kDynamic)
3159 else if (outputK != ShapedType::kDynamic && K != outputK)
3160 return emitOpError() <<
"requires values_out dimension 1 to have size "
3161 << K <<
", got " << outputK;
3162 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3164 return emitOpError() <<
"requires values_out dimension 2 to have size "
3165 <<
C <<
", got " << outputC;
3167 if (K != ShapedType::kDynamic &&
W != ShapedType::kDynamic && !(K >=
W))
3168 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
3177 int64_t axisVal = axis.getValue().getSExtValue();
3178 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3184 operandShape.
getDims(outputShape);
3185 outputShape[axisVal] = 1;
3190#define COMPATIBLE_RETURN_TYPES(OP) \
3191 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3192 if (l.size() != r.size() || l.size() != 1) \
3194 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3196 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3199#define REDUCE_SHAPE_INFER(OP) \
3200 LogicalResult OP::inferReturnTypeComponents( \
3201 MLIRContext *context, ::std::optional<Location> location, \
3202 OP::Adaptor adaptor, \
3203 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3205 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3206 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3207 const Properties &prop = adaptor.getProperties(); \
3208 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3209 inferredReturnShapes); \
3211 COMPATIBLE_RETURN_TYPES(OP)
3219#undef REDUCE_SHAPE_INFER
3221#undef COMPATIBLE_RETURN_TYPES
3223template <
typename T>
3226 TensorType inputType = op.getInput().getType();
3227 TensorType outputType = op.getOutput().getType();
3228 int32_t reduceAxis = op.getAxis();
3230 if (reduceAxis < 0) {
3231 op.emitOpError(
"reduce axis must not be negative");
3235 int64_t inputRank = inputType.getRank();
3238 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3239 op.emitOpError(
"expect input tensor rank (")
3240 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3246 int64_t outputRank = outputType.getRank();
3247 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3249 "expect output tensor rank to be equal to input tensor rank");
3252 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3253 op.emitOpError(
"expect output tensor rank (")
3254 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3260 if (outputRank != 0) {
3261 auto outputShape = outputType.
getShape();
3262 if (!outputType.isDynamicDim(reduceAxis) &&
3263 outputShape[reduceAxis] != 1) {
3264 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3265 << outputShape[reduceAxis];
3273LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3274LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3275LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3276LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3277LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3278LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3292#define NARY_SHAPE_INFER(OP) \
3293 LogicalResult OP::inferReturnTypeComponents( \
3294 MLIRContext *context, ::std::optional<Location> location, \
3295 ValueShapeRange operands, DictionaryAttr attributes, \
3296 OpaqueProperties properties, RegionRange regions, \
3297 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3298 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3338#undef PRED_SHAPE_INFER
3340LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3341 MLIRContext *context, ::std::optional<Location> location,
3342 NegateOp::Adaptor adaptor,
3344 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3349LogicalResult tosa::NegateOp::verify() {
3351 const Type input1Type = getInput1().getType();
3352 const Type outputType = getOutput().getType();
3357 const SmallVector<Type, 2> types = {input1Type, outputType};
3359 return emitOpError() <<
"requires the same shape for input1 and output";
3362 const Type input1ZpEType =
3364 if (input1EType != input1ZpEType) {
3365 return emitOpError(
"expect both input1 and its zero point are the same "
3366 "element type, got ")
3367 << input1EType <<
" and " << input1ZpEType;
3370 const Type outputZpEType =
3372 if (outputEType != outputZpEType) {
3373 return emitOpError(
"expect both output and its zero point are the same "
3374 "element type, got ")
3375 << outputEType <<
" and " << outputZpEType;
3378 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3379 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3382 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3383 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3394 outputShape.resize(4, ShapedType::kDynamic);
3409 if (ShapedType::isStatic(height)) {
3410 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3411 outputShape[1] = padded / stride[0] + 1;
3414 if (ShapedType::isStatic(width)) {
3415 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3416 outputShape[2] = padded / stride[1] + 1;
3423LogicalResult Conv2DOp::inferReturnTypeComponents(
3424 MLIRContext *context, ::std::optional<Location> location,
3425 Conv2DOp::Adaptor adaptor,
3426 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3427 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3429 int64_t inputWidth = ShapedType::kDynamic;
3430 int64_t inputHeight = ShapedType::kDynamic;
3431 int64_t weightWidth = ShapedType::kDynamic;
3432 int64_t weightHeight = ShapedType::kDynamic;
3436 ShapeAdaptor inputShape(adaptor.getInput().getType());
3437 if (inputShape.hasRank()) {
3438 outputShape[0] = inputShape.getDimSize(0);
3439 inputHeight = inputShape.getDimSize(1);
3440 inputWidth = inputShape.getDimSize(2);
3444 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3445 if (weightShape.hasRank()) {
3446 outputShape[3] = weightShape.getDimSize(0);
3447 weightHeight = weightShape.getDimSize(1);
3448 weightWidth = weightShape.getDimSize(2);
3452 ShapeAdaptor biasShape(adaptor.getBias().getType());
3453 if (biasShape.hasRank()) {
3454 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3455 ? biasShape.getDimSize(0)
3459 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3460 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3461 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3463 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3464 int64_t inputSize = inputHeight + padding[0] + padding[1];
3465 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3466 int64_t unstridedResult = inputSize - filterSize + 1;
3467 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3470 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3471 int64_t inputSize = inputWidth + padding[2] + padding[3];
3472 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3473 int64_t unstridedResult = inputSize - filterSize + 1;
3474 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3477 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3481LogicalResult Conv2DOp::verify() {
3488LogicalResult Conv3DOp::inferReturnTypeComponents(
3489 MLIRContext *context, ::std::optional<Location> location,
3490 Conv3DOp::Adaptor adaptor,
3491 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3492 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3494 int64_t inputWidth = ShapedType::kDynamic;
3495 int64_t inputHeight = ShapedType::kDynamic;
3496 int64_t inputDepth = ShapedType::kDynamic;
3498 int64_t weightWidth = ShapedType::kDynamic;
3499 int64_t weightHeight = ShapedType::kDynamic;
3500 int64_t weightDepth = ShapedType::kDynamic;
3503 ShapeAdaptor inputShape(adaptor.getInput().getType());
3504 if (inputShape.hasRank()) {
3505 outputShape[0] = inputShape.getDimSize(0);
3506 inputDepth = inputShape.getDimSize(1);
3507 inputHeight = inputShape.getDimSize(2);
3508 inputWidth = inputShape.getDimSize(3);
3512 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3513 if (weightShape.hasRank()) {
3514 outputShape[4] = weightShape.getDimSize(0);
3515 weightDepth = weightShape.getDimSize(1);
3516 weightHeight = weightShape.getDimSize(2);
3517 weightWidth = weightShape.getDimSize(3);
3521 ShapeAdaptor biasShape(adaptor.getBias().getType());
3522 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3523 outputShape[4] = biasShape.getDimSize(0);
3526 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3527 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3528 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3530 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3531 int32_t inputSize = inputDepth + pad[0] + pad[1];
3532 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3533 int32_t unstridedResult = inputSize - filterSize + 1;
3534 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3537 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3538 int32_t inputSize = inputHeight + pad[2] + pad[3];
3539 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3540 int32_t unstridedResult = inputSize - filterSize + 1;
3541 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3544 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3545 int32_t inputSize = inputWidth + pad[4] + pad[5];
3546 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3547 int32_t unstridedResult = inputSize - filterSize + 1;
3548 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3551 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3555LogicalResult Conv3DOp::verify() {
3562LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3563 MLIRContext *context, ::std::optional<Location> location,
3564 AvgPool2dOp::Adaptor adaptor,
3565 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3566 ShapeAdaptor inputShape(adaptor.getInput().getType());
3567 const Properties &prop = adaptor.getProperties();
3569 inferredReturnShapes);
3572LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3573 MLIRContext *context, ::std::optional<Location> location,
3574 MaxPool2dOp::Adaptor adaptor,
3575 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3576 ShapeAdaptor inputShape(adaptor.getInput().getType());
3577 const Properties &prop = adaptor.getProperties();
3579 inferredReturnShapes);
3582LogicalResult MaxPool2dOp::verify() {
3593LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3594 MLIRContext *context, ::std::optional<Location> location,
3595 DepthwiseConv2DOp::Adaptor adaptor,
3596 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3597 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3599 int64_t inputWidth = ShapedType::kDynamic;
3600 int64_t inputHeight = ShapedType::kDynamic;
3601 int64_t inputChannels = ShapedType::kDynamic;
3603 int64_t weightWidth = ShapedType::kDynamic;
3604 int64_t weightHeight = ShapedType::kDynamic;
3605 int64_t depthChannels = ShapedType::kDynamic;
3608 ShapeAdaptor inputShape(adaptor.getInput().getType());
3609 if (inputShape.hasRank()) {
3610 outputShape[0] = inputShape.getDimSize(0);
3611 inputHeight = inputShape.getDimSize(1);
3612 inputWidth = inputShape.getDimSize(2);
3613 inputChannels = inputShape.getDimSize(3);
3617 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3618 if (weightShape.hasRank()) {
3619 weightHeight = weightShape.getDimSize(0);
3620 weightWidth = weightShape.getDimSize(1);
3621 inputChannels = ShapedType::isDynamic(inputChannels)
3622 ? weightShape.getDimSize(2)
3624 depthChannels = weightShape.getDimSize(3);
3629 if (ShapedType::isStatic(inputChannels) &&
3630 ShapedType::isStatic(depthChannels)) {
3631 outputShape[3] = inputChannels * depthChannels;
3635 ShapeAdaptor biasShape(adaptor.getBias().getType());
3636 if (biasShape.hasRank()) {
3637 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3638 ? biasShape.getDimSize(0)
3642 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3643 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3644 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3646 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3647 int64_t inputSize = inputHeight + padding[0] + padding[1];
3648 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3649 int64_t unstridedResult = inputSize - filterSize + 1;
3650 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3653 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3654 int64_t inputSize = inputWidth + padding[2] + padding[3];
3655 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3656 int64_t unstridedResult = inputSize - filterSize + 1;
3657 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3660 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3664LogicalResult DepthwiseConv2DOp::verify() {
3671LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3672 MLIRContext *context, ::std::optional<Location> location,
3673 TransposeConv2DOp::Adaptor adaptor,
3674 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3675 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3677 int64_t inputWidth = ShapedType::kDynamic;
3678 int64_t inputHeight = ShapedType::kDynamic;
3679 int64_t weightWidth = ShapedType::kDynamic;
3680 int64_t weightHeight = ShapedType::kDynamic;
3683 ShapeAdaptor inputShape(adaptor.getInput().getType());
3684 if (inputShape.hasRank()) {
3685 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3686 ? inputShape.getDimSize(0)
3688 inputHeight = inputShape.getDimSize(1);
3689 inputWidth = inputShape.getDimSize(2);
3693 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3694 if (weightShape.hasRank()) {
3695 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3696 ? weightShape.getDimSize(0)
3698 weightHeight = weightShape.getDimSize(1);
3699 weightWidth = weightShape.getDimSize(2);
3703 ShapeAdaptor biasShape(adaptor.getInput().getType());
3704 if (biasShape.hasRank()) {
3705 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3706 ? biasShape.getDimSize(0)
3710 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3711 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3713 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3714 int64_t calculateSize =
3715 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3717 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3720 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3721 int64_t calculateSize =
3722 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3724 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3727 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3731LogicalResult TransposeConv2DOp::verify() {
3735 const llvm::ArrayRef<int64_t> strides = getStride();
3736 const int64_t strideY = strides[0];
3737 const int64_t strideX = strides[1];
3739 if (strideY < 1 || strideX < 1)
3740 return emitOpError(
"expect all stride values to be >= 1, got [")
3743 const auto checkPadAgainstKernelDim =
3744 [
this](int64_t pad_value, int64_t kernel_dim_size,
3745 llvm::StringRef pad_name,
3746 llvm::StringRef kernel_dim_name) -> LogicalResult {
3747 if (pad_value <= -kernel_dim_size)
3749 << pad_name <<
" > -" << kernel_dim_name
3750 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3751 << kernel_dim_name <<
"=" << kernel_dim_size;
3755 const llvm::ArrayRef<int64_t> padding = getOutPad();
3756 const int64_t outPadTop = padding[0];
3757 const int64_t outPadBottom = padding[1];
3758 const int64_t outPadLeft = padding[2];
3759 const int64_t outPadRight = padding[3];
3761 const auto weightType =
3762 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3765 const int64_t kernelHeight = weightType.getDimSize(1);
3766 if (ShapedType::isStatic(kernelHeight)) {
3767 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3768 "out_pad_top",
"KH")))
3771 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3772 "out_pad_bottom",
"KH")))
3776 const int64_t kernelWidth = weightType.getDimSize(2);
3777 if (ShapedType::isStatic(kernelWidth)) {
3778 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3779 "out_pad_left",
"KW")))
3782 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3783 "out_pad_right",
"KW")))
3789 const auto outputType =
3790 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3794 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3795 if (inputType && weightType) {
3796 const int64_t inputHeight = inputType.getDimSize(1);
3797 const int64_t kernelHeight = weightType.getDimSize(1);
3798 const int64_t outputHeight = outputType.getDimSize(1);
3800 if (ShapedType::isStatic(inputHeight) &&
3801 ShapedType::isStatic(outputHeight)) {
3803 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3805 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3806 "+ out_pad_top + out_pad_bottom + KH, but got ")
3807 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3808 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3809 <<
" + " << kernelHeight;
3812 const int64_t inputWidth = inputType.getDimSize(2);
3813 const int64_t kernelWidth = weightType.getDimSize(2);
3814 const int64_t outputWidth = outputType.getDimSize(2);
3816 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3818 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3820 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3821 "+ out_pad_left + out_pad_right + KW, but got ")
3822 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3823 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3828 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3833 const int64_t biasChannels = biasType.getDimSize(0);
3836 if (biasChannels == ShapedType::kDynamic)
3839 const int64_t outputChannels = outputType.getDimSize(3);
3840 if (!ShapedType::isDynamic(outputChannels) &&
3841 biasChannels != outputChannels && biasChannels != 1)
3843 "bias channels expected to be equal to output channels (")
3844 << outputChannels <<
") or 1, got " << biasChannels;
3849LogicalResult RescaleOp::verify() {
3850 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3852 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3856 auto inputElementType =
3858 if (!mlir::isa<IntegerType>(inputElementType)) {
3859 emitOpError(
"expect input to have integer element type, got ")
3860 << inputElementType;
3864 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3866 emitOpError(
"expect shaped tensor for output, got ")
3867 << getOutput().getType();
3871 auto outputElementType =
3873 if (!mlir::isa<IntegerType>(outputElementType)) {
3874 emitOpError(
"expect output to have integer element type, got ")
3875 << outputElementType;
3887 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3888 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3891 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3892 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3895 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3896 if (!multiplierType) {
3897 emitOpError(
"expect shaped tensor for multiplier, got ")
3898 << getMultiplier().getType();
3902 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3904 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3909 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3910 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3911 << multiplierType.getElementType();
3916 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3918 "expect i16 element type for multiplier for scale32=false, got ")
3919 << multiplierType.getElementType();
3923 if (!inputType.hasRank())
3929 int64_t numChannels = 1;
3930 if (getPerChannel()) {
3931 if (inputType.getRank() < 1) {
3932 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3933 "true, but got rank ")
3934 << inputType.getRank();
3937 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3940 if (!multiplierType.hasRank())
3943 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3945 if (multiplierShape[0] != ShapedType::kDynamic &&
3946 multiplierShape[0] != numChannels) {
3948 << numChannels <<
" } for multiplier input, got { "
3949 << multiplierShape[0] <<
" }";
3953 if (!shiftType.hasRank())
3956 ArrayRef<int64_t> shiftShape = shiftType.getShape();
3958 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3960 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3967LogicalResult RescaleOp::inferReturnTypeComponents(
3968 MLIRContext *context, ::std::optional<Location> location,
3969 RescaleOp::Adaptor adaptor,
3970 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3971 ShapeAdaptor inputShape(adaptor.getInput().getType());
3972 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3976LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3977 MLIRContext *context, ::std::optional<Location> location,
3978 CastFromBlockScaledOp::Adaptor adaptor,
3979 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3980 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3981 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3985LogicalResult CastFromBlockScaledOp::verify() {
3986 const Type inputDataType = getInputData().getType();
3987 const Type outputDataType = getResult().getType();
3989 return emitOpError() <<
"require compatible shapes for input_data ("
3990 << inputDataType <<
") and "
3991 <<
"output_data (" << outputDataType <<
")";
3993 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
3995 if (inputDataShape.
hasRank()) {
3996 const unsigned int blockSize =
3998 const int64_t inputDataLastDim =
4000 if (inputDataLastDim % blockSize != 0)
4001 return emitOpError() <<
"expect last dimension of input_data ("
4003 <<
") to be divisible by block_size (" << blockSize
4006 const Type inputScaleType = getInputScale().getType();
4007 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4009 if (inputScaleShape.
hasRank()) {
4010 SmallVector<int64_t> inputDataDims, inputScaleDims;
4011 inputDataShape.
getDims(inputDataDims);
4012 inputScaleShape.
getDims(inputScaleDims);
4014 if (inputDataDims.size() != inputScaleDims.size() ||
4016 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4017 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4018 return emitOpError() <<
"require compatible shapes for input_data ("
4019 << inputDataType <<
") and "
4020 <<
"input_scale (" << inputScaleType
4021 <<
") except for the last dimension";
4023 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4024 inputScaleDims.back()};
4025 if (ShapedType::isStatic(inputDataLastDim) &&
4028 <<
"expect last dimension of input_scale ("
4029 << inputScaleDims.back()
4030 <<
") to be equal to last dimension of input_data / block_size ("
4031 << inputDataDims.back() / blockSize <<
")";
4038LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4039 MLIRContext *context, ::std::optional<Location> location,
4040 CastToBlockScaledOp::Adaptor adaptor,
4041 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4042 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4043 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4044 if (!inputShape.hasRank())
4048 SmallVector<int64_t> outputScaleShape;
4049 inputShape.getDims(outputScaleShape);
4050 const int64_t lastDimLoc = inputShape.getRank() - 1;
4051 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4052 if (ShapedType::isStatic(lastDimSize)) {
4053 const unsigned int blockSize =
4054 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4055 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4057 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4061LogicalResult CastToBlockScaledOp::verify() {
4062 const Type inputDataType = getInputData().getType();
4063 const Type outputDataType = getResult(0).getType();
4065 return emitOpError() <<
"require compatible shapes for input_data ("
4066 << inputDataType <<
") and "
4067 <<
"output_data (" << outputDataType <<
")";
4069 const unsigned int blockSize =
4071 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4072 if (inputDataShape.
hasRank()) {
4073 const int64_t inputDataLastDim =
4075 if (ShapedType::isStatic(inputDataLastDim) &&
4076 inputDataLastDim % blockSize != 0)
4077 return emitOpError() <<
"expect last dimension of input_data ("
4079 <<
") to be divisible by block_size (" << blockSize
4083 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4084 const Type outputScaleType = getResult(1).getType();
4085 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4087 SmallVector<int64_t> outputDataDims, outputScaleDims;
4088 outputDataShape.
getDims(outputDataDims);
4089 outputScaleShape.
getDims(outputScaleDims);
4091 if (outputDataDims.size() != outputScaleDims.size() ||
4093 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4094 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4095 return emitOpError() <<
"require compatible shapes for output_data ("
4096 << outputDataType <<
") and "
4097 <<
"output_scale (" << outputScaleType
4098 <<
") except for the last dimension";
4100 const int64_t outputDataLastDim = outputDataDims.back();
4101 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4102 outputScaleDims.back()};
4103 if (ShapedType::isStatic(outputDataLastDim) &&
4106 <<
"expect last dimension of output_scale ("
4107 << outputScaleDims.back()
4108 <<
") to be equal to last dimension of output_data / block_size ("
4109 << outputDataDims.back() / blockSize <<
")";
4115LogicalResult IfOp::inferReturnTypeComponents(
4116 MLIRContext *context, ::std::optional<Location> location,
4117 IfOp::Adaptor adaptor,
4118 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4119 llvm::SmallVector<tosa::YieldOp> yieldOps;
4120 for (Region *region : adaptor.getRegions()) {
4121 for (
auto &block : *region)
4122 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4123 yieldOps.push_back(returnOp);
4126 if (yieldOps.empty())
4130 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4131 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4132 for (
auto operand : yieldOps.front().getOperands()) {
4133 resultKnowledge.push_back(
4137 for (
auto yieldOp : yieldOps) {
4138 if (resultKnowledge.size() != yieldOp.getNumOperands())
4141 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4142 int32_t index = it.index();
4144 resultKnowledge[index],
4148 resultKnowledge[index] = meet;
4152 for (
const ValueKnowledge &
result : resultKnowledge) {
4153 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4159LogicalResult WhileOp::inferReturnTypeComponents(
4160 MLIRContext *context, ::std::optional<Location> location,
4161 WhileOp::Adaptor adaptor,
4162 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4163 llvm::SmallVector<tosa::YieldOp> yieldOps;
4164 for (
auto &block : adaptor.getBodyGraph())
4165 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4166 yieldOps.push_back(returnOp);
4170 if (yieldOps.empty())
4174 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4175 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4176 for (
auto operand : yieldOps.front().getOperands()) {
4177 resultKnowledge.push_back(
4181 for (
auto yieldOp : yieldOps) {
4182 if (resultKnowledge.size() != yieldOp.getNumOperands())
4185 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4186 int32_t index = it.index();
4188 resultKnowledge[index],
4190 resultKnowledge[index] = meet;
4195 for (
const ValueKnowledge &
result : resultKnowledge) {
4196 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4202std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4203 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4204 return llvm::to_vector<4>(vt.getShape());
4205 return std::nullopt;
4211 StringRef prefix =
"") {
4212 assert(blocksArgs.size() == initializers.size() &&
4213 "expected same length of arguments and initializers");
4214 if (initializers.empty())
4217 parser << prefix <<
'(';
4218 llvm::interleaveComma(
4219 llvm::zip(blocksArgs, initializers), parser,
4220 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4225ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
4227 result.regions.reserve(2);
4228 Region *thenRegion =
result.addRegion();
4229 Region *elseRegion =
result.addRegion();
4231 OpAsmParser::UnresolvedOperand cond;
4236 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4237 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4240 OptionalParseResult listResult =
4248 "expected type for condition operand");
4254 "expected type for condition operand");
4262 FunctionType functionType;
4266 <<
"expected list of types for block arguments "
4267 <<
"followed by arrow type and list of return types";
4269 result.addTypes(functionType.getResults());
4271 if (functionType.getNumInputs() != operands.size()) {
4273 <<
"expected as many input types as operands "
4274 <<
"(expected " << operands.size() <<
" got "
4275 << functionType.getNumInputs() <<
")";
4305void IfOp::print(OpAsmPrinter &p) {
4306 p <<
" " << getCondition();
4309 getInputList(),
" ");
4311 p << getCondition().getType();
4313 if (!getInputList().empty()) {
4315 llvm::interleaveComma(getInputList().getTypes(), p);
4324 auto &elseRegion = getElseGraph();
4325 if (!elseRegion.
empty()) {
4333LogicalResult IfOp::verify() {
4335 "'then_graph' arguments", getInputList(),
4341 "'else_graph' arguments", getInputList(),
4347 if (getThenGraph().front().mightHaveTerminator()) {
4349 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4351 *
this, thenYield.getInputs(),
"'then_graph' results",
4352 getOutputList(),
"'output_list'")
4358 if (getElseGraph().front().mightHaveTerminator()) {
4360 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4362 *
this, elseYield.getInputs(),
"'else_graph' results",
4363 getOutputList(),
"'output_list'")
4368 auto condType = getCondition().getType();
4370 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4376LogicalResult WhileOp::verify() {
4378 getOutputList(),
"'output_list'")
4383 "'cond_graph' arguments", getInputList(),
4389 "'body_graph' arguments", getInputList(),
4394 if (getBodyGraph().front().mightHaveTerminator()) {
4396 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4398 "'body_graph' results",
4399 getInputList(),
"'input_list'")
4406 if (!getCondGraph().front().mightHaveTerminator())
4410 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4414 if (condYield.getInputs().size() != 1)
4415 return emitOpError() <<
"require 'cond_graph' only have one result";
4417 auto condOutType = condYield.getInputs()[0].getType();
4419 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4423 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4429LogicalResult ReverseOp::verify() {
4434 TensorType inputType = getInput1().getType();
4435 TensorType outputType = getOutput().getType();
4436 int32_t reverseAxis = getAxis();
4438 if (reverseAxis < 0)
4439 return emitOpError(
"expected non-negative reverse axis");
4441 int64_t inputRank = inputType.getRank();
4444 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4446 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4450 int64_t outputRank = outputType.getRank();
4451 if (inputType.
hasRank() && outputRank != inputType.getRank())
4453 "expect output tensor rank to be equal to input tensor rank");
4454 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4456 << outputRank <<
") to be larger than reverse axis ("
4457 << reverseAxis <<
")";
4462LogicalResult tosa::SelectOp::verify() {
4473 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4474 if (!predicateType) {
4475 return emitOpError(
"expect shaped tensor for input1, got ")
4476 << getInput1().getType();
4478 auto predicateElementType = predicateType.getElementType();
4479 if (!predicateElementType.isInteger(1)) {
4480 return emitOpError(
"expect element type of bool for input1, got ")
4481 << predicateElementType;
4487LogicalResult tosa::VariableReadOp::verify() {
4495LogicalResult tosa::VariableWriteOp::verify() {
4504ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
4505 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4506 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4507 Region *cond =
result.addRegion();
4508 Region *body =
result.addRegion();
4510 OptionalParseResult listResult =
4515 FunctionType functionType;
4520 result.addTypes(functionType.getResults());
4522 if (functionType.getNumInputs() != operands.size()) {
4524 <<
"expected as many input types as operands "
4525 <<
"(expected " << operands.size() <<
" got "
4526 << functionType.getNumInputs() <<
")";
4536 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4537 regionArgs[i].type = functionType.getInput(i);
4539 return failure(parser.
parseRegion(*cond, regionArgs) ||
4544void WhileOp::print(OpAsmPrinter &parser) {
4546 getInputList(),
" ");
4549 getResults().getTypes());
4563 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4564 if (llvm::isa<FloatType>(srcElemType)) {
4566 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4567 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4569 if (llvm::isa<IntegerType>(srcElemType)) {
4572 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4574 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4575 return std::nullopt;
4583 return mlir::isa<tosa::shapeType>(t);
4590 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4596 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4597 Operation *definingOp = v.getDefiningOp();
4599 return op->
emitOpError(
"shape operand is not compile time resolvable");
4608 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4609 return op->
emitOpError(
"must have operands with tosa shape type");
4613 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4614 return op->
emitOpError(
"must have result with tosa shape type");
4627 auto getRank = [](
const Type type) {
4628 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4634 for (
auto type : operandTypes) {
4635 if (getRank(type) != rank) {
4636 return op->
emitOpError(
"operands don't have matching ranks");
4639 for (
auto type : resultTypes) {
4640 if (getRank(type) != rank) {
4641 return op->
emitOpError(
"result shape has different rank than operands");
4651LogicalResult tosa::ConstShapeOp::verify() {
4653 auto valuesRank = getValues().getType().getRank();
4654 if (valuesRank != 1)
4655 return emitOpError(
"expect elements in attribute values with rank 1");
4657 auto count = getValues().getNumElements();
4658 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4659 if (count != rank && (count != 1 || rank != 0)) {
4660 return emitOpError(
"expect number of elements in attribute values (")
4661 << count <<
") to be equal to the rank (" << rank
4662 <<
") for the result shape type";
4671#define GET_ATTRDEF_CLASSES
4672#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4677#define GET_TYPEDEF_CLASSES
4678#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4684#define GET_OP_CLASSES
4685#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)