29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
57struct TosaInlinerInterface :
public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
66 IRMapping &map)
const final {
72 IRMapping &map)
const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
79struct TosaDialectBytecodeInterface :
public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
86 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer)
const override {
92 return ::writeAttribute(attr, writer);
98 Type readType(DialectBytecodeReader &reader)
const override {
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer)
const override {
104 return ::writeType(type, writer);
107 void writeVersion(DialectBytecodeWriter &writer)
const final {
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader)
const final {
114 reader.
emitError(
"Dialect does not support versioning");
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version)
const final {
132 return {&getBodyGraph()};
141 return dim == -1 ? ShapedType::kDynamic : dim;
147 Type elementType = variableOp.getType();
150 return RankedTensorType::get(
shape, elementType);
157void TosaDialect::initialize() {
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
201ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
203 TypeAttr &typeAttr) {
204 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
207 <<
"expected ranked type";
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
217 <<
"expected shaped type";
234 <<
"expected attribute";
236 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
241 <<
"expected Typed attr";
244 initialValueAttr =
nullptr;
248 <<
"expected type after colon";
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
255 TypeAttr typeAttr,
Attribute initialValueAttr) {
256 bool needsSpace =
false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
268 if (initialValueAttr) {
279template <
typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
282 llvm::StringRef name;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name ==
"rounding_mode" &&
292 auto sym = symbolizeRoundingMode(kw);
295 <<
"invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
304 auto sym = symbolizeResizeMode(kw);
307 <<
"invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
317 auto sym = symbolizeNanPropagationMode(kw);
320 <<
"invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
330 auto sym = symbolizeBlockSize(kw);
333 <<
"invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
346template <
typename EnumType>
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
383 parser << namedAttr.
getName().strref() <<
" = ";
385 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 }
else if (
auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
408 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
411 toPrint.erase(attr.getName());
417 if (!toPrint.empty()) {
419 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
436 llvm::interleaveComma(op->
getAttrs(), parser,
438 printNamedAttr(parser, namedAttr);
450 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
454 printWithEnumHandling(parser, *
this);
458 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
462 printWithEnumHandling(parser, *
this);
466 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
470 printWithEnumHandling(parser, *
this);
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
478 printWithNanPropagationHandling(parser, *
this);
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
486 printWithNanPropagationHandling(parser, *
this);
489ParseResult MaxPool2dAdaptiveOp::parse(
OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
495 printWithNanPropagationHandling(parser, *
this);
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
503 printWithNanPropagationHandling(parser, *
this);
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
511 printWithNanPropagationHandling(parser, *
this);
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
519 printWithNanPropagationHandling(parser, *
this);
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
527 printWithNanPropagationHandling(parser, *
this);
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
535 printWithNanPropagationHandling(parser, *
this);
538ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
544 printWithEnumHandling(parser, *
this);
547ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *
this);
556ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
562 printWithEnumHandling(parser, *
this);
565ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
571 printWithEnumHandling(parser, *
this);
586 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
596 Value valZp, StringRef name) {
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
605 if (!bothInts || !sameBitWidth) {
607 <<
"expected " << name <<
" and " << name
608 <<
"_zp to both be integer of the same bitwidth, but got " << eType
609 <<
" vs. " << eZpType;
616 Value src, int32_t val) {
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
631 if (dyn_cast<tosa::mxint8Type>(type))
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
645 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
647 << dimName <<
" of " << operandName <<
" to match size " << currDim
648 <<
", got " << newDim;
655 auto printDim = [&](
int64_t dim) {
656 if (ShapedType::isDynamic(dim))
662 llvm::interleaveComma(
shape,
diag, printDim);
668 StringRef outputName =
"output") {
669 assert(outputType.hasRank() &&
"expected output type to be ranked");
675 diag << outputName <<
" shape ";
677 diag <<
" to be compatible with inferred shape ";
685 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
686 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
687 const llvm::StringRef padAfterName) {
688 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
693 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
694 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
696 if (!calculatedOutSizeMinusOne.has_value())
698 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
699 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
700 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
701 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
702 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
705 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
706 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
708 << dimName <<
" did not match expected: "
709 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
720 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
721 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
723 auto inputEType = inputType.getElementType();
724 auto weightEType = weightType.getElementType();
726 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
728 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
729 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
730 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
732 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
735 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
738 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
741 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
744 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
748 "expect both bias and result to have same element type, got ")
749 << biasEType <<
" and " << resultEType;
753 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
754 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
755 if (inputEType != weightEType) {
757 "expect both input and weight to have same element type, got ")
758 << inputEType <<
" and " << weightEType;
763 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
764 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
767 if (inputIsFloat != weightIsFloat) {
769 "expect both input and weight to be float or not together, got ")
770 << inputEType <<
" and " << weightEType;
775 if (inputEType != inputZpEType) {
776 return op.emitOpError(
"expect both input and its zero point are the same "
777 "element type, got ")
778 << inputEType <<
" and " << inputZpEType;
782 if (weightEType != weightZpEType) {
783 return op.emitOpError(
"expect both weight and its zero point are the same "
784 "element type, got ")
785 << weightEType <<
" and " << weightZpEType;
788 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
789 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
792 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
793 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
799LogicalResult tosa::ConstOp::verify() {
801 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
802 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
804 if (!attrType || !outputType) {
805 emitOpError(
"expected tensors for attr/result type");
809 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
810 outputType.getElementType())) {
815 if (attrType.getElementType() != outputType.getElementType()) {
816 emitOpError(
"expected same attr/result element types");
826 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
828 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
831 auto accType = op.getAccType();
832 if (inputEType.isInteger(8) && !accType.isInteger(32))
833 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
836 if (inputEType.isInteger(16) && !accType.isInteger(48))
837 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
840 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
841 !(accType.isF16() || accType.isF32()))
842 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
845 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
846 return op.emitOpError(
847 "accumulator type for f16 tensor is not f16/f32, got ")
850 if (inputEType.isBF16() && !accType.isF32())
851 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
854 if (inputEType.isF32() && !accType.isF32())
855 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
859 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
861 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
875 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
876 return op.emitOpError(
"expect all padding values to be >= 0, got ")
880 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
881 return op.emitOpError(
"expect all stride values to be >= 1, got ")
885 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
886 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
889 const RankedTensorType outputType =
890 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
895 const RankedTensorType inputType =
896 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
897 const RankedTensorType weightType =
898 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
900 if (inputType && weightType) {
902 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
904 op, inputType.getDimSize(1), weightType.getDimSize(1),
905 outputType.getDimSize(1), padding[0], padding[1], strides[0],
906 dilations[0],
"height",
"y",
"top",
"bottom")))
910 op, inputType.getDimSize(2), weightType.getDimSize(2),
911 outputType.getDimSize(2), padding[2], padding[3], strides[1],
912 dilations[1],
"width",
"x",
"left",
"right")))
917 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
919 op, inputType.getDimSize(1), weightType.getDimSize(0),
920 outputType.getDimSize(1), padding[0], padding[1], strides[0],
921 dilations[0],
"height",
"y",
"top",
"bottom")))
925 op, inputType.getDimSize(2), weightType.getDimSize(1),
926 outputType.getDimSize(2), padding[2], padding[3], strides[1],
927 dilations[1],
"width",
"x",
"left",
"right")))
932 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
934 op, inputType.getDimSize(1), weightType.getDimSize(1),
935 outputType.getDimSize(1), padding[0], padding[1], strides[0],
936 dilations[0],
"depth",
"d",
"front",
"back")))
940 op, inputType.getDimSize(2), weightType.getDimSize(2),
941 outputType.getDimSize(2), padding[2], padding[3], strides[1],
942 dilations[1],
"height",
"y",
"top",
"bottom")))
946 op, inputType.getDimSize(3), weightType.getDimSize(3),
947 outputType.getDimSize(3), padding[4], padding[5], strides[2],
948 dilations[2],
"width",
"x",
"left",
"right")))
953 const RankedTensorType biasType =
954 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
959 const int64_t biasChannels = biasType.getDimSize(0);
961 outputType.getDimSize(outputType.getRank() - 1);
962 if (biasChannels == ShapedType::kDynamic ||
963 outputChannels == ShapedType::kDynamic)
967 if (biasChannels != outputChannels && biasChannels != 1)
968 return op.emitOpError(
969 "bias channels expected to be equal to output channels (")
970 << outputChannels <<
") or 1, got " << biasChannels;
977 StringRef name1,
Type type2,
979 auto shapeType1 = dyn_cast<ShapedType>(type1);
980 auto shapeType2 = dyn_cast<ShapedType>(type2);
981 if (!shapeType1 || !shapeType2)
984 auto elemType1 = shapeType1.getElementType();
985 auto elemType2 = shapeType2.getElementType();
986 if (elemType1 != elemType2)
988 <<
"require same element type for " << name1 <<
" (" << elemType1
989 <<
") and " << name2 <<
" (" << elemType2 <<
")";
993 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
994 << name2 <<
" (" << type2 <<
")";
1004 if (list1.size() != list2.size())
1006 <<
"require same number of values in " << name1 <<
" ("
1007 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
1009 for (
auto [type1, type2] :
1026template <
typename T>
1029 op->template getParentWithTrait<OpTrait::SymbolTable>();
1036 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1040 return op->emitOpError(
"'")
1041 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1053template <
typename T>
1055 StringRef aName =
"input",
1056 StringRef bName =
"output") {
1057 auto aTType = llvm::dyn_cast<TensorType>(aType);
1058 auto bTType = llvm::dyn_cast<TensorType>(bType);
1060 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1064 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1067 auto aElementType = aTType.getElementType();
1068 auto bElementType = bTType.getElementType();
1070 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1072 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1073 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1074 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1075 aElementType != bElementType) {
1080 op.emitOpError(
"expect ")
1081 << aName <<
" and " << bName <<
" to have same element type, got "
1082 << aElementType <<
" and " << bElementType;
1088LogicalResult tosa::ArgMaxOp::verify() {
1089 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1092 if (
const auto resultETy = resultType.getElementType();
1093 !resultETy.isIntOrIndex())
1094 return emitOpError(
"result tensor is not of integer type");
1096 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1097 if (!inputType.hasRank())
1101 const int64_t axis = getAxisAttr().getInt();
1102 if (((axis < 0) || axis >= inputType.getRank()))
1103 return emitOpError(
"specified axis is outside the rank of the tensor");
1105 if (!resultType.hasRank())
1111 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1114 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1124 const bool hasKernel = kernel.size() > 0;
1125 const bool hasStrides = strides.size() > 0;
1126 const bool hasPad = padding.size() > 0;
1128 if (hasKernel && llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1129 return op->
emitOpError(
"expect all kernel values to be >= 1, got ")
1132 if (hasStrides && llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1133 return op->
emitOpError(
"expect all stride values to be >= 1, got ")
1136 if (hasPad && llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1137 return op->
emitOpError(
"expect all padding values to be >= 0, got ")
1140 if (hasKernel && hasPad) {
1142 const int64_t kernelX = kernel[1];
1143 const int64_t padLeft = padding[2];
1144 const int64_t padRight = padding[3];
1145 if (padRight >= kernelX || padLeft >= kernelX)
1146 return op->
emitOpError(
"expected left/right padding to be less than the "
1147 "width of the kernel, got pad_left=")
1148 << padLeft <<
", pad_right=" << padRight
1149 <<
", kernel_x=" << kernelX;
1151 const int64_t kernelY = kernel[0];
1152 const int64_t padTop = padding[0];
1153 const int64_t padBottom = padding[1];
1154 if (padTop >= kernelY || padBottom >= kernelY)
1155 return op->
emitOpError(
"expected top/bottom padding to be less than the "
1156 "height of the kernel, got pad_top=")
1157 << padTop <<
", pad_bottom=" << padBottom
1158 <<
", kernel_y=" << kernelY;
1161 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.
getType());
1162 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.
getType());
1163 if (!inputType || !outputType)
1166 if (hasKernel && hasStrides && hasPad) {
1167 const auto verifyOutputSize =
1171 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1172 const llvm::StringRef padBeforeName,
1173 const llvm::StringRef padAfterName) -> LogicalResult {
1174 if (ShapedType::isDynamic(inputSize))
1177 const std::optional<int64_t> calculatedOutSizeMinusOne =
1178 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1179 if (!calculatedOutSizeMinusOne.has_value())
1181 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1182 << padAfterName <<
" - kernel_" << dimAxis
1183 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1184 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1185 << kernelSize <<
") / " << strideSize;
1187 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1188 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1190 << dimName <<
" did not match expected: " <<
"calculated="
1191 << calculatedOutSize <<
", expected=" << outputSize;
1196 if (failed(verifyOutputSize(inputType.getDimSize(1),
1197 outputType.getDimSize(1), kernel[0], strides[0],
1198 padding[0], padding[1],
"height",
"y",
"top",
1202 if (failed(verifyOutputSize(
1203 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1204 strides[1], padding[2], padding[3],
"width",
"x",
"left",
"right")))
1210template <
typename T>
1213 op.getPad(), op.getInput(), op.getOutput());
1216template <
typename T>
1220 const Type inputZpETy =
1222 const Type outputZpETy =
1225 auto accType = op.getAccType();
1226 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1227 return op.emitOpError(
"accumulator type for integer tensor is not i32");
1229 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1230 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1232 if (inputETy.
isBF16() && !accType.isF32())
1233 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
1235 if (inputETy.
isF32() && !accType.isF32())
1236 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
1238 if (inputETy != inputZpETy)
1239 return op.emitOpError(
"expect both input and its zero point are the same "
1240 "element type, got ")
1241 << inputETy <<
" and " << inputZpETy;
1243 if (resultETy != outputZpETy)
1244 return op.emitOpError(
"expect both output and its zero point are the same "
1245 "element type, got ")
1246 << resultETy <<
" and " << outputZpETy;
1248 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1249 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1252 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1253 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1260struct AdaptivePoolingConstShapeValues {
1261 llvm::SmallVector<int64_t> kernel;
1262 llvm::SmallVector<int64_t> stride;
1263 llvm::SmallVector<int64_t> pad;
1267template <
typename T>
1269 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1270 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1272template <
typename T,
1273 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1276 T op, AdaptivePoolingConstShapeValues &values) {
1282LogicalResult tosa::AvgPool2dOp::verify() {
1290LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1291 AdaptivePoolingConstShapeValues values;
1300 values.pad, getInput(), getOutput())))
1309LogicalResult tosa::ClampOp::verify() {
1311 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1312 if (
auto quantType =
1313 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1317 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1318 if (
auto quantType =
1319 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1322 if (inputETy != outputETy)
1323 return emitOpError(
"input/output element types are incompatible.");
1325 auto maxValAttr = getMaxValAttr();
1326 auto minValAttr = getMinValAttr();
1330 if (inputETy.
isInteger(dataTypeBitWidth)) {
1334 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1335 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1336 if (!intMaxValAttr || !intMinValAttr ||
1337 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1338 (intMaxValAttr.getType() != inputETy))
1339 return emitOpError(
"min/max attributes types are incompatible with "
1340 "input/output element types.");
1343 const bool isBoolean = inputETy.
isInteger(1);
1344 const APInt minVal = intMinValAttr.getValue();
1345 const APInt maxVal = intMaxValAttr.getValue();
1346 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1347 return emitOpError(
"expected min_val <= max_val, got min_val=")
1348 << minValAttr <<
", max_val=" << maxValAttr;
1353 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1354 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1355 if (!floatMaxValAttr || !floatMinValAttr ||
1356 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1357 (floatMaxValAttr.getType() != inputETy))
1358 return emitOpError(
"min/max attributes types are incompatible with "
1359 "input/output element types.");
1361 const APFloat minVal = floatMinValAttr.getValue();
1362 const APFloat maxVal = floatMaxValAttr.getValue();
1363 if (minVal.isNaN() || maxVal.isNaN())
1364 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1365 << minValAttr <<
", max_val=" << maxValAttr;
1367 if (maxVal < minVal)
1368 return emitOpError(
"expected min_val <= max_val, got min_val=")
1369 << minValAttr <<
", max_val=" << maxValAttr;
1389 result.addOperands({input, weight, bias, zps.first, zps.second});
1390 result.addAttribute(
"pad", pad);
1391 result.addAttribute(
"stride", stride);
1392 result.addAttribute(
"dilation", dilation);
1393 result.addAttribute(
"acc_type", accType);
1394 Type finalOutputType = outputType;
1400 result.addTypes(finalOutputType);
1411 result.addOperands({input, weight, bias, zps.first, zps.second});
1412 result.addAttribute(
"out_pad", outpad);
1413 result.addAttribute(
"stride", stride);
1414 result.addAttribute(
"acc_type", accType);
1415 Type finalOutputType = outputType;
1421 result.addTypes(finalOutputType);
1432 result.addOperands({a,
b, zps.first, zps.second});
1434 Type finalOutputType{outputType};
1437 auto inputBits = eType.getIntOrFloatBitWidth();
1439 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1440 assert(outputShapedType &&
"Output must be a shaped type");
1442 IntegerType accElementType;
1443 if (inputBits == 16)
1448 finalOutputType = outputShapedType.clone(accElementType);
1450 result.addTypes(finalOutputType);
1459 DenseArrayAttr kernel, DenseArrayAttr stride,
1460 DenseArrayAttr pad, TypeAttr accType) {
1465 if (
auto quantAttr =
1467 inputZp = quantAttr.getInputZp();
1468 outputZp = quantAttr.getOutputZp();
1470 const std::optional<Value> inputZpOp =
1475 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1477 const std::optional<Value> outputZpOp =
1480 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1481 "quantized AVG_POOL2D op");
1484 if (inputZpOp && outputZpOp) {
1485 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1490 result.addOperands({input});
1492 result.addAttribute(
"kernel", kernel);
1493 result.addAttribute(
"stride", stride);
1494 result.addAttribute(
"pad", pad);
1495 result.addAttribute(
"acc_type", accType);
1496 result.types.push_back(outputType);
1509 if (
auto quantAttr =
1511 inputZp = quantAttr.getInputZp();
1512 outputZp = quantAttr.getOutputZp();
1514 const std::optional<Value> inputZpOp =
1518 "Failed to create input zero point tensor for quantized "
1519 "AVG_POOL2D_ADAPTIVE op");
1521 const std::optional<Value> outputZpOp =
1524 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1525 "quantized AVG_POOL2D_ADAPTIVE op");
1528 if (inputZpOp && outputZpOp) {
1533 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1534 kernelShape, strideShape, padShape});
1539 result.addOperands({input});
1541 result.addAttribute(
"acc_type", accType);
1542 result.types.push_back(outputType);
1556 input1Zp = quantAttr.getInputZp();
1557 outputZp = quantAttr.getOutputZp();
1559 const std::optional<Value> input1ZpOp =
1563 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1566 const std::optional<Value> outputZpOp =
1570 loc,
"Failed to create output zero point for quantized NEGATE op");
1573 if (input1ZpOp && outputZpOp) {
1574 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1579 result.addOperands({input});
1582 result.types.push_back(outputType);
1595 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1598 result.addOperands({input, paddings, padConstOp});
1599 result.types.push_back(outputType);
1603 StringRef name,
Type variableType,
1608 auto shapedType = dyn_cast<ShapedType>(variableType);
1610 (
void)
emitError(loc,
"variable type must be a shaped type");
1613 if (!shapedType.hasRank()) {
1614 (
void)
emitError(loc,
"variable type must be a ranked type");
1618 auto elementType = shapedType.getElementType();
1619 auto elementTypeAttr = TypeAttr::get(elementType);
1623 result.addAttribute(
"sym_name", nameAttr);
1624 result.addAttribute(
"var_shape", varShapeAttr);
1625 result.addAttribute(
"type", elementTypeAttr);
1626 result.addAttribute(
"initial_value", initialValue);
1639 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1643 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1649 for (
int i = 0, e = operands.size(); i != e; ++i) {
1651 if (!
shape.hasRank()) {
1656 outRank = std::max<int64_t>(outRank,
shape.getRank());
1659 outShape.resize(outRank, 1);
1661 for (
int i = 0, e = operands.size(); i != e; ++i) {
1663 auto rankDiff = outShape.size() -
shape.getRank();
1665 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1666 auto dim1 = outShape[i + rankDiff];
1667 auto dim2 =
shape.getDimSize(i);
1669 const FailureOr<int64_t> maybeResolvedDim =
1671 if (failed(maybeResolvedDim))
1673 const int64_t resolvedDim = *maybeResolvedDim;
1674 outShape[i + rankDiff] = resolvedDim;
1681LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1682 MLIRContext *context, ::std::optional<Location> location,
1683 ArgMaxOp::Adaptor adaptor,
1686 IntegerAttr axis = adaptor.getProperties().axis;
1687 int32_t axisVal = axis.getValue().getSExtValue();
1689 if (!inputShape.hasRank()) {
1695 outShape.reserve(inputShape.getRank() - 1);
1696 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1699 outShape.push_back(inputShape.getDimSize(i));
1706LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1707 MLIRContext *context, ::std::optional<Location> location,
1708 RFFT2dOp::Adaptor adaptor,
1710 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1712 if (!inputShape.hasRank())
1716 outputShape.resize(3, ShapedType::kDynamic);
1717 outputShape[0] = inputShape.getDimSize(0);
1718 outputShape[1] = inputShape.getDimSize(1);
1719 int64_t inWidth = inputShape.getDimSize(2);
1723 if (inWidth != ShapedType::kDynamic)
1724 outputShape[2] = inWidth / 2 + 1;
1733 const llvm::StringRef dimName) {
1734 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1737 << dimName <<
" to be a power of two, got " << dimSize;
1742LogicalResult tosa::RFFT2dOp::verify() {
1743 const auto outputTypes = getResultTypes();
1745 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1747 const auto inputType =
1748 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1752 const int64_t height = inputType.getDimSize(1);
1753 if (ShapedType::isStatic(height) &&
1757 const int64_t width = inputType.getDimSize(2);
1758 if (ShapedType::isStatic(width) &&
1762 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1768 outputType.getShape().drop_back())))
1769 return emitOpError(
"expected batch and height dimensions of input/output "
1770 "to match, got input=")
1771 << inputType <<
" output=" << outputType;
1774 const int64_t outputWidth = outputType.getDimSize(2);
1775 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1776 (outputWidth != (width / 2) + 1))
1778 "expected output width to be equal to input_width / 2 + 1, got ")
1784LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1785 MLIRContext *context, ::std::optional<Location> location,
1786 FFT2dOp::Adaptor adaptor,
1788 inferredReturnShapes.push_back(
1790 inferredReturnShapes.push_back(
1795LogicalResult tosa::FFT2dOp::verify() {
1796 const auto inputRealType =
1797 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1798 const auto inputImagType =
1799 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1800 if (!inputRealType || !inputImagType)
1803 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1804 return ShapedType::isDynamic(a) ? a :
b;
1807 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1808 inputImagType.getDimSize(1));
1809 if (ShapedType::isStatic(height) &&
1813 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1814 inputImagType.getDimSize(2));
1815 if (ShapedType::isStatic(width) &&
1822LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1823 MLIRContext *context, ::std::optional<Location> location,
1824 ConcatOp::Adaptor adaptor,
1827 const Properties &prop = adaptor.getProperties();
1828 int32_t axis = prop.axis.getValue().getSExtValue();
1830 bool hasRankedInput =
false;
1831 for (
auto operand : adaptor.getOperands()) {
1833 if (!operandShape.hasRank())
1837 if (!hasRankedInput)
1838 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1841 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1842 if (i == axis || operandShape.isDynamicDim(i))
1844 if (outputShape[i] == ShapedType::kDynamic)
1845 outputShape[i] = operandShape.getDimSize(i);
1846 if (outputShape[i] != operandShape.getDimSize(i))
1848 "Cannot concat tensors with different sizes"
1849 " on the non-axis dimension ",
1853 hasRankedInput =
true;
1856 if (adaptor.getInput1().empty())
1860 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1861 if (!hasRankedInput) {
1868 for (
auto operand : adaptor.getOperands()) {
1873 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1874 concatDimSize = ShapedType::kDynamic;
1878 concatDimSize += operandShape.getDimSize(axis);
1881 outputShape[axis] = concatDimSize;
1887LogicalResult tosa::ConcatOp::verify() {
1889 auto outType = getOutput().getType();
1893 if (inputList.empty())
1896 if (!llvm::all_of(inputList, [&](
auto input) {
1898 *
this, input.getType(), outType));
1903 const int32_t axis = getAxis();
1905 for (
const auto &input : inputList) {
1906 const Type inputType = input.getType();
1908 if (currShape.hasRank()) {
1909 firstRankedInputShape = currShape;
1911 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1912 return emitOpError(
"expect axis to be within range 0 < axis < "
1913 "rank(input1[firstRankedTensorIdx]), got ")
1919 const auto allOperandsHasRank = [](
const Value input) {
1922 if (llvm::all_of(inputList, allOperandsHasRank)) {
1925 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1927 const int64_t inputRank = inputShape.getRank();
1928 const size_t operandNum =
index + 1;
1931 if (inputRank != firstInputRank)
1933 "expect all operands to have the same rank, but got ")
1934 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1938 for (
int i = 0; i < inputRank; i++) {
1939 const int64_t inputDim = inputShape.getDimSize(i);
1941 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1942 inputShape.isDynamicDim(i))
1944 if (inputDim != firstInputDim)
1945 return emitOpError(
"expect all operand shapes to have the same sizes "
1946 "on non-axis dimensions, but got ")
1947 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1948 <<
" on operands 0 and " << operandNum;
1953 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1954 return emitOpError(
"expect output rank to match inputs rank, got ")
1955 << outputShape.getRank() <<
" vs " << firstInputRank;
1959 for (
const auto &input : inputList) {
1961 if (inputShape.isDynamicDim(axis)) {
1966 axisSum += inputShape.getDimSize(axis);
1969 if (axisSum >= 0 && outputShape.hasRank() &&
1970 !outputShape.isDynamicDim(axis) &&
1971 axisSum != outputShape.getDimSize(axis))
1972 return emitOpError(
"requires sum of axis dimensions of input1 "
1973 "equal to output axis dimension, got ")
1974 << axisSum <<
" and " << outputShape.getDimSize(axis);
1980LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1981 MLIRContext *context, ::std::optional<Location> location,
1985 auto elementType = IntegerType::get(context, 1);
1998 if (l.size() != r.size() || l.size() != 1)
2003LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
2004 MLIRContext *context, ::std::optional<Location> location,
2005 MatMulOp::Adaptor adaptor,
2012 outShape.resize(3, ShapedType::kDynamic);
2014 if (lhsShape.hasRank()) {
2015 outShape[0] = lhsShape.getDimSize(0);
2016 outShape[1] = lhsShape.getDimSize(1);
2019 if (rhsShape.hasRank()) {
2020 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2022 outShape[2] = rhsShape.getDimSize(2);
2029LogicalResult MatMulOp::verify() {
2032 const Type aElementType = aShape.getElementType();
2033 const Type bElementType = bShape.getElementType();
2035 const auto aQuantizedEType =
2036 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2037 const auto bQuantizedEType =
2038 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2040 if (aQuantizedEType || bQuantizedEType) {
2041 if (!aQuantizedEType || !bQuantizedEType) {
2042 return emitOpError(
"expect operands to be both quantized or both not "
2044 << aElementType <<
" and " << bElementType;
2047 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2048 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2049 if (aQuantWidth != bQuantWidth) {
2050 return emitOpError(
"expect quantized operands to have same widths, got ")
2051 << aQuantWidth <<
" and " << bQuantWidth;
2058 if (aEType != aZpEType)
2059 return emitOpError(
"expect input a and a_zp have the same "
2060 "element type, got ")
2061 << aEType <<
" and " << aZpEType;
2065 if (bEType != bZpEType)
2066 return emitOpError(
"expect input b and b_zp have the same "
2067 "element type, got ")
2068 << bEType <<
" and " << bZpEType;
2070 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2071 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2074 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2075 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2079 int64_t N = ShapedType::kDynamic;
2080 int64_t H = ShapedType::kDynamic;
2084 if (aShape.hasRank()) {
2085 N = aShape.getDimSize(0);
2086 H = aShape.getDimSize(1);
2087 C = aShape.getDimSize(2);
2090 if (bShape.hasRank()) {
2096 W = bShape.getDimSize(2);
2100 const auto outputType = cast<ShapedType>(getResult().
getType());
2101 if (outputType.hasRank() &&
2106 opError <<
" to be compatible with expected output shape ";
2114LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2115 MLIRContext *context, ::std::optional<Location> location,
2116 MatmulTBlockScaledOp::Adaptor adaptor,
2120 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2121 if (aDataShape.hasRank()) {
2122 outShape[0] = aDataShape.getDimSize(0);
2123 outShape[1] = aDataShape.getDimSize(1);
2126 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2127 if (aScaleShape.hasRank()) {
2128 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2130 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2135 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2136 if (bDataShape.hasRank()) {
2137 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2138 if (bDataBatchSize != 1)
2140 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2141 outShape[2] = bDataShape.getDimSize(1);
2144 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2145 if (bScaleShape.hasRank()) {
2146 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2147 if (bScaleBatchSize != 1)
2149 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2150 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2158LogicalResult MatmulTBlockScaledOp::verify() {
2160 const Type aDataType = getAData().getType();
2161 const Type bDataType = getBData().getType();
2167 int64_t N = ShapedType::kDynamic;
2168 int64_t D = ShapedType::kDynamic;
2169 int64_t H = ShapedType::kDynamic;
2172 int64_t multiplesOfC = ShapedType::kDynamic;
2184 "a_scale",
"batch")) ||
2186 "a_scale",
"height")))
2194 "b_data",
"batch")) ||
2196 "b_data",
"channels")))
2204 "b_scale",
"batch")) ||
2206 "b_scale",
"width")) ||
2214 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2215 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2217 << D <<
" vs N=" << N;
2220 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2221 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2222 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2223 if (ShapedType::isStatic(C) && C % blockSize != 0)
2224 return emitOpError(
"expect C to be a multiple of block size, got C=")
2225 <<
C <<
", block_size=" << blockSize;
2228 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2229 multiplesOfC != C / blockSize)
2231 "expect scale operands dimension 2 to equal C/block_size (")
2232 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2235 N = ShapedType::isDynamic(N) ? D : N;
2237 const auto outputType = cast<ShapedType>(getResult().
getType());
2238 if (outputType.hasRank() &&
2243 opError <<
" to be compatible with expected output shape ";
2251LogicalResult tosa::PadOp::inferReturnTypeComponents(
2252 MLIRContext *context, ::std::optional<Location> location,
2253 PadOp::Adaptor adaptor,
2255 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2257 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2262 if (!inputShape.hasRank()) {
2263 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2272 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2277 outputShape.reserve(inputShape.getRank());
2278 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2279 if (inputShape.isDynamicDim(i)) {
2280 outputShape.push_back(ShapedType::kDynamic);
2283 auto padFront = paddingValues[i * 2];
2284 auto padBack = paddingValues[i * 2 + 1];
2285 if (padFront < 0 || padBack < 0) {
2287 outputShape.push_back(ShapedType::kDynamic);
2291 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2298LogicalResult tosa::PadOp::verify() {
2305 if (
auto padConst = getPadConst()) {
2313 RankedTensorType inputType =
2314 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2315 RankedTensorType outputType =
2316 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2317 if (!inputType || !outputType)
2324 auto inputRank = inputType.getRank();
2329 auto paddingValues = paddingAttr.getValues<APInt>();
2330 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2331 return emitOpError() <<
"padding tensor must have " << inputRank
2332 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2333 << paddingValues.size();
2335 auto inputShape = inputType.getShape();
2336 auto outputShape = outputType.getShape();
2338 for (
int64_t i = 0; i < inputRank; ++i) {
2339 int64_t padStart = paddingValues[i * 2].getSExtValue();
2340 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2342 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2344 <<
"invalid padding values at dimension " << i
2345 <<
": values must be non-negative or -1 for dynamic padding, got ["
2346 << padStart <<
", " << padEnd <<
"]";
2350 if (inputShape[i] == ShapedType::kDynamic ||
2351 outputShape[i] == ShapedType::kDynamic)
2354 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2355 return emitOpError() <<
"mismatch in output shape at dimension " << i
2356 <<
": expected " << inputShape[i] <<
" + "
2357 << padStart <<
" + " << padEnd <<
" = "
2358 << (inputShape[i] + padStart + padEnd)
2359 <<
", but got " << outputShape[i];
2366LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2367 MLIRContext *context, ::std::optional<Location> location,
2368 SliceOp::Adaptor adaptor,
2377 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2385 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2388 if (inputShape.hasRank()) {
2389 for (
size_t i = 0; i < size.size(); i++) {
2390 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2391 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2392 start[i] < inputShape.getDimSize(i))) {
2394 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2397 outputShape[i] = size[i];
2401 if (size[i] == -1) {
2402 outputShape[i] = inputShape.getDimSize(i) - start[i];
2403 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2405 outputShape[i] = size[i];
2417LogicalResult tosa::SliceOp::verify() {
2418 const Value input = getInput1();
2419 const Value output = getOutput();
2425 const Value start = getStart();
2426 const Value size = getSize();
2430 if (inputShape.hasRank()) {
2431 const auto inputRank = inputShape.getRank();
2432 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2434 "expect input1 and output to have the same ranks, got ")
2435 << inputRank <<
" and " << outputShape.getRank();
2437 const auto startShapeRank =
2438 llvm::cast<tosa::shapeType>(start.
getType()).getRank();
2439 if (inputRank != startShapeRank)
2440 return emitOpError(
"length of start is not equal to rank of input shape");
2442 const auto sizeShapeRank =
2443 llvm::cast<tosa::shapeType>(size.
getType()).getRank();
2444 if (inputRank != sizeShapeRank)
2445 return emitOpError(
"length of size is not equal to rank of input shape");
2450 if (startValues.size()) {
2451 if (llvm::any_of(startValues, [](
const int64_t v) {
2454 return emitOpError(
"start values must be non-negative, got [")
2455 << startValues <<
"]";
2462 if (llvm::any_of(sizeValues, [](
const int64_t v) {
2465 return emitOpError(
"size values must be > 0, got [") << sizeValues <<
"]";
2466 if (outputShape.hasRank()) {
2468 outputShape.getDims(outputDims);
2469 const bool hasNoInferableDims = llvm::all_of(
2471 if (hasNoInferableDims &&
2473 return emitOpError(
"expected output shape to match size values, got ")
2474 << output.
getType() <<
" vs [" << sizeValues <<
"]";
2477 if (inputShape.hasRank() && startValues.size()) {
2479 inputShape.getDims(inputDims);
2480 for (
const auto &[
index, vals] :
2481 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2482 const auto &[start, size, inputDim] = vals;
2484 ShapedType::isDynamic(inputDim))
2486 if (start + size > inputDim)
2487 return emitOpError(
"start + size must be less than or equal to input "
2488 "dimension size, got start=")
2489 << start <<
", size=" << size
2490 <<
" vs input dim size=" << inputDim <<
" at dimension "
2498LogicalResult tosa::MulOp::inferReturnTypeComponents(
2499 MLIRContext *context, ::std::optional<Location> location,
2514LogicalResult tosa::MulOp::verify() {
2515 const Value output = getOutput();
2520 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2521 IntegerType lhsIntType =
2523 IntegerType rhsIntType =
2525 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2526 return emitOpError(
"requires the same element type for all operands");
2531 if (lhsIntType.getWidth() > resIntType.getWidth())
2532 return emitOpError(
"invalid data type size for operands or result");
2537 for (
int i = 0; i < 2; ++i) {
2540 "requires the same element type for all operands and results");
2544 ElementsAttr shiftElem;
2546 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2548 return emitOpError() <<
"require shift to be 0 for float type";
2556 TypeRange operandTypes = getOperandTypes();
2557 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2558 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2560 const bool aHasRank = aType.hasRank();
2561 const bool bHasRank = bType.hasRank();
2563 bool hasExpectedOutputShape =
false;
2566 if (aHasRank && bHasRank) {
2567 const int64_t aRank = aType.getRank();
2568 const int64_t bRank = bType.getRank();
2570 return emitOpError(
"a and b operands don't have matching ranks, got ")
2571 << aRank <<
" and " << bRank;
2575 aType.getShape(), bType.getShape(), expectedOutputShape))
2576 return emitOpError(
"a and b operands don't have broadcast-compatible "
2578 << aType <<
" and " << bType;
2579 hasExpectedOutputShape =
true;
2582 ShapedType resultType = cast<ShapedType>(output.
getType());
2583 if (!resultType.hasRank())
2586 const int64_t resultRank = resultType.getRank();
2587 if (aHasRank && resultRank != aType.getRank())
2588 return emitOpError(
"result type has different rank than a, got ")
2589 << resultRank <<
" vs " << aType.getRank();
2590 if (bHasRank && resultRank != bType.getRank())
2591 return emitOpError(
"result type has different rank than b, got ")
2592 << resultRank <<
" vs " << bType.getRank();
2594 if (hasExpectedOutputShape &&
2596 expectedOutputShape)))
2602LogicalResult tosa::TableOp::inferReturnTypeComponents(
2603 MLIRContext *context, ::std::optional<Location> location,
2604 TableOp::Adaptor adaptor,
2606 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2608 if (!inputShape.hasRank()) {
2613 inferredReturnShapes.resize(1);
2614 inputShape.getDims(inferredReturnShapes[0]);
2618LogicalResult tosa::TableOp::verify() {
2619 const TensorType inputType = getInput1().getType();
2620 const TensorType outputType = getOutput().getType();
2629 auto inputDims = inputType.
getShape();
2630 auto outputDims = outputType.
getShape();
2631 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2633 auto [inputDim, outputDim] = it.value();
2634 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2635 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2636 <<
" doesn't match dim(input, " << dim
2637 <<
") = " << inputDim;
2650 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2651 [](
const APInt &val) { return val.getSExtValue(); });
2655LogicalResult tosa::TileOp::inferReturnTypeComponents(
2656 MLIRContext *context, ::std::optional<Location> location,
2657 TileOp::Adaptor adaptor,
2664 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2671 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2673 if (!inputShape.hasRank()) {
2674 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2675 inferredReturnShapes.push_back(
2679 if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2683 outputShape.reserve(multiples.size());
2684 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2685 if (multiples[i] == ShapedType::kDynamic) {
2686 outputShape.push_back(ShapedType::kDynamic);
2688 int64_t dim = inputShape.getDimSize(i);
2689 if (dim != ShapedType::kDynamic)
2690 dim *= multiples[i];
2691 outputShape.push_back(dim);
2699LogicalResult tosa::TileOp::verify() {
2705 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2706 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2708 shapeType multiplesType =
2709 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2711 auto multiplesRank = multiplesType.getRank();
2713 if (inputType.hasRank()) {
2714 if (inputType.getRank() != multiplesRank)
2715 return emitOpError(
"expect 'multiples' to have rank ")
2716 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2717 if (outputType.hasRank() &&
2721 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2722 return emitOpError(
"expect 'multiples' array to have length ")
2723 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2726 if (getConstantMultiples(multiples).succeeded() &&
2727 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2729 "expect element of 'multiples' to be positive integer or -1.");
2735 if (l.size() != r.size() || l.size() != 1)
2740LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2741 MLIRContext *context, ::std::optional<Location> location,
2742 ReshapeOp::Adaptor adaptor,
2744 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2749 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2758 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2759 inferredReturnShapes.push_back(
2767 int64_t numElements = inputShape.getNumElements();
2769 for (
auto val : newShapeValue) {
2770 if (ShapedType::isStatic(val)) {
2776 for (
auto &val : newShapeValue) {
2777 if (ShapedType::isDynamic(val))
2778 val = numElements / staticMul;
2781 inferredReturnShapes.push_back(
2786llvm::LogicalResult tosa::ReshapeOp::verify() {
2792 TensorType inputType = getInput1().getType();
2797 return mlir::success();
2801 if (missingDims > 1)
2802 return emitOpError() <<
"expected at most one target dimension to be "
2805 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2809 if ((
int64_t)shapeValues.size() != outputType.getRank())
2810 return emitOpError() <<
"new shape does not match result rank";
2812 for (
auto [newShapeDim, outputShapeDim] :
2813 zip(shapeValues, outputType.getShape())) {
2815 newShapeDim != ShapedType::kDynamic &&
2816 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2817 return emitOpError() <<
"new shape is inconsistent with result shape";
2820 return emitOpError() <<
"new shape has invalid tensor dimension size "
2824 if (inputType.hasStaticShape()) {
2825 int64_t inputElementsNum = inputType.getNumElements();
2826 if (outputType.hasStaticShape()) {
2827 int64_t outputElementsNum = outputType.getNumElements();
2828 if (inputElementsNum != outputElementsNum) {
2829 return emitOpError() <<
"cannot reshape " << inputElementsNum
2830 <<
" elements into " << outputElementsNum;
2836 return (dim > 0) ?
acc * dim :
acc;
2838 bool isStaticNewShape =
2839 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2840 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2841 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2842 return emitOpError() <<
"cannot reshape " << inputElementsNum
2843 <<
" elements into " << newShapeElementsNum;
2847 return mlir::success();
2850bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(
TypeRange l,
2852 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2860LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2861 MLIRContext *context, ::std::optional<Location> location,
2862 ReshapeBlockScaledOp::Adaptor adaptor,
2865 const auto numInputs = adaptor.getInput().size();
2866 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2869 const auto newShape = adaptor.getNewValueShape();
2871 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2880 const uint32_t blockSize =
2881 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2884 if (numInputs == 2) {
2885 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2886 if (ShapedType::isStatic(newScaleShapeValue.back()))
2887 newScaleShapeValue.back() /= blockSize;
2890 inferredReturnShapes.push_back(
2892 if (numInputs == 2) {
2894 for (
size_t idx = 0; idx < newShapeValue.size(); idx++) {
2895 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2896 newScaleShapeValue[idx] = newShapeValue[idx];
2897 if (idx == (newShapeValue.size() - 1))
2898 newScaleShapeValue[idx] /= blockSize;
2909llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2913 if (inputList.size() == 0)
2914 return emitOpError(
"requires at least one input");
2916 if (inputList.size() > 2)
2917 return emitOpError(
"requires at most two inputs");
2919 if (inputList.size() != outputList.size())
2920 return emitOpError(
"requires number of results to match inputs");
2928 const auto inputType = llvm::cast<ShapedType>(inputList[0].
getType());
2929 if (!inputType.hasRank())
2931 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2933 if (inputList.size() == 2) {
2934 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2935 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2936 if (llvm::any_of(inputList, [](
Value v) {
2937 const auto input = cast<ShapedType>(v.
getType());
2938 return input.hasRank() && input.getRank() == 0;
2941 "requires all input shapes have a rank greater than 0");
2942 if (llvm::any_of(outputList, [](
Value v) {
2943 const auto output = cast<ShapedType>(v.
getType());
2944 return output.hasRank() && output.getRank() == 0;
2947 "requires all result shapes have a rank greater than 0");
2955 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].
getType());
2956 if (inputScaleType.hasRank()) {
2957 if (inputType.getRank() != inputScaleType.getRank())
2958 return emitOpError(
"input shapes do not have same rank");
2961 for (
auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2962 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2963 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2964 if (ShapedType::isStatic(inputValueDim) &&
2965 ShapedType::isStatic(inputScaleDim) &&
2966 inputValueDim != inputScaleDim)
2967 return emitOpError(
"input shapes for data and scale do not match on "
2974 inputType.getDimSize(inputType.getRank() - 1);
2975 if (ShapedType::isStatic(lastValueDim)) {
2976 if (lastValueDim % blockSize != 0)
2977 return emitOpError(
"expect last dimension of input_data (")
2978 << lastValueDim <<
") to be divisible by block_size ("
2979 << blockSize <<
")";
2982 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
2984 if (ShapedType::isStatic(lastScaleDim) &&
2985 lastScaleDim != lastValueDim / blockSize)
2986 return emitOpError(
"expect last dimension of scale_data (")
2987 << lastScaleDim <<
") to be " << lastValueDim <<
"/"
2992 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
2993 return emitOpError(
"expect block size to be 1, got ") << blockSize;
3001 return mlir::success();
3004 if (inputList.size() == 2) {
3005 if (
static_cast<int64_t>(shapeValues.size()) == 0)
3006 return emitOpError(
"requires new shape to have a rank greater than 0");
3008 const int64_t lastShapeDim = shapeValues.back();
3009 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
3010 return emitOpError(
"expect last dimension of new shape (")
3011 << lastShapeDim <<
") to be divisible by block_size (" << blockSize
3015 const auto outputType = llvm::cast<ShapedType>(outputList[0].
getType());
3016 if (!outputType.hasRank())
3019 if (
static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
3020 return emitOpError() <<
"result does not match new shape rank";
3022 for (
auto [newShapeDim, outputShapeDim] :
3023 zip(shapeValues, outputType.getShape())) {
3024 if (ShapedType::isStatic(newShapeDim) &&
3025 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3026 return emitOpError() <<
"result shape is inconsistent with new shape";
3029 if (outputList.size() == 2) {
3033 scaleShapeValues.back() /= blockSize;
3035 const auto outputScaleType =
3036 llvm::cast<ShapedType>(outputList[1].
getType());
3037 if (outputScaleType.hasRank()) {
3038 if ((
int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3039 return emitOpError() <<
"result scale does not match new shape rank";
3041 for (
auto [newScaleShapeDim, outputScaleShapeDim] :
3042 zip(scaleShapeValues, outputScaleType.getShape())) {
3043 if (ShapedType::isStatic(newScaleShapeDim) &&
3044 ShapedType::isStatic(outputScaleShapeDim) &&
3045 newScaleShapeDim != outputScaleShapeDim)
3047 <<
"result scale shape is inconsistent with new shape";
3052 if (inputType.hasStaticShape()) {
3053 int64_t inputElementsNum = inputType.getNumElements();
3054 if (outputType.hasStaticShape()) {
3055 int64_t outputElementsNum = outputType.getNumElements();
3056 if (inputElementsNum != outputElementsNum) {
3057 return emitOpError() <<
"cannot reshape " << inputElementsNum
3058 <<
" elements into " << outputElementsNum;
3064 return (dim > 0) ?
acc * dim :
acc;
3066 bool isStaticNewShape =
3067 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
3068 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3069 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3070 return emitOpError() <<
"cannot reshape " << inputElementsNum
3071 <<
" elements into " << newShapeElementsNum;
3075 return mlir::success();
3082 ElementsAttr zpAttr;
3087 Type zpElemType = zpAttr.getElementType();
3089 if (llvm::isa<FloatType>(zpElemType)) {
3090 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3097 if (llvm::isa<IntegerType>(zpElemType)) {
3099 return zpAttr.getValues<APInt>()[0].getSExtValue();
3100 return zpAttr.getValues<APInt>()[0].getZExtValue();
3112 if (!llvm::isa<IntegerType>(attr.getElementType()) ||
3113 attr.getNumElements() != 1)
3116 return attr.getValues<APInt>()[0].getSExtValue();
3119template <
typename T>
3121 const std::string &operand) {
3124 if (!zpElemType.
isInteger(8) && zp != 0) {
3126 std::string lower = operand;
3127 llvm::transform(lower, lower.begin(), ::tolower);
3128 return op.emitOpError()
3129 << lower <<
" zero point must be zero for non-int8 integer types";
3137 const std::string &operand) {
3138 bool isInputZp = (operand ==
"Input");
3140 bool tensorUnsigned =
3141 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3142 StringRef tensorName = isInputZp ?
"input" :
"output";
3148 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
3149 return op.emitOpError()
3150 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
3152 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
3153 return op.emitOpError() <<
"expect " << tensorName
3154 <<
"_zp of 0 or 32768 for unsigned int16 "
3155 << tensorName <<
", got " << zp;
3162#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3163 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3164 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3166 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3167 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3188#undef ZERO_POINT_HELPER
3190LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3191 MLIRContext *context, ::std::optional<Location> location,
3192 TransposeOp::Adaptor adaptor,
3194 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3203 const auto inputRank = inputShape.
getRank();
3207 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
3213 if (inputRank == 0) {
3219 bool allTheSame =
true;
3220 for (
int i = 1, s = inputRank; i < s; i++) {
3230 outputShape.resize(inputRank, inputShape.
getDimSize(0));
3235 outputShape.resize(inputRank, ShapedType::kDynamic);
3238 if (llvm::any_of(adaptor.getPerms(),
3239 [inputRank](
const auto i) { return i >= inputRank; }))
3242 outputShape.reserve(inputRank);
3243 for (
int i = 0, s = inputRank; i < s; i++) {
3244 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
3251LogicalResult tosa::TransposeOp::verify() {
3263 if (inputShape.hasRank() &&
3264 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
3265 return emitOpError() <<
"expected perms attribute to have size "
3266 << inputShape.getRank()
3267 <<
" (input rank) but got size "
3268 << constantPerms.size();
3270 if (inputShape.hasRank() && outputShape.hasRank() &&
3271 inputShape.getRank() != outputShape.getRank())
3273 <<
"expected input tensor rank to equal result tensor rank";
3275 if (outputShape.hasRank() &&
3276 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
3277 return emitOpError() <<
"expected perms attribute to have size "
3278 << outputShape.getRank()
3279 <<
" (output rank) but got size "
3280 << constantPerms.size();
3282 if (!llvm::all_of(constantPerms,
3283 [&constantPerms](int32_t s) {
3285 static_cast<size_t>(s) < constantPerms.size();
3288 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
3289 return emitOpError() <<
"expected valid permutation indices";
3292 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3293 inputShape.getNumElements() != outputShape.getNumElements())
3294 return emitOpError() <<
"expected input1 and output to have same numbers "
3296 << inputShape.getNumElements() <<
" and "
3297 << outputShape.getNumElements();
3301 if (inputShape.hasRank() && outputShape.hasRank()) {
3302 for (
auto i = 0; i < outputShape.getRank(); i++) {
3303 if (inputShape.isDynamicDim(constantPerms[i]) ||
3304 outputShape.isDynamicDim(i))
3307 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3309 <<
"expected output tensor dim " << i <<
" to match "
3310 <<
"input dim " << constantPerms[i] <<
" with value of "
3311 << inputShape.getDimSize(constantPerms[i]);
3318LogicalResult TransposeOp::reifyResultShapes(
3321 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3323 Value input = getInput1();
3324 auto inputType = cast<TensorType>(input.
getType());
3326 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3327 for (
auto dim : transposePerms) {
3328 int32_t dimInInput = transposePerms[dim];
3329 if (inputType.isDynamicDim(dimInInput))
3331 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3335 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
3338 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3342LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3343 MLIRContext *context, ::std::optional<Location> location,
3344 GatherOp::Adaptor adaptor,
3345 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3346 llvm::SmallVector<int64_t> outputShape;
3347 outputShape.resize(3, ShapedType::kDynamic);
3349 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3350 if (valuesShape.hasRank()) {
3351 outputShape[0] = valuesShape.getDimSize(0);
3352 outputShape[2] = valuesShape.getDimSize(2);
3355 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3356 if (indicesShape.hasRank()) {
3357 if (outputShape[0] == ShapedType::kDynamic)
3358 outputShape[0] = indicesShape.getDimSize(0);
3359 if (outputShape[1] == ShapedType::kDynamic)
3360 outputShape[1] = indicesShape.getDimSize(1);
3363 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3367LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3368 MLIRContext *context, ::std::optional<Location> location,
3369 RowGatherBlockScaledOp::Adaptor adaptor,
3370 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3371 const auto values = adaptor.getValues();
3375 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3376 const ShapeAdaptor valuesShape(values.front().getType());
3377 if (valuesShape.hasRank()) {
3378 dataShape[0] = valuesShape.getDimSize(0);
3379 dataShape[2] = valuesShape.getDimSize(2);
3382 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3383 if (indicesShape.hasRank()) {
3384 if (dataShape[0] == ShapedType::kDynamic)
3385 dataShape[0] = indicesShape.getDimSize(0);
3388 succeeded(rowCount) && rowCount.value() > 0) {
3389 const int64_t indicesW = indicesShape.getDimSize(1);
3390 if (ShapedType::isStatic(indicesW))
3391 dataShape[1] = indicesW * rowCount.value();
3395 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3396 if (values.size() == 1)
3399 SmallVector<int64_t> scaleShape = dataShape;
3400 const uint32_t blockSize =
3401 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3402 if (ShapedType::isStatic(dataShape[2]))
3403 scaleShape[2] = dataShape[2] / blockSize;
3405 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3409LogicalResult tosa::GatherOp::verify() {
3416 const ShapeAdaptor valuesShape(getValues().
getType());
3418 const ShapeAdaptor outputShape(getOutput().
getType());
3420 int64_t n = ShapedType::kDynamic;
3421 int64_t w = ShapedType::kDynamic;
3422 int64_t c = ShapedType::kDynamic;
3424 if (valuesShape.hasRank()) {
3425 n = valuesShape.getDimSize(0);
3426 c = valuesShape.getDimSize(2);
3428 if (indicesShape.hasRank()) {
3429 const int64_t indicesN = indicesShape.getDimSize(0);
3430 w = indicesShape.getDimSize(1);
3431 if (n == ShapedType::kDynamic)
3433 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3434 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3435 <<
", got " << indicesN;
3437 if (outputShape.hasRank()) {
3438 const int64_t outputN = outputShape.getDimSize(0);
3439 const int64_t outputW = outputShape.getDimSize(1);
3440 const int64_t outputC = outputShape.getDimSize(2);
3441 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3443 return emitOpError() <<
"requires output dimension 0 to have size " << n
3444 <<
", got " << outputN;
3446 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3448 return emitOpError() <<
"requires output dimension 1 to have size " << w
3449 <<
", got " << outputW;
3450 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3452 return emitOpError() <<
"requires output dimension 2 to have size " << c
3453 <<
", got " << outputC;
3458LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3459 const OperandRange values = getValues();
3460 const ResultRange output = getOutput();
3461 if (values.empty() || values.size() > 2)
3463 <<
"expects values tensor list length to be 1 or 2, got "
3465 if (output.size() != values.size())
3467 <<
"expects output tensor list length to match values tensor list "
3469 << output.size() <<
" results for " << values.size()
3470 <<
" input tensors";
3472 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3473 if (values.size() == 1 && blockSize != 1)
3475 <<
"requires block_size to be BLOCK_SIZE_1 when values tensor list "
3477 if (values.size() == 2 && blockSize == 1)
3479 <<
"requires block_size to not be BLOCK_SIZE_1 when values tensor "
3483 output[0].
getType(),
"values[0]",
3488 "values[1]",
"output[1]")))
3492 succeeded(rowCount) && rowCount.value() <= 0)
3493 return emitOpError() <<
"requires row_count to be > 0, got "
3494 << rowCount.value();
3496 int64_t n = ShapedType::kDynamic;
3497 int64_t k = ShapedType::kDynamic;
3498 int64_t c = ShapedType::kDynamic;
3499 int64_t w = ShapedType::kDynamic;
3500 int64_t multiplesOfC = ShapedType::kDynamic;
3502 const ShapeAdaptor valuesDataShape(values[0].
getType());
3503 if (valuesDataShape.hasRank()) {
3504 n = valuesDataShape.getDimSize(0);
3505 k = valuesDataShape.getDimSize(1);
3506 c = valuesDataShape.getDimSize(2);
3509 if (ShapedType::isStatic(c) && c % blockSize != 0)
3510 return emitOpError() <<
"expects channels of values[0] (" << c
3511 <<
") to be divisible by block_size (" << blockSize
3515 if (indicesShape.hasRank()) {
3517 "indices",
"batch")))
3519 w = indicesShape.getDimSize(1);
3522 const ShapeAdaptor outputDataShape(output[0].
getType());
3523 if (outputDataShape.hasRank()) {
3525 "output[0]",
"batch")) ||
3527 "output[0]",
"channels")))
3531 succeeded(rowCount) && rowCount.value() > 0 &&
3532 ShapedType::isStatic(w)) {
3533 const int64_t expectedOutputRows = w * rowCount.value();
3534 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3535 outputDataShape.getDimSize(1) != expectedOutputRows)
3536 return emitOpError() <<
"requires output[0] dimension 1 to have size "
3537 << expectedOutputRows <<
", got "
3538 << outputDataShape.getDimSize(1);
3542 if (values.size() == 2) {
3543 const ShapeAdaptor valuesScaleShape(values[1].
getType());
3544 if (valuesScaleShape.hasRank()) {
3546 "values[1]",
"batch")) ||
3548 "values[1]",
"rows")))
3550 multiplesOfC = valuesScaleShape.getDimSize(2);
3553 const ShapeAdaptor outputScaleShape(output[1].
getType());
3554 if (outputScaleShape.hasRank()) {
3556 "output[1]",
"batch")))
3560 succeeded(rowCount) && rowCount.value() > 0 &&
3561 ShapedType::isStatic(w)) {
3562 const int64_t expectedOutputRows = w * rowCount.value();
3563 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3564 outputScaleShape.getDimSize(1) != expectedOutputRows)
3565 return emitOpError() <<
"requires output[1] dimension 1 to have size "
3566 << expectedOutputRows <<
", got "
3567 << outputScaleShape.getDimSize(1);
3570 if (ShapedType::isDynamic(multiplesOfC))
3571 multiplesOfC = outputScaleShape.getDimSize(2);
3572 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3573 multiplesOfC != outputScaleShape.getDimSize(2))
3575 <<
"expected channels of output[1] to match size "
3576 << multiplesOfC <<
", got " << outputScaleShape.getDimSize(2);
3579 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3580 multiplesOfC != c / blockSize)
3582 <<
"expects channels of scale tensors to equal C/block_size (" << c
3583 <<
"/" << blockSize <<
"), got " << multiplesOfC;
3589LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3590 MLIRContext *context, ::std::optional<Location> location,
3591 ResizeOp::Adaptor adaptor,
3592 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3593 llvm::SmallVector<int64_t, 4> outputShape;
3594 outputShape.resize(4, ShapedType::kDynamic);
3596 ShapeAdaptor inputShape(adaptor.getInput().getType());
3597 if (!inputShape.hasRank())
3600 outputShape[0] = inputShape.getDimSize(0);
3601 outputShape[3] = inputShape.getDimSize(3);
3602 int64_t inputHeight = inputShape.getDimSize(1);
3603 int64_t inputWidth = inputShape.getDimSize(2);
3605 if ((inputHeight == ShapedType::kDynamic) ||
3606 (inputWidth == ShapedType::kDynamic))
3609 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3620 const int64_t outputHeight =
3621 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3625 const int64_t outputWidth =
3626 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3630 if (outputHeight < 0 || outputWidth < 0) {
3633 "calculated output height and width must be non-negative, "
3635 outputHeight,
", width = ", outputWidth);
3638 outputShape[1] = outputHeight;
3639 outputShape[2] = outputWidth;
3640 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3644LogicalResult tosa::ResizeOp::verify() {
3645 const Value input = getInput();
3646 const Value output = getOutput();
3647 const RankedTensorType inputType =
3648 llvm::dyn_cast<RankedTensorType>(input.
getType());
3649 const RankedTensorType outputType =
3650 llvm::dyn_cast<RankedTensorType>(output.
getType());
3652 SmallVector<int64_t> scaleValues;
3653 SmallVector<int64_t> offsetValues;
3654 SmallVector<int64_t> borderValues;
3662 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3663 return emitOpError(
"expect all scale values to be > 0, got ")
3666 const int64_t scaleYN = scaleValues[0];
3667 const int64_t scaleYD = scaleValues[1];
3668 const int64_t scaleXN = scaleValues[2];
3669 const int64_t scaleXD = scaleValues[3];
3671 const int64_t offsetY = offsetValues[0];
3672 const int64_t offsetX = offsetValues[1];
3674 const int64_t borderY = borderValues[0];
3675 const int64_t borderX = borderValues[1];
3682 const int64_t oh = outputType.getDimSize(1);
3683 const int64_t ow = outputType.getDimSize(2);
3684 const int64_t ih = inputType.getDimSize(1);
3685 const int64_t iw = inputType.getDimSize(2);
3691 if (ih != ShapedType::kDynamic && ih != 1) {
3692 const std::optional<int64_t> calculatedOutHeightMinusOne =
3693 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3694 if (!calculatedOutHeightMinusOne.has_value())
3695 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3697 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3698 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3699 <<
") / " << scaleYD;
3700 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3701 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3702 return emitOpError(
"calculated output height did not match expected: ")
3703 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3710 if (iw != ShapedType::kDynamic && iw != 1) {
3711 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3712 const std::optional<int64_t> calculatedOutWidthMinusOne =
3714 if (!calculatedOutWidthMinusOne.has_value())
3715 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3717 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3718 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3719 <<
") / " << scaleXD;
3720 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3721 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3722 return emitOpError(
"calculated output width did not match expected: ")
3723 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3729LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3730 MLIRContext *context, ::std::optional<Location> location,
3731 ScatterOp::Adaptor adaptor,
3732 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3733 llvm::SmallVector<int64_t> outputShape;
3734 outputShape.resize(3, ShapedType::kDynamic);
3736 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3737 if (valuesInShape.hasRank()) {
3738 outputShape[0] = valuesInShape.getDimSize(0);
3739 outputShape[1] = valuesInShape.getDimSize(1);
3740 outputShape[2] = valuesInShape.getDimSize(2);
3743 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3744 if (indicesShape.hasRank()) {
3745 if (outputShape[0] == ShapedType::kDynamic)
3746 outputShape[0] = indicesShape.getDimSize(0);
3749 ShapeAdaptor inputShape(adaptor.getInput().getType());
3750 if (inputShape.hasRank()) {
3751 if (outputShape[0] == ShapedType::kDynamic)
3752 outputShape[0] = inputShape.getDimSize(0);
3753 if (outputShape[2] == ShapedType::kDynamic)
3754 outputShape[2] = inputShape.getDimSize(2);
3757 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3761LogicalResult tosa::ScatterOp::verify() {
3771 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3773 const ShapeAdaptor inputShape(getInput().
getType());
3774 const ShapeAdaptor outputShape(getValuesOut().
getType());
3776 int64_t n = ShapedType::kDynamic;
3777 int64_t k = ShapedType::kDynamic;
3778 int64_t w = ShapedType::kDynamic;
3779 int64_t c = ShapedType::kDynamic;
3780 if (valuesInShape.hasRank()) {
3781 n = valuesInShape.getDimSize(0);
3782 k = valuesInShape.getDimSize(1);
3783 c = valuesInShape.getDimSize(2);
3785 if (indicesShape.hasRank()) {
3786 const int64_t indicesN = indicesShape.getDimSize(0);
3787 w = indicesShape.getDimSize(1);
3788 if (n == ShapedType::kDynamic)
3790 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3791 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3792 <<
", got " << indicesN;
3794 if (inputShape.hasRank()) {
3795 const int64_t inputN = inputShape.getDimSize(0);
3796 const int64_t inputW = inputShape.getDimSize(1);
3797 const int64_t inputC = inputShape.getDimSize(2);
3798 if (n == ShapedType::kDynamic)
3800 else if (inputN != ShapedType::kDynamic && n != inputN)
3801 return emitOpError() <<
"requires input dimension 0 to have size " << n
3802 <<
", got " << inputN;
3803 if (w == ShapedType::kDynamic)
3805 else if (inputW != ShapedType::kDynamic && w != inputW)
3806 return emitOpError() <<
"requires input dimension 1 to have size " << w
3807 <<
", got " << inputW;
3809 if (c == ShapedType::kDynamic)
3811 else if (inputC != ShapedType::kDynamic && c != inputC)
3812 return emitOpError() <<
"requires input dimension 2 to have size " << c
3813 <<
", got " << inputC;
3815 if (outputShape.hasRank()) {
3816 const int64_t outputN = outputShape.getDimSize(0);
3817 const int64_t outputK = outputShape.getDimSize(1);
3818 const int64_t outputC = outputShape.getDimSize(2);
3819 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3821 return emitOpError() <<
"requires values_out dimension 0 to have size "
3822 << n <<
", got " << outputN;
3823 if (k == ShapedType::kDynamic)
3825 else if (outputK != ShapedType::kDynamic && k != outputK)
3826 return emitOpError() <<
"requires values_out dimension 1 to have size "
3827 << k <<
", got " << outputK;
3828 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3830 return emitOpError() <<
"requires values_out dimension 2 to have size "
3831 << c <<
", got " << outputC;
3833 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3834 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3843 int64_t axisVal = axis.getValue().getSExtValue();
3844 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3850 operandShape.
getDims(outputShape);
3851 outputShape[axisVal] = 1;
3856#define COMPATIBLE_RETURN_TYPES(OP) \
3857 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3858 if (l.size() != r.size() || l.size() != 1) \
3860 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3862 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3865#define REDUCE_SHAPE_INFER(OP) \
3866 LogicalResult OP::inferReturnTypeComponents( \
3867 MLIRContext *context, ::std::optional<Location> location, \
3868 OP::Adaptor adaptor, \
3869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3871 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3872 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3873 const Properties &prop = adaptor.getProperties(); \
3874 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3875 inferredReturnShapes); \
3877 COMPATIBLE_RETURN_TYPES(OP)
3885#undef REDUCE_SHAPE_INFER
3887#undef COMPATIBLE_RETURN_TYPES
3889template <
typename T>
3892 TensorType inputType = op.getInput().getType();
3893 TensorType outputType = op.getOutput().getType();
3894 int32_t reduceAxis = op.getAxis();
3896 if (reduceAxis < 0) {
3897 op.emitOpError(
"reduce axis must not be negative");
3901 int64_t inputRank = inputType.getRank();
3904 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3905 op.emitOpError(
"expect input tensor rank (")
3906 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3912 int64_t outputRank = outputType.getRank();
3913 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3915 "expect output tensor rank to be equal to input tensor rank");
3918 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3919 op.emitOpError(
"expect output tensor rank (")
3920 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3926 if (outputRank != 0) {
3927 auto outputShape = outputType.
getShape();
3928 if (!outputType.isDynamicDim(reduceAxis) &&
3929 outputShape[reduceAxis] != 1) {
3930 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3931 << outputShape[reduceAxis];
3939LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3940LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3941LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3942LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3943LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3944LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3958#define NARY_SHAPE_INFER(OP) \
3959 LogicalResult OP::inferReturnTypeComponents( \
3960 MLIRContext *context, ::std::optional<Location> location, \
3961 ValueShapeRange operands, DictionaryAttr attributes, \
3962 PropertyRef properties, RegionRange regions, \
3963 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3964 return NAryInferReturnTypes(operands, inferredReturnShapes); \
4004#undef PRED_SHAPE_INFER
4006LogicalResult tosa::NegateOp::inferReturnTypeComponents(
4007 MLIRContext *context, ::std::optional<Location> location,
4008 NegateOp::Adaptor adaptor,
4010 ShapeAdaptor inputShape(adaptor.getInput1().getType());
4015LogicalResult tosa::NegateOp::verify() {
4017 const Type input1Type = getInput1().getType();
4018 const Type outputType = getOutput().getType();
4023 const SmallVector<Type, 2> types = {input1Type, outputType};
4025 return emitOpError() <<
"requires the same shape for input1 and output";
4028 const Type input1ZpEType =
4030 if (input1EType != input1ZpEType) {
4031 return emitOpError(
"expect both input1 and its zero point are the same "
4032 "element type, got ")
4033 << input1EType <<
" and " << input1ZpEType;
4036 const Type outputZpEType =
4038 if (outputEType != outputZpEType) {
4039 return emitOpError(
"expect both output and its zero point are the same "
4040 "element type, got ")
4041 << outputEType <<
" and " << outputZpEType;
4044 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4045 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4048 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4049 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4060 outputShape.resize(4, ShapedType::kDynamic);
4075 if (ShapedType::isStatic(height)) {
4076 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4077 outputShape[1] = padded / stride[0] + 1;
4080 if (ShapedType::isStatic(width)) {
4081 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4082 outputShape[2] = padded / stride[1] + 1;
4089template <
typename AdaptorT>
4095 if (ShapedType::isDynamic(current))
4096 current = candidate;
4105 : adaptor(adaptor) {}
4109 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4117 outputShape[0] = outputBatch;
4118 inputSpatial[0] = inputHeight;
4119 inputSpatial[1] = inputWidth;
4124 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4132 outputShape[3] = outputChannels;
4133 weightSpatial[0] = kernelHeight;
4134 weightSpatial[1] = kernelWidth;
4143 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4144 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4145 dilationValues.assign(adaptor.getDilation().begin(),
4146 adaptor.getDilation().end());
4151 Conv2DOp::Adaptor adaptor;
4159 : adaptor(adaptor) {}
4163 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4164 if (inputDataShape.
hasRank()) {
4169 outputShape[0] = outputBatch;
4170 inputSpatial[0] = inputHeight;
4171 inputSpatial[1] = inputWidth;
4174 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4175 if (!inputScaleShape.
hasRank())
4189 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4190 if (weightDataShape.
hasRank()) {
4195 outputShape[3] = outputChannels;
4196 weightSpatial[0] = kernelHeight;
4197 weightSpatial[1] = kernelWidth;
4200 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4201 if (!weightScaleShape.
hasRank())
4230 Conv2DBlockScaledOp::Adaptor adaptor;
4238 : adaptor(adaptor) {}
4242 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4251 outputShape[0] = outputBatch;
4252 inputSpatial[0] = inputDepth;
4253 inputSpatial[1] = inputHeight;
4254 inputSpatial[2] = inputWidth;
4259 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4268 outputShape[4] = outputChannels;
4269 weightSpatial[0] = kernelDepth;
4270 weightSpatial[1] = kernelHeight;
4271 weightSpatial[2] = kernelWidth;
4280 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4281 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4282 dilationValues.assign(adaptor.getDilation().begin(),
4283 adaptor.getDilation().end());
4288 Conv3DOp::Adaptor adaptor;
4291template <
typename AdaptorT>
4297 ShapedType::kDynamic);
4299 ShapedType::kDynamic);
4301 ShapedType::kDynamic);
4303 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4304 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4306 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4309 if (biasSize != 1) {
4310 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4311 outputShape[outputChannelDim] =
4312 ShapedType::isDynamic(outputShape[outputChannelDim])
4314 : outputShape[outputChannelDim];
4321 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4327 for (
int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4328 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4329 !ShapedType::isStatic(weightSpatial[dim]))
4332 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4334 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4335 const int64_t unstridedResult = inputSize - filterSize + 1;
4336 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4343LogicalResult Conv2DOp::inferReturnTypeComponents(
4344 MLIRContext *context, ::std::optional<Location> location,
4345 Conv2DOp::Adaptor adaptor,
4346 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4350LogicalResult Conv2DOp::verify() {
4357LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4358 MLIRContext *context, ::std::optional<Location> location,
4359 Conv2DBlockScaledOp::Adaptor adaptor,
4360 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4364LogicalResult Conv2DBlockScaledOp::verify() {
4366 getWeightData().
getType(),
"input_data",
4369 getWeightScale().
getType(),
"input_scale",
4372 getOutput().
getType(),
"bias",
"output")))
4376 int64_t N = ShapedType::kDynamic;
4377 int64_t IH = ShapedType::kDynamic;
4378 int64_t IW = ShapedType::kDynamic;
4379 int64_t IC = ShapedType::kDynamic;
4380 int64_t multiplesOfIC = ShapedType::kDynamic;
4381 int64_t OC = ShapedType::kDynamic;
4382 int64_t KH = ShapedType::kDynamic;
4383 int64_t KW = ShapedType::kDynamic;
4385 const ShapeAdaptor inputDataShape(getInputData().
getType());
4386 if (inputDataShape.hasRank()) {
4387 N = inputDataShape.getDimSize(0);
4388 IH = inputDataShape.getDimSize(1);
4389 IW = inputDataShape.getDimSize(2);
4390 IC = inputDataShape.getDimSize(3);
4393 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
4394 if (inputScaleShape.hasRank()) {
4396 "input_scale",
"batch size")) ||
4398 "input_scale",
"input height")) ||
4400 "input_scale",
"input width")))
4402 multiplesOfIC = inputScaleShape.getDimSize(3);
4405 const ShapeAdaptor weightDataShape(getWeightData().
getType());
4406 if (weightDataShape.hasRank()) {
4407 OC = weightDataShape.getDimSize(0);
4408 KH = weightDataShape.getDimSize(1);
4409 KW = weightDataShape.getDimSize(2);
4411 "weight_data",
"input channels")))
4415 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
4416 if (weightScaleShape.hasRank()) {
4418 "weight_scale",
"output channels")) ||
4420 "weight_scale",
"kernel height")) ||
4422 "weight_scale",
"kernel width")) ||
4424 weightScaleShape.getDimSize(3),
4425 "weight_scale",
"input channel blocks")))
4429 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
4430 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4431 return emitOpError(
"expect block size to be 32, got ") << blockSize;
4433 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4434 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
4435 << IC <<
", block_size=" << blockSize;
4438 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4439 multiplesOfIC != IC / blockSize)
4441 "expect scale operands dimension 2 to equal IC/block_size (")
4442 << IC <<
"/" << blockSize <<
")"
4443 <<
", got " << multiplesOfIC;
4446 SmallVector<int64_t> padValues;
4448 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
4449 return emitOpError(
"expect all padding values to be >= 0, got ")
4453 SmallVector<int64_t> strideValues;
4455 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
4456 return emitOpError(
"expect all stride values to be >= 1, got ")
4460 SmallVector<int64_t> dilationValues;
4463 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
4464 return emitOpError(
"expect all dilation values to be >= 1, got ")
4469 const ShapeAdaptor outputShape(getOutput().
getType());
4470 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4471 outputShape.hasRank()) {
4473 padValues[0], padValues[1], strideValues[0],
4474 dilationValues[0],
"height",
"y",
"top",
4477 padValues[2], padValues[3], strideValues[1],
4478 dilationValues[1],
"width",
"x",
"left",
4484 const ShapeAdaptor biasShape(getBias().
getType());
4485 if (biasShape.hasRank() && outputShape.hasRank()) {
4486 const int64_t biasChannels = biasShape.getDimSize(0);
4487 const int64_t outputChannels =
4488 outputShape.getDimSize(outputShape.getRank() - 1);
4489 if (biasChannels == ShapedType::kDynamic ||
4490 outputChannels == ShapedType::kDynamic)
4494 if (biasChannels != outputChannels && biasChannels != 1)
4496 "bias channels expected to be equal to output channels (")
4497 << outputChannels <<
") or 1, got " << biasChannels;
4503LogicalResult Conv3DOp::inferReturnTypeComponents(
4504 MLIRContext *context, ::std::optional<Location> location,
4505 Conv3DOp::Adaptor adaptor,
4506 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4510LogicalResult Conv3DOp::verify() {
4517LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4518 MLIRContext *context, ::std::optional<Location> location,
4519 AvgPool2dOp::Adaptor adaptor,
4520 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4521 ShapeAdaptor inputShape(adaptor.getInput().getType());
4522 const Properties &prop = adaptor.getProperties();
4524 inferredReturnShapes);
4527LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4528 MLIRContext *context, ::std::optional<Location> location,
4529 AvgPool2dAdaptiveOp::Adaptor adaptor,
4530 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4531 ShapeAdaptor inputShape(adaptor.getInput().getType());
4533 llvm::SmallVector<int64_t> kernelValues;
4534 llvm::SmallVector<int64_t> strideValues;
4535 llvm::SmallVector<int64_t> padValues;
4542 padValues, inferredReturnShapes);
4545 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4546 if (inputShape.hasRank()) {
4548 outputShape[0] = inputShape.getDimSize(0);
4549 outputShape[3] = inputShape.getDimSize(3);
4552 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4556LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4557 MLIRContext *context, ::std::optional<Location> location,
4558 MaxPool2dOp::Adaptor adaptor,
4559 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4560 ShapeAdaptor inputShape(adaptor.getInput().getType());
4561 const Properties &prop = adaptor.getProperties();
4563 inferredReturnShapes);
4566LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4567 MLIRContext *context, ::std::optional<Location> location,
4568 MaxPool2dAdaptiveOp::Adaptor adaptor,
4569 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4570 ShapeAdaptor inputShape(adaptor.getInput().getType());
4572 llvm::SmallVector<int64_t> kernelValues;
4573 llvm::SmallVector<int64_t> strideValues;
4574 llvm::SmallVector<int64_t> padValues;
4581 padValues, inferredReturnShapes);
4584 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4585 if (inputShape.hasRank()) {
4586 outputShape[0] = inputShape.getDimSize(0);
4587 outputShape[3] = inputShape.getDimSize(3);
4589 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4593LogicalResult MaxPool2dOp::verify() {
4604LogicalResult MaxPool2dAdaptiveOp::verify() {
4609 AdaptivePoolingConstShapeValues values;
4613 values.pad, getInput(), getOutput())))
4619LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4620 MLIRContext *context, ::std::optional<Location> location,
4621 DepthwiseConv2DOp::Adaptor adaptor,
4622 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4623 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4625 int64_t inputWidth = ShapedType::kDynamic;
4626 int64_t inputHeight = ShapedType::kDynamic;
4627 int64_t inputChannels = ShapedType::kDynamic;
4629 int64_t weightWidth = ShapedType::kDynamic;
4630 int64_t weightHeight = ShapedType::kDynamic;
4631 int64_t depthChannels = ShapedType::kDynamic;
4634 ShapeAdaptor inputShape(adaptor.getInput().getType());
4635 if (inputShape.hasRank()) {
4636 outputShape[0] = inputShape.getDimSize(0);
4637 inputHeight = inputShape.getDimSize(1);
4638 inputWidth = inputShape.getDimSize(2);
4639 inputChannels = inputShape.getDimSize(3);
4643 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4644 if (weightShape.hasRank()) {
4645 weightHeight = weightShape.getDimSize(0);
4646 weightWidth = weightShape.getDimSize(1);
4647 inputChannels = ShapedType::isDynamic(inputChannels)
4648 ? weightShape.getDimSize(2)
4650 depthChannels = weightShape.getDimSize(3);
4655 if (ShapedType::isStatic(inputChannels) &&
4656 ShapedType::isStatic(depthChannels)) {
4657 outputShape[3] = inputChannels * depthChannels;
4661 ShapeAdaptor biasShape(adaptor.getBias().getType());
4662 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4663 int64_t bc = biasShape.getDimSize(0);
4664 if (bc != ShapedType::kDynamic && bc != 1)
4665 outputShape[3] = bc;
4668 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4669 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4670 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4672 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4673 int64_t inputSize = inputHeight + padding[0] + padding[1];
4674 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4675 int64_t unstridedResult = inputSize - filterSize + 1;
4676 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4679 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4680 int64_t inputSize = inputWidth + padding[2] + padding[3];
4681 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4682 int64_t unstridedResult = inputSize - filterSize + 1;
4683 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4686 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4690LogicalResult DepthwiseConv2DOp::verify() {
4697LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4698 MLIRContext *context, ::std::optional<Location> location,
4699 TransposeConv2DOp::Adaptor adaptor,
4700 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4701 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4703 int64_t inputWidth = ShapedType::kDynamic;
4704 int64_t inputHeight = ShapedType::kDynamic;
4705 int64_t weightWidth = ShapedType::kDynamic;
4706 int64_t weightHeight = ShapedType::kDynamic;
4709 ShapeAdaptor inputShape(adaptor.getInput().getType());
4710 if (inputShape.hasRank()) {
4711 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4712 ? inputShape.getDimSize(0)
4714 inputHeight = inputShape.getDimSize(1);
4715 inputWidth = inputShape.getDimSize(2);
4719 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4720 if (weightShape.hasRank()) {
4721 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4722 ? weightShape.getDimSize(0)
4724 weightHeight = weightShape.getDimSize(1);
4725 weightWidth = weightShape.getDimSize(2);
4729 ShapeAdaptor biasShape(adaptor.getBias().getType());
4730 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4731 int64_t bc = biasShape.getDimSize(0);
4732 if (bc != ShapedType::kDynamic && bc != 1)
4733 outputShape[3] = bc;
4736 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4737 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4739 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4740 int64_t calculateSize =
4741 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4743 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4746 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4747 int64_t calculateSize =
4748 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4750 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4753 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4757LogicalResult TransposeConv2DOp::verify() {
4761 const llvm::ArrayRef<int64_t> strides = getStride();
4762 const int64_t strideY = strides[0];
4763 const int64_t strideX = strides[1];
4765 if (strideY < 1 || strideX < 1)
4766 return emitOpError(
"expect all stride values to be >= 1, got [")
4769 const auto checkPadAgainstKernelDim =
4770 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4771 llvm::StringRef kernelDimName) -> LogicalResult {
4772 if (padValue <= -kernelDimSize)
4774 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4775 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4780 const llvm::ArrayRef<int64_t> padding = getOutPad();
4781 const int64_t outPadTop = padding[0];
4782 const int64_t outPadBottom = padding[1];
4783 const int64_t outPadLeft = padding[2];
4784 const int64_t outPadRight = padding[3];
4786 const auto weightType =
4787 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4790 const int64_t kernelHeight = weightType.getDimSize(1);
4791 if (ShapedType::isStatic(kernelHeight)) {
4792 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4793 "out_pad_top",
"KH")))
4796 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4797 "out_pad_bottom",
"KH")))
4801 const int64_t kernelWidth = weightType.getDimSize(2);
4802 if (ShapedType::isStatic(kernelWidth)) {
4803 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4804 "out_pad_left",
"KW")))
4807 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4808 "out_pad_right",
"KW")))
4814 const auto outputType =
4815 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4819 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4820 if (inputType && weightType) {
4821 const int64_t inputHeight = inputType.getDimSize(1);
4822 const int64_t kernelHeight = weightType.getDimSize(1);
4823 const int64_t outputHeight = outputType.getDimSize(1);
4825 if (ShapedType::isStatic(inputHeight) &&
4826 ShapedType::isStatic(outputHeight)) {
4828 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4830 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4831 "+ out_pad_top + out_pad_bottom + KH, but got ")
4832 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4833 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4834 <<
" + " << kernelHeight;
4837 const int64_t inputWidth = inputType.getDimSize(2);
4838 const int64_t kernelWidth = weightType.getDimSize(2);
4839 const int64_t outputWidth = outputType.getDimSize(2);
4841 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4843 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4845 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4846 "+ out_pad_left + out_pad_right + KW, but got ")
4847 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4848 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4853 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4858 const int64_t biasChannels = biasType.getDimSize(0);
4861 if (biasChannels == ShapedType::kDynamic)
4864 const int64_t outputChannels = outputType.getDimSize(3);
4865 if (!ShapedType::isDynamic(outputChannels) &&
4866 biasChannels != outputChannels && biasChannels != 1)
4868 "bias channels expected to be equal to output channels (")
4869 << outputChannels <<
") or 1, got " << biasChannels;
4874LogicalResult RescaleOp::verify() {
4875 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
4876 auto inputElementType =
4878 if (!mlir::isa<IntegerType>(inputElementType)) {
4879 emitOpError(
"expect input to have integer element type, got ")
4880 << inputElementType;
4884 const auto outputType = llvm::cast<ShapedType>(getOutput().
getType());
4885 auto outputElementType =
4887 if (!mlir::isa<IntegerType>(outputElementType)) {
4888 emitOpError(
"expect output to have integer element type, got ")
4889 << outputElementType;
4901 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4902 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4905 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4906 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4909 const auto multiplierType = llvm::cast<ShapedType>(getMultiplier().
getType());
4911 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4912 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4913 << multiplierType.getElementType();
4918 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4920 "expect i16 element type for multiplier for scale32=false, got ")
4921 << multiplierType.getElementType();
4925 if (!inputType.hasRank())
4931 int64_t numChannels = 1;
4932 if (getPerChannel()) {
4933 if (inputType.getRank() < 1) {
4934 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4935 "true, but got rank ")
4936 << inputType.getRank();
4939 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4942 if (outputType.hasRank()) {
4944 getOperation(), outputType, inputType.getShape())))
4948 if (multiplierType.hasRank()) {
4949 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4951 if (multiplierShape[0] != ShapedType::kDynamic &&
4952 multiplierShape[0] != numChannels) {
4954 << numChannels <<
" } for multiplier input, got { "
4955 << multiplierShape[0] <<
" }";
4960 const auto shiftType = llvm::cast<ShapedType>(getShift().
getType());
4961 if (shiftType.hasRank()) {
4962 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4964 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4966 << numChannels <<
" } for shift input, got { " << shiftShape[0]
4975LogicalResult RescaleOp::inferReturnTypeComponents(
4976 MLIRContext *context, ::std::optional<Location> location,
4977 RescaleOp::Adaptor adaptor,
4978 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4979 ShapeAdaptor inputShape(adaptor.getInput().getType());
4980 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4984LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4985 MLIRContext *context, ::std::optional<Location> location,
4986 CastFromBlockScaledOp::Adaptor adaptor,
4987 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4988 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4989 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4993LogicalResult CastFromBlockScaledOp::verify() {
4994 const Type inputDataType = getInputData().getType();
4995 const Type outputDataType = getResult().getType();
4997 return emitOpError() <<
"require compatible shapes for input_data ("
4998 << inputDataType <<
") and " <<
"output_data ("
4999 << outputDataType <<
")";
5001 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5003 if (inputDataShape.
hasRank()) {
5004 const unsigned int blockSize =
5006 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5007 return emitOpError(
"expect block size to be 32, got ") << blockSize;
5008 const int64_t inputDataLastDim =
5010 if (inputDataLastDim % blockSize != 0)
5011 return emitOpError() <<
"expect last dimension of input_data ("
5013 <<
") to be divisible by block_size (" << blockSize
5016 const Type inputScaleType = getInputScale().getType();
5017 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5019 if (inputScaleShape.
hasRank()) {
5020 SmallVector<int64_t> inputDataDims, inputScaleDims;
5021 inputDataShape.
getDims(inputDataDims);
5022 inputScaleShape.
getDims(inputScaleDims);
5024 if (inputDataDims.size() != inputScaleDims.size() ||
5026 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5027 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5029 <<
"require compatible shapes for input_data (" << inputDataType
5030 <<
") and " <<
"input_scale (" << inputScaleType
5031 <<
") except for the last dimension";
5033 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5034 inputScaleDims.back()};
5035 if (ShapedType::isStatic(inputDataLastDim) &&
5038 <<
"expect last dimension of input_scale ("
5039 << inputScaleDims.back()
5040 <<
") to be equal to last dimension of input_data / block_size ("
5041 << inputDataDims.back() / blockSize <<
")";
5048LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5049 MLIRContext *context, ::std::optional<Location> location,
5050 CastToBlockScaledOp::Adaptor adaptor,
5051 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5052 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5053 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5054 if (!inputShape.hasRank())
5058 SmallVector<int64_t> outputScaleShape;
5059 inputShape.getDims(outputScaleShape);
5060 const int64_t lastDimLoc = inputShape.getRank() - 1;
5061 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5062 if (ShapedType::isStatic(lastDimSize)) {
5063 const unsigned int blockSize =
5064 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5065 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5067 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5071LogicalResult CastToBlockScaledOp::verify() {
5072 const Type inputDataType = getInputData().getType();
5073 const Type outputDataType = getResult(0).getType();
5075 return emitOpError() <<
"require compatible shapes for input_data ("
5076 << inputDataType <<
") and " <<
"output_data ("
5077 << outputDataType <<
")";
5079 const unsigned int blockSize =
5081 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5082 return emitOpError(
"expect block size to be 32, got ") << blockSize;
5083 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5084 if (inputDataShape.
hasRank()) {
5085 const int64_t inputDataLastDim =
5087 if (ShapedType::isStatic(inputDataLastDim) &&
5088 inputDataLastDim % blockSize != 0)
5089 return emitOpError() <<
"expect last dimension of input_data ("
5091 <<
") to be divisible by block_size (" << blockSize
5095 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5096 const Type outputScaleType = getResult(1).getType();
5097 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5099 SmallVector<int64_t> outputDataDims, outputScaleDims;
5100 outputDataShape.
getDims(outputDataDims);
5101 outputScaleShape.
getDims(outputScaleDims);
5103 if (outputDataDims.size() != outputScaleDims.size() ||
5105 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5106 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5107 return emitOpError() <<
"require compatible shapes for output_data ("
5108 << outputDataType <<
") and " <<
"output_scale ("
5110 <<
") except for the last dimension";
5112 const int64_t outputDataLastDim = outputDataDims.back();
5113 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5114 outputScaleDims.back()};
5115 if (ShapedType::isStatic(outputDataLastDim) &&
5118 <<
"expect last dimension of output_scale ("
5119 << outputScaleDims.back()
5120 <<
") to be equal to last dimension of output_data / block_size ("
5121 << outputDataDims.back() / blockSize <<
")";
5127LogicalResult IfOp::inferReturnTypeComponents(
5128 MLIRContext *context, ::std::optional<Location> location,
5129 IfOp::Adaptor adaptor,
5130 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5131 llvm::SmallVector<tosa::YieldOp> yieldOps;
5132 for (Region *region : adaptor.getRegions()) {
5133 for (
auto &block : *region)
5134 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5135 yieldOps.push_back(returnOp);
5138 if (yieldOps.empty())
5142 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5143 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5144 for (
auto operand : yieldOps.front().getOperands()) {
5145 resultKnowledge.push_back(
5149 for (
auto yieldOp : yieldOps) {
5150 if (resultKnowledge.size() != yieldOp.getNumOperands())
5153 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5154 int32_t index = it.index();
5156 resultKnowledge[index],
5160 resultKnowledge[index] = meet;
5164 for (
const ValueKnowledge &
result : resultKnowledge) {
5165 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5171LogicalResult WhileOp::inferReturnTypeComponents(
5172 MLIRContext *context, ::std::optional<Location> location,
5173 WhileOp::Adaptor adaptor,
5174 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5175 llvm::SmallVector<tosa::YieldOp> yieldOps;
5176 for (
auto &block : adaptor.getBodyGraph())
5177 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5178 yieldOps.push_back(returnOp);
5182 if (yieldOps.empty())
5186 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5187 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5188 for (
auto operand : yieldOps.front().getOperands()) {
5189 resultKnowledge.push_back(
5193 for (
auto yieldOp : yieldOps) {
5194 if (resultKnowledge.size() != yieldOp.getNumOperands())
5197 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5198 int32_t index = it.index();
5200 resultKnowledge[index],
5202 resultKnowledge[index] = meet;
5207 for (
const ValueKnowledge &
result : resultKnowledge) {
5208 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5214std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5215 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
5216 return llvm::to_vector<4>(vt.getShape());
5217 return std::nullopt;
5223 StringRef prefix =
"") {
5224 assert(blocksArgs.size() == initializers.size() &&
5225 "expected same length of arguments and initializers");
5226 if (initializers.empty())
5229 parser << prefix <<
'(';
5230 llvm::interleaveComma(
5231 llvm::zip(blocksArgs, initializers), parser,
5232 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
5237ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
5239 result.regions.reserve(2);
5240 Region *thenRegion =
result.addRegion();
5241 Region *elseRegion =
result.addRegion();
5243 OpAsmParser::UnresolvedOperand cond;
5248 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5249 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5252 OptionalParseResult listResult =
5260 "expected type for condition operand");
5266 "expected type for condition operand");
5274 FunctionType functionType;
5278 <<
"expected list of types for block arguments "
5279 <<
"followed by arrow type and list of return types";
5281 result.addTypes(functionType.getResults());
5283 if (functionType.getNumInputs() != operands.size()) {
5285 <<
"expected as many input types as operands " <<
"(expected "
5286 << operands.size() <<
" got " << functionType.getNumInputs()
5317void IfOp::print(OpAsmPrinter &p) {
5318 p <<
" " << getCondition();
5321 getInputList(),
" ");
5323 p << getCondition().getType();
5325 if (!getInputList().empty()) {
5327 llvm::interleaveComma(getInputList().getTypes(), p);
5336 auto &elseRegion = getElseGraph();
5337 if (!elseRegion.
empty()) {
5345LogicalResult IfOp::verify() {
5347 "'then_graph' arguments", getInputList(),
5353 "'else_graph' arguments", getInputList(),
5359 if (getThenGraph().front().mightHaveTerminator()) {
5361 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5363 *
this, thenYield.getInputs(),
"'then_graph' results",
5364 getOutputList(),
"'output_list'")
5370 if (getElseGraph().front().mightHaveTerminator()) {
5372 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5374 *
this, elseYield.getInputs(),
"'else_graph' results",
5375 getOutputList(),
"'output_list'")
5380 auto condType = getCondition().getType();
5382 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
5388LogicalResult WhileOp::verify() {
5390 getOutputList(),
"'output_list'")
5395 "'cond_graph' arguments", getInputList(),
5401 "'body_graph' arguments", getInputList(),
5406 if (getBodyGraph().front().mightHaveTerminator()) {
5408 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5410 "'body_graph' results",
5411 getInputList(),
"'input_list'")
5418 if (!getCondGraph().front().mightHaveTerminator())
5422 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5426 if (condYield.getInputs().size() != 1)
5427 return emitOpError() <<
"require 'cond_graph' only have one result";
5429 auto condOutType = condYield.getInputs()[0].getType();
5431 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
5435 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
5441LogicalResult ReverseOp::verify() {
5446 TensorType inputType = getInput1().getType();
5447 TensorType outputType = getOutput().getType();
5448 int32_t reverseAxis = getAxis();
5450 if (reverseAxis < 0)
5451 return emitOpError(
"expected non-negative reverse axis");
5453 int64_t inputRank = inputType.getRank();
5456 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5458 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
5462 int64_t outputRank = outputType.getRank();
5463 if (inputType.
hasRank() && outputRank != inputType.getRank())
5465 "expect output tensor rank to be equal to input tensor rank");
5466 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
5468 << outputRank <<
") to be larger than reverse axis ("
5469 << reverseAxis <<
")";
5474LogicalResult tosa::SelectOp::verify() {
5485 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
5486 if (!predicateType) {
5487 return emitOpError(
"expect shaped tensor for input1, got ")
5488 << getInput1().getType();
5490 auto predicateElementType = predicateType.getElementType();
5491 if (!predicateElementType.isInteger(1)) {
5492 return emitOpError(
"expect element type of bool for input1, got ")
5493 << predicateElementType;
5499LogicalResult tosa::VariableReadOp::verify() {
5507LogicalResult tosa::VariableWriteOp::verify() {
5516ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
5517 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5518 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5519 Region *cond =
result.addRegion();
5520 Region *body =
result.addRegion();
5522 OptionalParseResult listResult =
5527 FunctionType functionType;
5532 result.addTypes(functionType.getResults());
5534 if (functionType.getNumInputs() != operands.size()) {
5536 <<
"expected as many input types as operands " <<
"(expected "
5537 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
5547 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
5548 regionArgs[i].type = functionType.getInput(i);
5550 return failure(parser.
parseRegion(*cond, regionArgs) ||
5555void WhileOp::print(OpAsmPrinter &parser) {
5557 getInputList(),
" ");
5560 getResults().getTypes());
5574 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5575 if (llvm::isa<FloatType>(srcElemType)) {
5577 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
5578 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5580 if (llvm::isa<IntegerType>(srcElemType)) {
5583 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5585 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
5586 return std::nullopt;
5594 return mlir::isa<tosa::shapeType>(t);
5601 return emitError() <<
"invalid rank (must be >= 0): " << rank;
5607 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5608 Operation *definingOp = v.getDefiningOp();
5610 return op->
emitOpError(
"shape operand is not compile time resolvable");
5623 auto getRank = [](
const Type type) {
5624 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5630 for (
auto type : operandTypes) {
5631 if (getRank(type) != rank) {
5632 return op->
emitOpError(
"operands don't have matching ranks");
5635 for (
auto type : resultTypes) {
5636 if (getRank(type) != rank) {
5637 return op->
emitOpError(
"result shape has different rank than operands");
5647LogicalResult tosa::ConstShapeOp::verify() {
5649 auto valuesRank = getValues().getType().getRank();
5650 if (valuesRank != 1)
5651 return emitOpError(
"expect elements in attribute values with rank 1");
5653 auto count = getValues().getNumElements();
5654 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
5655 if (count != rank && (count != 1 || rank != 0)) {
5656 return emitOpError(
"expect number of elements in attribute values (")
5657 << count <<
") to be equal to the rank (" << rank
5658 <<
") for the result shape type";
5663LogicalResult tosa::DimOp::verify() {
5664 const tosa::shapeType outShapeType =
5665 cast<tosa::shapeType>(getResult().
getType());
5666 if (outShapeType.getRank() != 1)
5667 return emitOpError(
"expect output shape type to contain one element, got ")
5672 const int64_t inputRank = inputType.getRank();
5673 const int64_t axis = getAxisAttr().getInt();
5674 if (axis < 0 || axis >= inputRank)
5675 return emitOpError(
"expect axis to be in the range [0, ")
5676 << inputRank <<
"), got " << axis;
5681LogicalResult tosa::ConcatShapeOp::verify() {
5682 const tosa::shapeType outShapeType =
5683 cast<tosa::shapeType>(getResult().
getType());
5684 const int64_t outputRank = outShapeType.getRank();
5687 if (inputList.size() == 0)
5688 return emitOpError(
"requires at least one input shape");
5690 if (llvm::any_of(inputList, [](Value v) {
5691 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
5693 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
5695 const int64_t inputsRank =
5696 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
5697 const tosa::shapeType inShapeType =
5698 cast<tosa::shapeType>(input.
getType());
5699 return acc + inShapeType.getRank();
5701 if (outputRank != inputsRank)
5702 return emitOpError(
"requires output shape rank to be equal to the sum of "
5703 "the input shape ranks (")
5704 << inputsRank <<
"), got " << outputRank;
5709LogicalResult tosa::SliceShapeOp::verify() {
5710 std::optional<int32_t> start;
5711 DenseIntElementsAttr startAttr;
5713 start = startAttr.getValues<int32_t>()[0];
5714 if (start && start.value() < 0)
5715 return emitOpError(
"expected non-negative start index, got ")
5718 std::optional<int32_t> size;
5719 DenseIntElementsAttr sizeAttr;
5721 size = sizeAttr.getValues<int32_t>()[0];
5722 if (size && size.value() <= 0)
5723 return emitOpError(
"expected positive size, got ") << size.value();
5728 const tosa::shapeType outShapeType =
5729 cast<tosa::shapeType>(getResult().
getType());
5730 const int64_t outputRank = outShapeType.getRank();
5731 if (outputRank != size)
5733 "expected output type size to be equal to size attribute, got ")
5734 << outputRank <<
" vs " << size.value();
5739 const tosa::shapeType inShapeType =
5740 cast<tosa::shapeType>(getInput().
getType());
5741 const int64_t inputRank = inShapeType.getRank();
5742 const int64_t sliceSize = start.value() + size.value();
5743 if (sliceSize > inputRank)
5744 return emitOpError(
"expected start + size to be less than or equal to "
5745 "input shape rank (")
5746 << inputRank <<
"), got " << sliceSize;
5755#define GET_ATTRDEF_CLASSES
5756#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5761#define GET_TYPEDEF_CLASSES
5762#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5768#define GET_OP_CLASSES
5769#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static LogicalResult verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType, ArrayRef< int64_t > expectedShape, StringRef outputName="output")
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > getConstantScalarIntValue(Value val)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static void updateIfDynamic(int64_t ¤t, int64_t candidate)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
int64_t getOutputRank() const
int64_t getNumSpatialDims() const
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
int64_t getNumSpatialDims() const
int64_t getOutputRank() const
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
int64_t getNumSpatialDims() const
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
int64_t getOutputRank() const
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Operation is the basic unit of execution within MLIR.
ResultRange result_range
Support result iteration.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)