29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
57struct TosaInlinerInterface :
public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
66 IRMapping &map)
const final {
72 IRMapping &map)
const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
79struct TosaDialectBytecodeInterface :
public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
86 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer)
const override {
92 return ::writeAttribute(attr, writer);
98 Type readType(DialectBytecodeReader &reader)
const override {
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer)
const override {
104 return ::writeType(type, writer);
107 void writeVersion(DialectBytecodeWriter &writer)
const final {
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader)
const final {
114 reader.
emitError(
"Dialect does not support versioning");
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version)
const final {
132 return {&getBodyGraph()};
141 return dim == -1 ? ShapedType::kDynamic : dim;
147 Type elementType = variableOp.getType();
150 return RankedTensorType::get(
shape, elementType);
157void TosaDialect::initialize() {
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
201ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
203 TypeAttr &typeAttr) {
204 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
207 <<
"expected ranked type";
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
217 <<
"expected shaped type";
234 <<
"expected attribute";
236 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
241 <<
"expected Typed attr";
244 initialValueAttr =
nullptr;
248 <<
"expected type after colon";
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
255 TypeAttr typeAttr,
Attribute initialValueAttr) {
256 bool needsSpace =
false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
268 if (initialValueAttr) {
279template <
typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
282 llvm::StringRef name;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name ==
"rounding_mode" &&
292 auto sym = symbolizeRoundingMode(kw);
295 <<
"invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
304 auto sym = symbolizeResizeMode(kw);
307 <<
"invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
317 auto sym = symbolizeNanPropagationMode(kw);
320 <<
"invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
330 auto sym = symbolizeBlockSize(kw);
333 <<
"invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
346template <
typename EnumType>
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
383 parser << namedAttr.
getName().strref() <<
" = ";
385 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 }
else if (
auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
408 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
411 toPrint.erase(attr.getName());
417 if (!toPrint.empty()) {
419 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
436 llvm::interleaveComma(op->
getAttrs(), parser,
438 printNamedAttr(parser, namedAttr);
450 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
454 printWithEnumHandling(parser, *
this);
458 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
462 printWithEnumHandling(parser, *
this);
466 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
470 printWithEnumHandling(parser, *
this);
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
478 printWithNanPropagationHandling(parser, *
this);
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
486 printWithNanPropagationHandling(parser, *
this);
489ParseResult MaxPool2dAdaptiveOp::parse(
OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
495 printWithNanPropagationHandling(parser, *
this);
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
503 printWithNanPropagationHandling(parser, *
this);
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
511 printWithNanPropagationHandling(parser, *
this);
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
519 printWithNanPropagationHandling(parser, *
this);
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
527 printWithNanPropagationHandling(parser, *
this);
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
535 printWithNanPropagationHandling(parser, *
this);
538ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
544 printWithEnumHandling(parser, *
this);
547ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *
this);
556ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
562 printWithEnumHandling(parser, *
this);
565ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
571 printWithEnumHandling(parser, *
this);
586 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
596 Value valZp, StringRef name) {
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
605 if (!bothInts || !sameBitWidth) {
607 <<
"expected " << name <<
" and " << name
608 <<
"_zp to both be integer of the same bitwidth, but got " << eType
609 <<
" vs. " << eZpType;
616 Value src, int32_t val) {
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
631 if (dyn_cast<tosa::mxint8Type>(type))
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
645 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
647 << dimName <<
" of " << operandName <<
" to match size " << currDim
648 <<
", got " << newDim;
655 auto printDim = [&](
int64_t dim) {
656 if (ShapedType::isDynamic(dim))
662 llvm::interleaveComma(
shape,
diag, printDim);
668 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
669 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
670 const llvm::StringRef padAfterName) {
671 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
676 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
677 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
679 if (!calculatedOutSizeMinusOne.has_value())
681 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
682 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
683 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
684 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
685 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
688 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
689 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
691 << dimName <<
" did not match expected: "
692 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
703 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
704 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
706 auto inputEType = inputType.getElementType();
707 auto weightEType = weightType.getElementType();
709 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
711 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
712 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
713 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
715 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
718 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
721 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
724 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
727 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
731 "expect both bias and result to have same element type, got ")
732 << biasEType <<
" and " << resultEType;
736 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
737 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
738 if (inputEType != weightEType) {
740 "expect both input and weight to have same element type, got ")
741 << inputEType <<
" and " << weightEType;
746 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
747 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
750 if (inputIsFloat != weightIsFloat) {
752 "expect both input and weight to be float or not together, got ")
753 << inputEType <<
" and " << weightEType;
758 if (inputEType != inputZpEType) {
759 return op.emitOpError(
"expect both input and its zero point are the same "
760 "element type, got ")
761 << inputEType <<
" and " << inputZpEType;
765 if (weightEType != weightZpEType) {
766 return op.emitOpError(
"expect both weight and its zero point are the same "
767 "element type, got ")
768 << weightEType <<
" and " << weightZpEType;
771 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
772 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
775 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
776 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
782LogicalResult tosa::ConstOp::verify() {
784 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
785 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
787 if (!attrType || !outputType) {
788 emitOpError(
"expected tensors for attr/result type");
792 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
793 outputType.getElementType())) {
798 if (attrType.getElementType() != outputType.getElementType()) {
799 emitOpError(
"expected same attr/result element types");
809 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
811 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
814 auto accType = op.getAccType();
815 if (inputEType.isInteger(8) && !accType.isInteger(32))
816 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
819 if (inputEType.isInteger(16) && !accType.isInteger(48))
820 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
823 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
824 !(accType.isF16() || accType.isF32()))
825 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
828 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
829 return op.emitOpError(
830 "accumulator type for f16 tensor is not f16/f32, got ")
833 if (inputEType.isBF16() && !accType.isF32())
834 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
837 if (inputEType.isF32() && !accType.isF32())
838 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
842 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
844 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
858 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
859 return op.emitOpError(
"expect all padding values to be >= 0, got ")
863 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
864 return op.emitOpError(
"expect all stride values to be >= 1, got ")
868 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
869 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
872 const RankedTensorType outputType =
873 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
878 const RankedTensorType inputType =
879 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
880 const RankedTensorType weightType =
881 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
883 if (inputType && weightType) {
885 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
887 op, inputType.getDimSize(1), weightType.getDimSize(1),
888 outputType.getDimSize(1), padding[0], padding[1], strides[0],
889 dilations[0],
"height",
"y",
"top",
"bottom")))
893 op, inputType.getDimSize(2), weightType.getDimSize(2),
894 outputType.getDimSize(2), padding[2], padding[3], strides[1],
895 dilations[1],
"width",
"x",
"left",
"right")))
900 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
902 op, inputType.getDimSize(1), weightType.getDimSize(0),
903 outputType.getDimSize(1), padding[0], padding[1], strides[0],
904 dilations[0],
"height",
"y",
"top",
"bottom")))
908 op, inputType.getDimSize(2), weightType.getDimSize(1),
909 outputType.getDimSize(2), padding[2], padding[3], strides[1],
910 dilations[1],
"width",
"x",
"left",
"right")))
915 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
917 op, inputType.getDimSize(1), weightType.getDimSize(1),
918 outputType.getDimSize(1), padding[0], padding[1], strides[0],
919 dilations[0],
"depth",
"d",
"front",
"back")))
923 op, inputType.getDimSize(2), weightType.getDimSize(2),
924 outputType.getDimSize(2), padding[2], padding[3], strides[1],
925 dilations[1],
"height",
"y",
"top",
"bottom")))
929 op, inputType.getDimSize(3), weightType.getDimSize(3),
930 outputType.getDimSize(3), padding[4], padding[5], strides[2],
931 dilations[2],
"width",
"x",
"left",
"right")))
936 const RankedTensorType biasType =
937 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
942 const int64_t biasChannels = biasType.getDimSize(0);
944 outputType.getDimSize(outputType.getRank() - 1);
945 if (biasChannels == ShapedType::kDynamic ||
946 outputChannels == ShapedType::kDynamic)
950 if (biasChannels != outputChannels && biasChannels != 1)
951 return op.emitOpError(
952 "bias channels expected to be equal to output channels (")
953 << outputChannels <<
") or 1, got " << biasChannels;
960 StringRef name1,
Type type2,
962 auto shapeType1 = dyn_cast<ShapedType>(type1);
963 auto shapeType2 = dyn_cast<ShapedType>(type2);
964 if (!shapeType1 || !shapeType2)
967 auto elemType1 = shapeType1.getElementType();
968 auto elemType2 = shapeType2.getElementType();
969 if (elemType1 != elemType2)
971 <<
"require same element type for " << name1 <<
" (" << elemType1
972 <<
") and " << name2 <<
" (" << elemType2 <<
")";
976 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
977 << name2 <<
" (" << type2 <<
")";
987 if (list1.size() != list2.size())
989 <<
"require same number of values in " << name1 <<
" ("
990 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
992 for (
auto [type1, type2] :
1009template <
typename T>
1012 op->template getParentWithTrait<OpTrait::SymbolTable>();
1019 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1023 return op->emitOpError(
"'")
1024 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1036template <
typename T>
1038 StringRef aName =
"input",
1039 StringRef bName =
"output") {
1040 auto aTType = llvm::dyn_cast<TensorType>(aType);
1041 auto bTType = llvm::dyn_cast<TensorType>(bType);
1043 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1047 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1050 auto aElementType = aTType.getElementType();
1051 auto bElementType = bTType.getElementType();
1053 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1055 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1056 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1057 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1058 aElementType != bElementType) {
1063 op.emitOpError(
"expect ")
1064 << aName <<
" and " << bName <<
" to have same element type, got "
1065 << aElementType <<
" and " << bElementType;
1071LogicalResult tosa::ArgMaxOp::verify() {
1072 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1075 if (
const auto resultETy = resultType.getElementType();
1076 !resultETy.isIntOrIndex())
1077 return emitOpError(
"result tensor is not of integer type");
1079 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1080 if (!inputType.hasRank())
1084 const int64_t axis = getAxisAttr().getInt();
1085 if (((axis < 0) || axis >= inputType.getRank()))
1086 return emitOpError(
"specified axis is outside the rank of the tensor");
1088 if (!resultType.hasRank())
1094 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1097 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1107 const bool hasKernel = kernel.size() > 0;
1108 const bool hasStrides = strides.size() > 0;
1109 const bool hasPad = padding.size() > 0;
1111 if (hasKernel && llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1112 return op->
emitOpError(
"expect all kernel values to be >= 1, got ")
1115 if (hasStrides && llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1116 return op->
emitOpError(
"expect all stride values to be >= 1, got ")
1119 if (hasPad && llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1120 return op->
emitOpError(
"expect all padding values to be >= 0, got ")
1123 if (hasKernel && hasPad) {
1125 const int64_t kernelX = kernel[1];
1126 const int64_t padLeft = padding[2];
1127 const int64_t padRight = padding[3];
1128 if (padRight >= kernelX || padLeft >= kernelX)
1129 return op->
emitOpError(
"expected left/right padding to be less than the "
1130 "width of the kernel, got pad_left=")
1131 << padLeft <<
", pad_right=" << padRight
1132 <<
", kernel_x=" << kernelX;
1134 const int64_t kernelY = kernel[0];
1135 const int64_t padTop = padding[0];
1136 const int64_t padBottom = padding[1];
1137 if (padTop >= kernelY || padBottom >= kernelY)
1138 return op->
emitOpError(
"expected top/bottom padding to be less than the "
1139 "height of the kernel, got pad_top=")
1140 << padTop <<
", pad_bottom=" << padBottom
1141 <<
", kernel_y=" << kernelY;
1144 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.
getType());
1145 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.
getType());
1146 if (!inputType || !outputType)
1149 if (hasKernel && hasStrides && hasPad) {
1150 const auto verifyOutputSize =
1154 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1155 const llvm::StringRef padBeforeName,
1156 const llvm::StringRef padAfterName) -> LogicalResult {
1157 if (ShapedType::isDynamic(inputSize))
1160 const std::optional<int64_t> calculatedOutSizeMinusOne =
1161 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1162 if (!calculatedOutSizeMinusOne.has_value())
1164 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1165 << padAfterName <<
" - kernel_" << dimAxis
1166 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1167 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1168 << kernelSize <<
") / " << strideSize;
1170 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1171 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1173 << dimName <<
" did not match expected: " <<
"calculated="
1174 << calculatedOutSize <<
", expected=" << outputSize;
1179 if (failed(verifyOutputSize(inputType.getDimSize(1),
1180 outputType.getDimSize(1), kernel[0], strides[0],
1181 padding[0], padding[1],
"height",
"y",
"top",
1185 if (failed(verifyOutputSize(
1186 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1187 strides[1], padding[2], padding[3],
"width",
"x",
"left",
"right")))
1193template <
typename T>
1196 op.getPad(), op.getInput(), op.getOutput());
1199template <
typename T>
1203 const Type inputZpETy =
1205 const Type outputZpETy =
1208 auto accType = op.getAccType();
1209 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1210 return op.emitOpError(
"accumulator type for integer tensor is not i32");
1212 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1213 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1215 if (inputETy.
isBF16() && !accType.isF32())
1216 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
1218 if (inputETy.
isF32() && !accType.isF32())
1219 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
1221 if (inputETy != inputZpETy)
1222 return op.emitOpError(
"expect both input and its zero point are the same "
1223 "element type, got ")
1224 << inputETy <<
" and " << inputZpETy;
1226 if (resultETy != outputZpETy)
1227 return op.emitOpError(
"expect both output and its zero point are the same "
1228 "element type, got ")
1229 << resultETy <<
" and " << outputZpETy;
1231 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1232 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1235 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1236 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1243struct AdaptivePoolingConstShapeValues {
1244 llvm::SmallVector<int64_t> kernel;
1245 llvm::SmallVector<int64_t> stride;
1246 llvm::SmallVector<int64_t> pad;
1250template <
typename T>
1252 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1253 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1255template <
typename T,
1256 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1259 T op, AdaptivePoolingConstShapeValues &values) {
1265LogicalResult tosa::AvgPool2dOp::verify() {
1273LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1274 AdaptivePoolingConstShapeValues values;
1283 values.pad, getInput(), getOutput())))
1292LogicalResult tosa::ClampOp::verify() {
1294 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1295 if (
auto quantType =
1296 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1300 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1301 if (
auto quantType =
1302 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1305 if (inputETy != outputETy)
1306 return emitOpError(
"input/output element types are incompatible.");
1308 auto maxValAttr = getMaxValAttr();
1309 auto minValAttr = getMinValAttr();
1313 if (inputETy.
isInteger(dataTypeBitWidth)) {
1317 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1318 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1319 if (!intMaxValAttr || !intMinValAttr ||
1320 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1321 (intMaxValAttr.getType() != inputETy))
1322 return emitOpError(
"min/max attributes types are incompatible with "
1323 "input/output element types.");
1326 const bool isBoolean = inputETy.
isInteger(1);
1327 const APInt minVal = intMinValAttr.getValue();
1328 const APInt maxVal = intMaxValAttr.getValue();
1329 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1330 return emitOpError(
"expected min_val <= max_val, got min_val=")
1331 << minValAttr <<
", max_val=" << maxValAttr;
1336 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1337 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1338 if (!floatMaxValAttr || !floatMinValAttr ||
1339 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1340 (floatMaxValAttr.getType() != inputETy))
1341 return emitOpError(
"min/max attributes types are incompatible with "
1342 "input/output element types.");
1344 const APFloat minVal = floatMinValAttr.getValue();
1345 const APFloat maxVal = floatMaxValAttr.getValue();
1346 if (minVal.isNaN() || maxVal.isNaN())
1347 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1348 << minValAttr <<
", max_val=" << maxValAttr;
1350 if (maxVal < minVal)
1351 return emitOpError(
"expected min_val <= max_val, got min_val=")
1352 << minValAttr <<
", max_val=" << maxValAttr;
1372 result.addOperands({input, weight, bias, zps.first, zps.second});
1373 result.addAttribute(
"pad", pad);
1374 result.addAttribute(
"stride", stride);
1375 result.addAttribute(
"dilation", dilation);
1376 result.addAttribute(
"acc_type", accType);
1377 Type finalOutputType = outputType;
1383 result.addTypes(finalOutputType);
1394 result.addOperands({input, weight, bias, zps.first, zps.second});
1395 result.addAttribute(
"out_pad", outpad);
1396 result.addAttribute(
"stride", stride);
1397 result.addAttribute(
"acc_type", accType);
1398 Type finalOutputType = outputType;
1404 result.addTypes(finalOutputType);
1415 result.addOperands({a,
b, zps.first, zps.second});
1417 Type finalOutputType{outputType};
1420 auto inputBits = eType.getIntOrFloatBitWidth();
1422 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1423 assert(outputShapedType &&
"Output must be a shaped type");
1425 IntegerType accElementType;
1426 if (inputBits == 16)
1431 finalOutputType = outputShapedType.clone(accElementType);
1433 result.addTypes(finalOutputType);
1442 DenseArrayAttr kernel, DenseArrayAttr stride,
1443 DenseArrayAttr pad, TypeAttr accType) {
1448 if (
auto quantAttr =
1450 inputZp = quantAttr.getInputZp();
1451 outputZp = quantAttr.getOutputZp();
1453 const std::optional<Value> inputZpOp =
1458 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1460 const std::optional<Value> outputZpOp =
1463 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1464 "quantized AVG_POOL2D op");
1467 if (inputZpOp && outputZpOp) {
1468 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1473 result.addOperands({input});
1475 result.addAttribute(
"kernel", kernel);
1476 result.addAttribute(
"stride", stride);
1477 result.addAttribute(
"pad", pad);
1478 result.addAttribute(
"acc_type", accType);
1479 result.types.push_back(outputType);
1492 if (
auto quantAttr =
1494 inputZp = quantAttr.getInputZp();
1495 outputZp = quantAttr.getOutputZp();
1497 const std::optional<Value> inputZpOp =
1501 "Failed to create input zero point tensor for quantized "
1502 "AVG_POOL2D_ADAPTIVE op");
1504 const std::optional<Value> outputZpOp =
1507 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1508 "quantized AVG_POOL2D_ADAPTIVE op");
1511 if (inputZpOp && outputZpOp) {
1516 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1517 kernelShape, strideShape, padShape});
1522 result.addOperands({input});
1524 result.addAttribute(
"acc_type", accType);
1525 result.types.push_back(outputType);
1539 input1Zp = quantAttr.getInputZp();
1540 outputZp = quantAttr.getOutputZp();
1542 const std::optional<Value> input1ZpOp =
1546 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1549 const std::optional<Value> outputZpOp =
1553 loc,
"Failed to create output zero point for quantized NEGATE op");
1556 if (input1ZpOp && outputZpOp) {
1557 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1562 result.addOperands({input});
1565 result.types.push_back(outputType);
1578 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1581 result.addOperands({input, paddings, padConstOp});
1582 result.types.push_back(outputType);
1586 StringRef name,
Type variableType,
1591 auto shapedType = dyn_cast<ShapedType>(variableType);
1593 (
void)
emitError(loc,
"variable type must be a shaped type");
1596 if (!shapedType.hasRank()) {
1597 (
void)
emitError(loc,
"variable type must be a ranked type");
1601 auto elementType = shapedType.getElementType();
1602 auto elementTypeAttr = TypeAttr::get(elementType);
1606 result.addAttribute(
"sym_name", nameAttr);
1607 result.addAttribute(
"var_shape", varShapeAttr);
1608 result.addAttribute(
"type", elementTypeAttr);
1609 result.addAttribute(
"initial_value", initialValue);
1622 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1626 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1632 for (
int i = 0, e = operands.size(); i != e; ++i) {
1634 if (!
shape.hasRank()) {
1639 outRank = std::max<int64_t>(outRank,
shape.getRank());
1642 outShape.resize(outRank, 1);
1644 for (
int i = 0, e = operands.size(); i != e; ++i) {
1646 auto rankDiff = outShape.size() -
shape.getRank();
1648 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1649 auto dim1 = outShape[i + rankDiff];
1650 auto dim2 =
shape.getDimSize(i);
1652 const FailureOr<int64_t> maybeResolvedDim =
1654 if (failed(maybeResolvedDim))
1656 const int64_t resolvedDim = *maybeResolvedDim;
1657 outShape[i + rankDiff] = resolvedDim;
1664LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1665 MLIRContext *context, ::std::optional<Location> location,
1666 ArgMaxOp::Adaptor adaptor,
1669 IntegerAttr axis = adaptor.getProperties().axis;
1670 int32_t axisVal = axis.getValue().getSExtValue();
1672 if (!inputShape.hasRank()) {
1678 outShape.reserve(inputShape.getRank() - 1);
1679 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1682 outShape.push_back(inputShape.getDimSize(i));
1689LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1690 MLIRContext *context, ::std::optional<Location> location,
1691 RFFT2dOp::Adaptor adaptor,
1693 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1695 if (!inputShape.hasRank())
1699 outputShape.resize(3, ShapedType::kDynamic);
1700 outputShape[0] = inputShape.getDimSize(0);
1701 outputShape[1] = inputShape.getDimSize(1);
1702 int64_t inWidth = inputShape.getDimSize(2);
1706 if (inWidth != ShapedType::kDynamic)
1707 outputShape[2] = inWidth / 2 + 1;
1716 const llvm::StringRef dimName) {
1717 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1720 << dimName <<
" to be a power of two, got " << dimSize;
1725LogicalResult tosa::RFFT2dOp::verify() {
1726 const auto outputTypes = getResultTypes();
1728 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1730 const auto inputType =
1731 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1735 const int64_t height = inputType.getDimSize(1);
1736 if (ShapedType::isStatic(height) &&
1740 const int64_t width = inputType.getDimSize(2);
1741 if (ShapedType::isStatic(width) &&
1745 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1751 outputType.getShape().drop_back())))
1752 return emitOpError(
"expected batch and height dimensions of input/output "
1753 "to match, got input=")
1754 << inputType <<
" output=" << outputType;
1757 const int64_t outputWidth = outputType.getDimSize(2);
1758 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1759 (outputWidth != (width / 2) + 1))
1761 "expected output width to be equal to input_width / 2 + 1, got ")
1767LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1768 MLIRContext *context, ::std::optional<Location> location,
1769 FFT2dOp::Adaptor adaptor,
1771 inferredReturnShapes.push_back(
1773 inferredReturnShapes.push_back(
1778LogicalResult tosa::FFT2dOp::verify() {
1779 const auto inputRealType =
1780 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1781 const auto inputImagType =
1782 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1783 if (!inputRealType || !inputImagType)
1786 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1787 return ShapedType::isDynamic(a) ? a :
b;
1790 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1791 inputImagType.getDimSize(1));
1792 if (ShapedType::isStatic(height) &&
1796 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1797 inputImagType.getDimSize(2));
1798 if (ShapedType::isStatic(width) &&
1805LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1806 MLIRContext *context, ::std::optional<Location> location,
1807 ConcatOp::Adaptor adaptor,
1810 const Properties &prop = adaptor.getProperties();
1811 int32_t axis = prop.axis.getValue().getSExtValue();
1813 bool hasRankedInput =
false;
1814 for (
auto operand : adaptor.getOperands()) {
1816 if (!operandShape.hasRank())
1820 if (!hasRankedInput)
1821 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1824 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1825 if (i == axis || operandShape.isDynamicDim(i))
1827 if (outputShape[i] == ShapedType::kDynamic)
1828 outputShape[i] = operandShape.getDimSize(i);
1829 if (outputShape[i] != operandShape.getDimSize(i))
1831 "Cannot concat tensors with different sizes"
1832 " on the non-axis dimension ",
1836 hasRankedInput =
true;
1839 if (adaptor.getInput1().empty())
1843 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1844 if (!hasRankedInput) {
1851 for (
auto operand : adaptor.getOperands()) {
1856 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1857 concatDimSize = ShapedType::kDynamic;
1861 concatDimSize += operandShape.getDimSize(axis);
1864 outputShape[axis] = concatDimSize;
1870LogicalResult tosa::ConcatOp::verify() {
1872 auto outType = getOutput().getType();
1876 if (inputList.empty())
1879 if (!llvm::all_of(inputList, [&](
auto input) {
1881 *
this, input.getType(), outType));
1886 const int32_t axis = getAxis();
1888 for (
const auto &input : inputList) {
1889 const Type inputType = input.getType();
1891 if (currShape.hasRank()) {
1892 firstRankedInputShape = currShape;
1894 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1895 return emitOpError(
"expect axis to be within range 0 < axis < "
1896 "rank(input1[firstRankedTensorIdx]), got ")
1902 const auto allOperandsHasRank = [](
const Value input) {
1905 if (llvm::all_of(inputList, allOperandsHasRank)) {
1908 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1910 const int64_t inputRank = inputShape.getRank();
1911 const size_t operandNum =
index + 1;
1914 if (inputRank != firstInputRank)
1916 "expect all operands to have the same rank, but got ")
1917 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1921 for (
int i = 0; i < inputRank; i++) {
1922 const int64_t inputDim = inputShape.getDimSize(i);
1924 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1925 inputShape.isDynamicDim(i))
1927 if (inputDim != firstInputDim)
1928 return emitOpError(
"expect all operand shapes to have the same sizes "
1929 "on non-axis dimensions, but got ")
1930 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1931 <<
" on operands 0 and " << operandNum;
1936 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1937 return emitOpError(
"expect output rank to match inputs rank, got ")
1938 << outputShape.getRank() <<
" vs " << firstInputRank;
1942 for (
const auto &input : inputList) {
1944 if (inputShape.isDynamicDim(axis)) {
1949 axisSum += inputShape.getDimSize(axis);
1952 if (axisSum >= 0 && outputShape.hasRank() &&
1953 !outputShape.isDynamicDim(axis) &&
1954 axisSum != outputShape.getDimSize(axis))
1955 return emitOpError(
"requires sum of axis dimensions of input1 "
1956 "equal to output axis dimension, got ")
1957 << axisSum <<
" and " << outputShape.getDimSize(axis);
1963LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1964 MLIRContext *context, ::std::optional<Location> location,
1968 auto elementType = IntegerType::get(context, 1);
1981 if (l.size() != r.size() || l.size() != 1)
1986LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1987 MLIRContext *context, ::std::optional<Location> location,
1988 MatMulOp::Adaptor adaptor,
1995 outShape.resize(3, ShapedType::kDynamic);
1997 if (lhsShape.hasRank()) {
1998 outShape[0] = lhsShape.getDimSize(0);
1999 outShape[1] = lhsShape.getDimSize(1);
2002 if (rhsShape.hasRank()) {
2003 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2005 outShape[2] = rhsShape.getDimSize(2);
2012LogicalResult MatMulOp::verify() {
2015 const Type aElementType = aShape.getElementType();
2016 const Type bElementType = bShape.getElementType();
2018 const auto aQuantizedEType =
2019 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2020 const auto bQuantizedEType =
2021 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2023 if (aQuantizedEType || bQuantizedEType) {
2024 if (!aQuantizedEType || !bQuantizedEType) {
2025 return emitOpError(
"expect operands to be both quantized or both not "
2027 << aElementType <<
" and " << bElementType;
2030 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2031 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2032 if (aQuantWidth != bQuantWidth) {
2033 return emitOpError(
"expect quantized operands to have same widths, got ")
2034 << aQuantWidth <<
" and " << bQuantWidth;
2041 if (aEType != aZpEType)
2042 return emitOpError(
"expect input a and a_zp have the same "
2043 "element type, got ")
2044 << aEType <<
" and " << aZpEType;
2048 if (bEType != bZpEType)
2049 return emitOpError(
"expect input b and b_zp have the same "
2050 "element type, got ")
2051 << bEType <<
" and " << bZpEType;
2053 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2054 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2057 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2058 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2062 int64_t N = ShapedType::kDynamic;
2063 int64_t H = ShapedType::kDynamic;
2067 if (aShape.hasRank()) {
2068 N = aShape.getDimSize(0);
2069 H = aShape.getDimSize(1);
2070 C = aShape.getDimSize(2);
2073 if (bShape.hasRank()) {
2079 W = bShape.getDimSize(2);
2083 const auto outputType = cast<ShapedType>(getResult().
getType());
2084 if (outputType.hasRank() &&
2089 opError <<
" to be compatible with expected output shape ";
2097LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2098 MLIRContext *context, ::std::optional<Location> location,
2099 MatmulTBlockScaledOp::Adaptor adaptor,
2103 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2104 if (aDataShape.hasRank()) {
2105 outShape[0] = aDataShape.getDimSize(0);
2106 outShape[1] = aDataShape.getDimSize(1);
2109 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2110 if (aScaleShape.hasRank()) {
2111 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2113 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2118 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2119 if (bDataShape.hasRank()) {
2120 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2121 if (bDataBatchSize != 1)
2123 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2124 outShape[2] = bDataShape.getDimSize(1);
2127 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2128 if (bScaleShape.hasRank()) {
2129 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2130 if (bScaleBatchSize != 1)
2132 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2133 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2141LogicalResult MatmulTBlockScaledOp::verify() {
2143 const Type aDataType = getAData().getType();
2144 const Type bDataType = getBData().getType();
2150 int64_t N = ShapedType::kDynamic;
2151 int64_t D = ShapedType::kDynamic;
2152 int64_t H = ShapedType::kDynamic;
2155 int64_t multiplesOfC = ShapedType::kDynamic;
2167 "a_scale",
"batch")) ||
2169 "a_scale",
"height")))
2177 "b_data",
"batch")) ||
2179 "b_data",
"channels")))
2187 "b_scale",
"batch")) ||
2189 "b_scale",
"width")) ||
2197 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2198 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2200 << D <<
" vs N=" << N;
2203 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2204 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2205 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2206 if (ShapedType::isStatic(C) && C % blockSize != 0)
2207 return emitOpError(
"expect C to be a multiple of block size, got C=")
2208 <<
C <<
", block_size=" << blockSize;
2211 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2212 multiplesOfC != C / blockSize)
2214 "expect scale operands dimension 2 to equal C/block_size (")
2215 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2218 N = ShapedType::isDynamic(N) ? D : N;
2220 const auto outputType = cast<ShapedType>(getResult().
getType());
2221 if (outputType.hasRank() &&
2226 opError <<
" to be compatible with expected output shape ";
2234LogicalResult tosa::PadOp::inferReturnTypeComponents(
2235 MLIRContext *context, ::std::optional<Location> location,
2236 PadOp::Adaptor adaptor,
2238 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2240 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2245 if (!inputShape.hasRank()) {
2246 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2255 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2260 outputShape.reserve(inputShape.getRank());
2261 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2262 if (inputShape.isDynamicDim(i)) {
2263 outputShape.push_back(ShapedType::kDynamic);
2266 auto padFront = paddingValues[i * 2];
2267 auto padBack = paddingValues[i * 2 + 1];
2268 if (padFront < 0 || padBack < 0) {
2270 outputShape.push_back(ShapedType::kDynamic);
2274 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2281LogicalResult tosa::PadOp::verify() {
2288 if (
auto padConst = getPadConst()) {
2296 RankedTensorType inputType =
2297 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2298 RankedTensorType outputType =
2299 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2300 if (!inputType || !outputType)
2307 auto inputRank = inputType.getRank();
2312 auto paddingValues = paddingAttr.getValues<APInt>();
2313 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2314 return emitOpError() <<
"padding tensor must have " << inputRank
2315 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2316 << paddingValues.size();
2318 auto inputShape = inputType.getShape();
2319 auto outputShape = outputType.getShape();
2321 for (
int64_t i = 0; i < inputRank; ++i) {
2322 int64_t padStart = paddingValues[i * 2].getSExtValue();
2323 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2325 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2327 <<
"invalid padding values at dimension " << i
2328 <<
": values must be non-negative or -1 for dynamic padding, got ["
2329 << padStart <<
", " << padEnd <<
"]";
2333 if (inputShape[i] == ShapedType::kDynamic ||
2334 outputShape[i] == ShapedType::kDynamic)
2337 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2338 return emitOpError() <<
"mismatch in output shape at dimension " << i
2339 <<
": expected " << inputShape[i] <<
" + "
2340 << padStart <<
" + " << padEnd <<
" = "
2341 << (inputShape[i] + padStart + padEnd)
2342 <<
", but got " << outputShape[i];
2349LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2350 MLIRContext *context, ::std::optional<Location> location,
2351 SliceOp::Adaptor adaptor,
2360 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2368 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2371 if (inputShape.hasRank()) {
2372 for (
size_t i = 0; i < size.size(); i++) {
2373 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2374 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2375 start[i] < inputShape.getDimSize(i))) {
2377 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2380 outputShape[i] = size[i];
2384 if (size[i] == -1) {
2385 outputShape[i] = inputShape.getDimSize(i) - start[i];
2386 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2388 outputShape[i] = size[i];
2400LogicalResult tosa::SliceOp::verify() {
2401 const Value input = getInput1();
2402 const Value output = getOutput();
2408 const Value start = getStart();
2409 const Value size = getSize();
2413 if (inputShape.hasRank()) {
2414 const auto inputRank = inputShape.getRank();
2415 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2417 "expect input1 and output to have the same ranks, got ")
2418 << inputRank <<
" and " << outputShape.getRank();
2420 const auto startShapeRank =
2421 llvm::cast<tosa::shapeType>(start.
getType()).getRank();
2422 if (inputRank != startShapeRank)
2423 return emitOpError(
"length of start is not equal to rank of input shape");
2425 const auto sizeShapeRank =
2426 llvm::cast<tosa::shapeType>(size.
getType()).getRank();
2427 if (inputRank != sizeShapeRank)
2428 return emitOpError(
"length of size is not equal to rank of input shape");
2433 if (startValues.size()) {
2434 if (llvm::any_of(startValues, [](
const int64_t v) {
2437 return emitOpError(
"start values must be non-negative, got [")
2438 << startValues <<
"]";
2445 if (llvm::any_of(sizeValues, [](
const int64_t v) {
2448 return emitOpError(
"size values must be > 0, got [") << sizeValues <<
"]";
2449 if (outputShape.hasRank()) {
2451 outputShape.getDims(outputDims);
2452 const bool hasNoInferableDims = llvm::all_of(
2454 if (hasNoInferableDims &&
2456 return emitOpError(
"expected output shape to match size values, got ")
2457 << output.
getType() <<
" vs [" << sizeValues <<
"]";
2460 if (inputShape.hasRank() && startValues.size()) {
2462 inputShape.getDims(inputDims);
2463 for (
const auto &[
index, vals] :
2464 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2465 const auto &[start, size, inputDim] = vals;
2467 ShapedType::isDynamic(inputDim))
2469 if (start + size > inputDim)
2470 return emitOpError(
"start + size must be less than or equal to input "
2471 "dimension size, got start=")
2472 << start <<
", size=" << size
2473 <<
" vs input dim size=" << inputDim <<
" at dimension "
2481LogicalResult tosa::MulOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2497LogicalResult tosa::MulOp::verify() {
2498 const Value output = getOutput();
2503 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2504 IntegerType lhsIntType =
2506 IntegerType rhsIntType =
2508 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2509 return emitOpError(
"requires the same element type for all operands");
2514 if (lhsIntType.getWidth() > resIntType.getWidth())
2515 return emitOpError(
"invalid data type size for operands or result");
2520 for (
int i = 0; i < 2; ++i) {
2523 "requires the same element type for all operands and results");
2527 ElementsAttr shiftElem;
2529 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2531 return emitOpError() <<
"require shift to be 0 for float type";
2539 TypeRange operandTypes = getOperandTypes();
2540 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2541 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2543 const bool aHasRank = aType.hasRank();
2544 const bool bHasRank = bType.hasRank();
2545 if (aHasRank && bHasRank) {
2546 const int64_t aRank = aType.getRank();
2547 const int64_t bRank = bType.getRank();
2549 return emitOpError(
"a and b operands don't have matching ranks, got ")
2550 << aRank <<
" and " << bRank;
2555 aType.getShape(), bType.getShape(), resultShape))
2556 return emitOpError(
"a and b operands don't have broadcast-compatible "
2558 << aType <<
" and " << bType;
2561 ShapedType resultType = cast<ShapedType>(output.
getType());
2562 if (!resultType.hasRank())
2565 const int64_t resultRank = resultType.getRank();
2566 if (aHasRank && resultRank != aType.getRank())
2567 return emitOpError(
"result type has different rank than a, got ")
2568 << resultRank <<
" vs " << aType.getRank();
2569 if (bHasRank && resultRank != bType.getRank())
2570 return emitOpError(
"result type has different rank than b, got ")
2571 << resultRank <<
" vs " << bType.getRank();
2576LogicalResult tosa::TableOp::inferReturnTypeComponents(
2577 MLIRContext *context, ::std::optional<Location> location,
2578 TableOp::Adaptor adaptor,
2580 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2582 if (!inputShape.hasRank()) {
2587 inferredReturnShapes.resize(1);
2588 inputShape.getDims(inferredReturnShapes[0]);
2592LogicalResult tosa::TableOp::verify() {
2593 const TensorType inputType = getInput1().getType();
2594 const TensorType outputType = getOutput().getType();
2603 auto inputDims = inputType.
getShape();
2604 auto outputDims = outputType.
getShape();
2605 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2607 auto [inputDim, outputDim] = it.value();
2608 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2609 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2610 <<
" doesn't match dim(input, " << dim
2611 <<
") = " << inputDim;
2624 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2625 [](
const APInt &val) { return val.getSExtValue(); });
2629LogicalResult tosa::TileOp::inferReturnTypeComponents(
2630 MLIRContext *context, ::std::optional<Location> location,
2631 TileOp::Adaptor adaptor,
2638 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2645 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2647 if (!inputShape.hasRank()) {
2648 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2649 inferredReturnShapes.push_back(
2653 if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2657 outputShape.reserve(multiples.size());
2658 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2659 if (multiples[i] == ShapedType::kDynamic) {
2660 outputShape.push_back(ShapedType::kDynamic);
2662 int64_t dim = inputShape.getDimSize(i);
2663 if (dim != ShapedType::kDynamic)
2664 dim *= multiples[i];
2665 outputShape.push_back(dim);
2673LogicalResult tosa::TileOp::verify() {
2679 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2680 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2682 shapeType multiplesType =
2683 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2685 auto multiplesRank = multiplesType.getRank();
2687 if (inputType.hasRank()) {
2688 if (inputType.getRank() != multiplesRank)
2689 return emitOpError(
"expect 'multiples' to have rank ")
2690 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2691 if (outputType.hasRank() &&
2695 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2696 return emitOpError(
"expect 'multiples' array to have length ")
2697 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2700 if (getConstantMultiples(multiples).succeeded() &&
2701 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2703 "expect element of 'multiples' to be positive integer or -1.");
2709 if (l.size() != r.size() || l.size() != 1)
2714LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2715 MLIRContext *context, ::std::optional<Location> location,
2716 ReshapeOp::Adaptor adaptor,
2718 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2723 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2732 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2733 inferredReturnShapes.push_back(
2741 int64_t numElements = inputShape.getNumElements();
2743 for (
auto val : newShapeValue) {
2744 if (ShapedType::isStatic(val)) {
2750 for (
auto &val : newShapeValue) {
2751 if (ShapedType::isDynamic(val))
2752 val = numElements / staticMul;
2755 inferredReturnShapes.push_back(
2760llvm::LogicalResult tosa::ReshapeOp::verify() {
2766 TensorType inputType = getInput1().getType();
2771 return mlir::success();
2775 if (missingDims > 1)
2776 return emitOpError() <<
"expected at most one target dimension to be "
2779 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2783 if ((
int64_t)shapeValues.size() != outputType.getRank())
2784 return emitOpError() <<
"new shape does not match result rank";
2786 for (
auto [newShapeDim, outputShapeDim] :
2787 zip(shapeValues, outputType.getShape())) {
2789 newShapeDim != ShapedType::kDynamic &&
2790 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2791 return emitOpError() <<
"new shape is inconsistent with result shape";
2794 return emitOpError() <<
"new shape has invalid tensor dimension size "
2798 if (inputType.hasStaticShape()) {
2799 int64_t inputElementsNum = inputType.getNumElements();
2800 if (outputType.hasStaticShape()) {
2801 int64_t outputElementsNum = outputType.getNumElements();
2802 if (inputElementsNum != outputElementsNum) {
2803 return emitOpError() <<
"cannot reshape " << inputElementsNum
2804 <<
" elements into " << outputElementsNum;
2810 return (dim > 0) ?
acc * dim :
acc;
2812 bool isStaticNewShape =
2813 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2814 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2815 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2816 return emitOpError() <<
"cannot reshape " << inputElementsNum
2817 <<
" elements into " << newShapeElementsNum;
2821 return mlir::success();
2824bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(
TypeRange l,
2826 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2834LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2835 MLIRContext *context, ::std::optional<Location> location,
2836 ReshapeBlockScaledOp::Adaptor adaptor,
2839 const auto numInputs = adaptor.getInput().size();
2840 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2843 const auto newShape = adaptor.getNewValueShape();
2845 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2854 const uint32_t blockSize =
2855 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2858 if (numInputs == 2) {
2859 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2860 if (ShapedType::isStatic(newScaleShapeValue.back()))
2861 newScaleShapeValue.back() /= blockSize;
2864 inferredReturnShapes.push_back(
2866 if (numInputs == 2) {
2868 for (
size_t idx = 0; idx < newShapeValue.size(); idx++) {
2869 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2870 newScaleShapeValue[idx] = newShapeValue[idx];
2871 if (idx == (newShapeValue.size() - 1))
2872 newScaleShapeValue[idx] /= blockSize;
2883llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2887 if (inputList.size() == 0)
2888 return emitOpError(
"requires at least one input");
2890 if (inputList.size() > 2)
2891 return emitOpError(
"requires at most two inputs");
2893 if (inputList.size() != outputList.size())
2894 return emitOpError(
"requires number of results to match inputs");
2902 const auto inputType = llvm::cast<ShapedType>(inputList[0].
getType());
2903 if (!inputType.hasRank())
2905 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2907 if (inputList.size() == 2) {
2908 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2909 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2910 if (llvm::any_of(inputList, [](
Value v) {
2911 const auto input = cast<ShapedType>(v.
getType());
2912 return input.hasRank() && input.getRank() == 0;
2915 "requires all input shapes have a rank greater than 0");
2916 if (llvm::any_of(outputList, [](
Value v) {
2917 const auto output = cast<ShapedType>(v.
getType());
2918 return output.hasRank() && output.getRank() == 0;
2921 "requires all result shapes have a rank greater than 0");
2929 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].
getType());
2930 if (inputScaleType.hasRank()) {
2931 if (inputType.getRank() != inputScaleType.getRank())
2932 return emitOpError(
"input shapes do not have same rank");
2935 for (
auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2936 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2937 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2938 if (ShapedType::isStatic(inputValueDim) &&
2939 ShapedType::isStatic(inputScaleDim) &&
2940 inputValueDim != inputScaleDim)
2941 return emitOpError(
"input shapes for data and scale do not match on "
2948 inputType.getDimSize(inputType.getRank() - 1);
2949 if (ShapedType::isStatic(lastValueDim)) {
2950 if (lastValueDim % blockSize != 0)
2951 return emitOpError(
"expect last dimension of input_data (")
2952 << lastValueDim <<
") to be divisible by block_size ("
2953 << blockSize <<
")";
2956 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
2958 if (ShapedType::isStatic(lastScaleDim) &&
2959 lastScaleDim != lastValueDim / blockSize)
2960 return emitOpError(
"expect last dimension of scale_data (")
2961 << lastScaleDim <<
") to be " << lastValueDim <<
"/"
2966 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
2967 return emitOpError(
"expect block size to be 1, got ") << blockSize;
2975 return mlir::success();
2978 if (inputList.size() == 2) {
2979 if (
static_cast<int64_t>(shapeValues.size()) == 0)
2980 return emitOpError(
"requires new shape to have a rank greater than 0");
2982 const int64_t lastShapeDim = shapeValues.back();
2983 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
2984 return emitOpError(
"expect last dimension of new shape (")
2985 << lastShapeDim <<
") to be divisible by block_size (" << blockSize
2989 const auto outputType = llvm::cast<ShapedType>(outputList[0].
getType());
2990 if (!outputType.hasRank())
2993 if (
static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
2994 return emitOpError() <<
"result does not match new shape rank";
2996 for (
auto [newShapeDim, outputShapeDim] :
2997 zip(shapeValues, outputType.getShape())) {
2998 if (ShapedType::isStatic(newShapeDim) &&
2999 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3000 return emitOpError() <<
"result shape is inconsistent with new shape";
3003 if (outputList.size() == 2) {
3007 scaleShapeValues.back() /= blockSize;
3009 const auto outputScaleType =
3010 llvm::cast<ShapedType>(outputList[1].
getType());
3011 if (outputScaleType.hasRank()) {
3012 if ((
int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3013 return emitOpError() <<
"result scale does not match new shape rank";
3015 for (
auto [newScaleShapeDim, outputScaleShapeDim] :
3016 zip(scaleShapeValues, outputScaleType.getShape())) {
3017 if (ShapedType::isStatic(newScaleShapeDim) &&
3018 ShapedType::isStatic(outputScaleShapeDim) &&
3019 newScaleShapeDim != outputScaleShapeDim)
3021 <<
"result scale shape is inconsistent with new shape";
3026 if (inputType.hasStaticShape()) {
3027 int64_t inputElementsNum = inputType.getNumElements();
3028 if (outputType.hasStaticShape()) {
3029 int64_t outputElementsNum = outputType.getNumElements();
3030 if (inputElementsNum != outputElementsNum) {
3031 return emitOpError() <<
"cannot reshape " << inputElementsNum
3032 <<
" elements into " << outputElementsNum;
3038 return (dim > 0) ?
acc * dim :
acc;
3040 bool isStaticNewShape =
3041 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
3042 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3043 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3044 return emitOpError() <<
"cannot reshape " << inputElementsNum
3045 <<
" elements into " << newShapeElementsNum;
3049 return mlir::success();
3056 ElementsAttr zpAttr;
3061 Type zpElemType = zpAttr.getElementType();
3063 if (llvm::isa<FloatType>(zpElemType)) {
3064 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3071 if (llvm::isa<IntegerType>(zpElemType)) {
3073 return zpAttr.getValues<APInt>()[0].getSExtValue();
3074 return zpAttr.getValues<APInt>()[0].getZExtValue();
3086 if (!llvm::isa<IntegerType>(attr.getElementType()) ||
3087 attr.getNumElements() != 1)
3090 return attr.getValues<APInt>()[0].getSExtValue();
3093template <
typename T>
3095 const std::string &operand) {
3098 if (!zpElemType.
isInteger(8) && zp != 0) {
3100 std::string lower = operand;
3101 llvm::transform(lower, lower.begin(), ::tolower);
3102 return op.emitOpError()
3103 << lower <<
" zero point must be zero for non-int8 integer types";
3111 const std::string &operand) {
3112 bool isInputZp = (operand ==
"Input");
3114 bool tensorUnsigned =
3115 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3116 StringRef tensorName = isInputZp ?
"input" :
"output";
3122 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
3123 return op.emitOpError()
3124 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
3126 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
3127 return op.emitOpError() <<
"expect " << tensorName
3128 <<
"_zp of 0 or 32768 for unsigned int16 "
3129 << tensorName <<
", got " << zp;
3136#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3137 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3138 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3140 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3141 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3162#undef ZERO_POINT_HELPER
3164LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3165 MLIRContext *context, ::std::optional<Location> location,
3166 TransposeOp::Adaptor adaptor,
3168 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3177 const auto inputRank = inputShape.
getRank();
3181 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
3187 if (inputRank == 0) {
3193 bool allTheSame =
true;
3194 for (
int i = 1, s = inputRank; i < s; i++) {
3204 outputShape.resize(inputRank, inputShape.
getDimSize(0));
3209 outputShape.resize(inputRank, ShapedType::kDynamic);
3212 if (llvm::any_of(adaptor.getPerms(),
3213 [inputRank](
const auto i) { return i >= inputRank; }))
3216 outputShape.reserve(inputRank);
3217 for (
int i = 0, s = inputRank; i < s; i++) {
3218 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
3225LogicalResult tosa::TransposeOp::verify() {
3237 if (inputShape.hasRank() &&
3238 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
3239 return emitOpError() <<
"expected perms attribute to have size "
3240 << inputShape.getRank()
3241 <<
" (input rank) but got size "
3242 << constantPerms.size();
3244 if (inputShape.hasRank() && outputShape.hasRank() &&
3245 inputShape.getRank() != outputShape.getRank())
3247 <<
"expected input tensor rank to equal result tensor rank";
3249 if (outputShape.hasRank() &&
3250 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
3251 return emitOpError() <<
"expected perms attribute to have size "
3252 << outputShape.getRank()
3253 <<
" (output rank) but got size "
3254 << constantPerms.size();
3256 if (!llvm::all_of(constantPerms,
3257 [&constantPerms](int32_t s) {
3259 static_cast<size_t>(s) < constantPerms.size();
3262 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
3263 return emitOpError() <<
"expected valid permutation indices";
3266 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3267 inputShape.getNumElements() != outputShape.getNumElements())
3268 return emitOpError() <<
"expected input1 and output to have same numbers "
3270 << inputShape.getNumElements() <<
" and "
3271 << outputShape.getNumElements();
3275 if (inputShape.hasRank() && outputShape.hasRank()) {
3276 for (
auto i = 0; i < outputShape.getRank(); i++) {
3277 if (inputShape.isDynamicDim(constantPerms[i]) ||
3278 outputShape.isDynamicDim(i))
3281 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3283 <<
"expected output tensor dim " << i <<
" to match "
3284 <<
"input dim " << constantPerms[i] <<
" with value of "
3285 << inputShape.getDimSize(constantPerms[i]);
3292LogicalResult TransposeOp::reifyResultShapes(
3295 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3297 Value input = getInput1();
3298 auto inputType = cast<TensorType>(input.
getType());
3300 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3301 for (
auto dim : transposePerms) {
3302 int32_t dimInInput = transposePerms[dim];
3303 if (inputType.isDynamicDim(dimInInput))
3305 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3309 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
3312 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3316LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3317 MLIRContext *context, ::std::optional<Location> location,
3318 GatherOp::Adaptor adaptor,
3319 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3320 llvm::SmallVector<int64_t> outputShape;
3321 outputShape.resize(3, ShapedType::kDynamic);
3323 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3324 if (valuesShape.hasRank()) {
3325 outputShape[0] = valuesShape.getDimSize(0);
3326 outputShape[2] = valuesShape.getDimSize(2);
3329 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3330 if (indicesShape.hasRank()) {
3331 if (outputShape[0] == ShapedType::kDynamic)
3332 outputShape[0] = indicesShape.getDimSize(0);
3333 if (outputShape[1] == ShapedType::kDynamic)
3334 outputShape[1] = indicesShape.getDimSize(1);
3337 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3341LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3342 MLIRContext *context, ::std::optional<Location> location,
3343 RowGatherBlockScaledOp::Adaptor adaptor,
3344 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3345 const auto values = adaptor.getValues();
3349 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3350 const ShapeAdaptor valuesShape(values.front().getType());
3351 if (valuesShape.hasRank()) {
3352 dataShape[0] = valuesShape.getDimSize(0);
3353 dataShape[2] = valuesShape.getDimSize(2);
3356 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3357 if (indicesShape.hasRank()) {
3358 if (dataShape[0] == ShapedType::kDynamic)
3359 dataShape[0] = indicesShape.getDimSize(0);
3362 succeeded(rowCount) && rowCount.value() > 0) {
3363 const int64_t indicesW = indicesShape.getDimSize(1);
3364 if (ShapedType::isStatic(indicesW))
3365 dataShape[1] = indicesW * rowCount.value();
3369 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3370 if (values.size() == 1)
3373 SmallVector<int64_t> scaleShape = dataShape;
3374 const uint32_t blockSize =
3375 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3376 if (ShapedType::isStatic(dataShape[2]))
3377 scaleShape[2] = dataShape[2] / blockSize;
3379 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3383LogicalResult tosa::GatherOp::verify() {
3390 const ShapeAdaptor valuesShape(getValues().
getType());
3392 const ShapeAdaptor outputShape(getOutput().
getType());
3394 int64_t n = ShapedType::kDynamic;
3395 int64_t w = ShapedType::kDynamic;
3396 int64_t c = ShapedType::kDynamic;
3398 if (valuesShape.hasRank()) {
3399 n = valuesShape.getDimSize(0);
3400 c = valuesShape.getDimSize(2);
3402 if (indicesShape.hasRank()) {
3403 const int64_t indicesN = indicesShape.getDimSize(0);
3404 w = indicesShape.getDimSize(1);
3405 if (n == ShapedType::kDynamic)
3407 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3408 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3409 <<
", got " << indicesN;
3411 if (outputShape.hasRank()) {
3412 const int64_t outputN = outputShape.getDimSize(0);
3413 const int64_t outputW = outputShape.getDimSize(1);
3414 const int64_t outputC = outputShape.getDimSize(2);
3415 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3417 return emitOpError() <<
"requires output dimension 0 to have size " << n
3418 <<
", got " << outputN;
3420 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3422 return emitOpError() <<
"requires output dimension 1 to have size " << w
3423 <<
", got " << outputW;
3424 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3426 return emitOpError() <<
"requires output dimension 2 to have size " << c
3427 <<
", got " << outputC;
3432LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3433 const OperandRange values = getValues();
3434 const ResultRange output = getOutput();
3435 if (values.empty() || values.size() > 2)
3437 <<
"expects values tensor list length to be 1 or 2, got "
3439 if (output.size() != values.size())
3441 <<
"expects output tensor list length to match values tensor list "
3443 << output.size() <<
" results for " << values.size()
3444 <<
" input tensors";
3446 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3447 if (values.size() == 1 && blockSize != 1)
3449 <<
"requires block_size to be BLOCK_SIZE_1 when values tensor list "
3451 if (values.size() == 2 && blockSize == 1)
3453 <<
"requires block_size to not be BLOCK_SIZE_1 when values tensor "
3457 output[0].
getType(),
"values[0]",
3462 "values[1]",
"output[1]")))
3466 succeeded(rowCount) && rowCount.value() <= 0)
3467 return emitOpError() <<
"requires row_count to be > 0, got "
3468 << rowCount.value();
3470 int64_t n = ShapedType::kDynamic;
3471 int64_t k = ShapedType::kDynamic;
3472 int64_t c = ShapedType::kDynamic;
3473 int64_t w = ShapedType::kDynamic;
3474 int64_t multiplesOfC = ShapedType::kDynamic;
3476 const ShapeAdaptor valuesDataShape(values[0].
getType());
3477 if (valuesDataShape.hasRank()) {
3478 n = valuesDataShape.getDimSize(0);
3479 k = valuesDataShape.getDimSize(1);
3480 c = valuesDataShape.getDimSize(2);
3483 if (ShapedType::isStatic(c) && c % blockSize != 0)
3484 return emitOpError() <<
"expects channels of values[0] (" << c
3485 <<
") to be divisible by block_size (" << blockSize
3489 if (indicesShape.hasRank()) {
3491 "indices",
"batch")))
3493 w = indicesShape.getDimSize(1);
3496 const ShapeAdaptor outputDataShape(output[0].
getType());
3497 if (outputDataShape.hasRank()) {
3499 "output[0]",
"batch")) ||
3501 "output[0]",
"channels")))
3505 succeeded(rowCount) && rowCount.value() > 0 &&
3506 ShapedType::isStatic(w)) {
3507 const int64_t expectedOutputRows = w * rowCount.value();
3508 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3509 outputDataShape.getDimSize(1) != expectedOutputRows)
3510 return emitOpError() <<
"requires output[0] dimension 1 to have size "
3511 << expectedOutputRows <<
", got "
3512 << outputDataShape.getDimSize(1);
3516 if (values.size() == 2) {
3517 const ShapeAdaptor valuesScaleShape(values[1].
getType());
3518 if (valuesScaleShape.hasRank()) {
3520 "values[1]",
"batch")) ||
3522 "values[1]",
"rows")))
3524 multiplesOfC = valuesScaleShape.getDimSize(2);
3527 const ShapeAdaptor outputScaleShape(output[1].
getType());
3528 if (outputScaleShape.hasRank()) {
3530 "output[1]",
"batch")))
3534 succeeded(rowCount) && rowCount.value() > 0 &&
3535 ShapedType::isStatic(w)) {
3536 const int64_t expectedOutputRows = w * rowCount.value();
3537 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3538 outputScaleShape.getDimSize(1) != expectedOutputRows)
3539 return emitOpError() <<
"requires output[1] dimension 1 to have size "
3540 << expectedOutputRows <<
", got "
3541 << outputScaleShape.getDimSize(1);
3544 if (ShapedType::isDynamic(multiplesOfC))
3545 multiplesOfC = outputScaleShape.getDimSize(2);
3546 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3547 multiplesOfC != outputScaleShape.getDimSize(2))
3549 <<
"expected channels of output[1] to match size "
3550 << multiplesOfC <<
", got " << outputScaleShape.getDimSize(2);
3553 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3554 multiplesOfC != c / blockSize)
3556 <<
"expects channels of scale tensors to equal C/block_size (" << c
3557 <<
"/" << blockSize <<
"), got " << multiplesOfC;
3563LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3564 MLIRContext *context, ::std::optional<Location> location,
3565 ResizeOp::Adaptor adaptor,
3566 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3567 llvm::SmallVector<int64_t, 4> outputShape;
3568 outputShape.resize(4, ShapedType::kDynamic);
3570 ShapeAdaptor inputShape(adaptor.getInput().getType());
3571 if (!inputShape.hasRank())
3574 outputShape[0] = inputShape.getDimSize(0);
3575 outputShape[3] = inputShape.getDimSize(3);
3576 int64_t inputHeight = inputShape.getDimSize(1);
3577 int64_t inputWidth = inputShape.getDimSize(2);
3579 if ((inputHeight == ShapedType::kDynamic) ||
3580 (inputWidth == ShapedType::kDynamic))
3583 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3594 const int64_t outputHeight =
3595 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3599 const int64_t outputWidth =
3600 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3604 if (outputHeight < 0 || outputWidth < 0) {
3607 "calculated output height and width must be non-negative, "
3609 outputHeight,
", width = ", outputWidth);
3612 outputShape[1] = outputHeight;
3613 outputShape[2] = outputWidth;
3614 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3618LogicalResult tosa::ResizeOp::verify() {
3619 const Value input = getInput();
3620 const Value output = getOutput();
3621 const RankedTensorType inputType =
3622 llvm::dyn_cast<RankedTensorType>(input.
getType());
3623 const RankedTensorType outputType =
3624 llvm::dyn_cast<RankedTensorType>(output.
getType());
3626 SmallVector<int64_t> scaleValues;
3627 SmallVector<int64_t> offsetValues;
3628 SmallVector<int64_t> borderValues;
3636 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3637 return emitOpError(
"expect all scale values to be > 0, got ")
3640 const int64_t scaleYN = scaleValues[0];
3641 const int64_t scaleYD = scaleValues[1];
3642 const int64_t scaleXN = scaleValues[2];
3643 const int64_t scaleXD = scaleValues[3];
3645 const int64_t offsetY = offsetValues[0];
3646 const int64_t offsetX = offsetValues[1];
3648 const int64_t borderY = borderValues[0];
3649 const int64_t borderX = borderValues[1];
3656 const int64_t oh = outputType.getDimSize(1);
3657 const int64_t ow = outputType.getDimSize(2);
3658 const int64_t ih = inputType.getDimSize(1);
3659 const int64_t iw = inputType.getDimSize(2);
3665 if (ih != ShapedType::kDynamic && ih != 1) {
3666 const std::optional<int64_t> calculatedOutHeightMinusOne =
3667 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3668 if (!calculatedOutHeightMinusOne.has_value())
3669 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3671 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3672 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3673 <<
") / " << scaleYD;
3674 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3675 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3676 return emitOpError(
"calculated output height did not match expected: ")
3677 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3684 if (iw != ShapedType::kDynamic && iw != 1) {
3685 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3686 const std::optional<int64_t> calculatedOutWidthMinusOne =
3688 if (!calculatedOutWidthMinusOne.has_value())
3689 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3691 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3692 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3693 <<
") / " << scaleXD;
3694 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3695 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3696 return emitOpError(
"calculated output width did not match expected: ")
3697 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3703LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3704 MLIRContext *context, ::std::optional<Location> location,
3705 ScatterOp::Adaptor adaptor,
3706 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3707 llvm::SmallVector<int64_t> outputShape;
3708 outputShape.resize(3, ShapedType::kDynamic);
3710 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3711 if (valuesInShape.hasRank()) {
3712 outputShape[0] = valuesInShape.getDimSize(0);
3713 outputShape[1] = valuesInShape.getDimSize(1);
3714 outputShape[2] = valuesInShape.getDimSize(2);
3717 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3718 if (indicesShape.hasRank()) {
3719 if (outputShape[0] == ShapedType::kDynamic)
3720 outputShape[0] = indicesShape.getDimSize(0);
3723 ShapeAdaptor inputShape(adaptor.getInput().getType());
3724 if (inputShape.hasRank()) {
3725 if (outputShape[0] == ShapedType::kDynamic)
3726 outputShape[0] = inputShape.getDimSize(0);
3727 if (outputShape[2] == ShapedType::kDynamic)
3728 outputShape[2] = inputShape.getDimSize(2);
3731 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3735LogicalResult tosa::ScatterOp::verify() {
3745 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3747 const ShapeAdaptor inputShape(getInput().
getType());
3748 const ShapeAdaptor outputShape(getValuesOut().
getType());
3750 int64_t n = ShapedType::kDynamic;
3751 int64_t k = ShapedType::kDynamic;
3752 int64_t w = ShapedType::kDynamic;
3753 int64_t c = ShapedType::kDynamic;
3754 if (valuesInShape.hasRank()) {
3755 n = valuesInShape.getDimSize(0);
3756 k = valuesInShape.getDimSize(1);
3757 c = valuesInShape.getDimSize(2);
3759 if (indicesShape.hasRank()) {
3760 const int64_t indicesN = indicesShape.getDimSize(0);
3761 w = indicesShape.getDimSize(1);
3762 if (n == ShapedType::kDynamic)
3764 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3765 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3766 <<
", got " << indicesN;
3768 if (inputShape.hasRank()) {
3769 const int64_t inputN = inputShape.getDimSize(0);
3770 const int64_t inputW = inputShape.getDimSize(1);
3771 const int64_t inputC = inputShape.getDimSize(2);
3772 if (n == ShapedType::kDynamic)
3774 else if (inputN != ShapedType::kDynamic && n != inputN)
3775 return emitOpError() <<
"requires input dimension 0 to have size " << n
3776 <<
", got " << inputN;
3777 if (w == ShapedType::kDynamic)
3779 else if (inputW != ShapedType::kDynamic && w != inputW)
3780 return emitOpError() <<
"requires input dimension 1 to have size " << w
3781 <<
", got " << inputW;
3783 if (c == ShapedType::kDynamic)
3785 else if (inputC != ShapedType::kDynamic && c != inputC)
3786 return emitOpError() <<
"requires input dimension 2 to have size " << c
3787 <<
", got " << inputC;
3789 if (outputShape.hasRank()) {
3790 const int64_t outputN = outputShape.getDimSize(0);
3791 const int64_t outputK = outputShape.getDimSize(1);
3792 const int64_t outputC = outputShape.getDimSize(2);
3793 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3795 return emitOpError() <<
"requires values_out dimension 0 to have size "
3796 << n <<
", got " << outputN;
3797 if (k == ShapedType::kDynamic)
3799 else if (outputK != ShapedType::kDynamic && k != outputK)
3800 return emitOpError() <<
"requires values_out dimension 1 to have size "
3801 << k <<
", got " << outputK;
3802 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3804 return emitOpError() <<
"requires values_out dimension 2 to have size "
3805 << c <<
", got " << outputC;
3807 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3808 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3817 int64_t axisVal = axis.getValue().getSExtValue();
3818 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3824 operandShape.
getDims(outputShape);
3825 outputShape[axisVal] = 1;
3830#define COMPATIBLE_RETURN_TYPES(OP) \
3831 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3832 if (l.size() != r.size() || l.size() != 1) \
3834 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3836 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3839#define REDUCE_SHAPE_INFER(OP) \
3840 LogicalResult OP::inferReturnTypeComponents( \
3841 MLIRContext *context, ::std::optional<Location> location, \
3842 OP::Adaptor adaptor, \
3843 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3845 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3846 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3847 const Properties &prop = adaptor.getProperties(); \
3848 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3849 inferredReturnShapes); \
3851 COMPATIBLE_RETURN_TYPES(OP)
3859#undef REDUCE_SHAPE_INFER
3861#undef COMPATIBLE_RETURN_TYPES
3863template <
typename T>
3866 TensorType inputType = op.getInput().getType();
3867 TensorType outputType = op.getOutput().getType();
3868 int32_t reduceAxis = op.getAxis();
3870 if (reduceAxis < 0) {
3871 op.emitOpError(
"reduce axis must not be negative");
3875 int64_t inputRank = inputType.getRank();
3878 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3879 op.emitOpError(
"expect input tensor rank (")
3880 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3886 int64_t outputRank = outputType.getRank();
3887 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3889 "expect output tensor rank to be equal to input tensor rank");
3892 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3893 op.emitOpError(
"expect output tensor rank (")
3894 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3900 if (outputRank != 0) {
3901 auto outputShape = outputType.
getShape();
3902 if (!outputType.isDynamicDim(reduceAxis) &&
3903 outputShape[reduceAxis] != 1) {
3904 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3905 << outputShape[reduceAxis];
3913LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3914LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3915LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3916LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3917LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3918LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3932#define NARY_SHAPE_INFER(OP) \
3933 LogicalResult OP::inferReturnTypeComponents( \
3934 MLIRContext *context, ::std::optional<Location> location, \
3935 ValueShapeRange operands, DictionaryAttr attributes, \
3936 PropertyRef properties, RegionRange regions, \
3937 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3938 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3978#undef PRED_SHAPE_INFER
3980LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3981 MLIRContext *context, ::std::optional<Location> location,
3982 NegateOp::Adaptor adaptor,
3984 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3989LogicalResult tosa::NegateOp::verify() {
3991 const Type input1Type = getInput1().getType();
3992 const Type outputType = getOutput().getType();
3997 const SmallVector<Type, 2> types = {input1Type, outputType};
3999 return emitOpError() <<
"requires the same shape for input1 and output";
4002 const Type input1ZpEType =
4004 if (input1EType != input1ZpEType) {
4005 return emitOpError(
"expect both input1 and its zero point are the same "
4006 "element type, got ")
4007 << input1EType <<
" and " << input1ZpEType;
4010 const Type outputZpEType =
4012 if (outputEType != outputZpEType) {
4013 return emitOpError(
"expect both output and its zero point are the same "
4014 "element type, got ")
4015 << outputEType <<
" and " << outputZpEType;
4018 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4019 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4022 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4023 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4034 outputShape.resize(4, ShapedType::kDynamic);
4049 if (ShapedType::isStatic(height)) {
4050 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4051 outputShape[1] = padded / stride[0] + 1;
4054 if (ShapedType::isStatic(width)) {
4055 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4056 outputShape[2] = padded / stride[1] + 1;
4063template <
typename AdaptorT>
4069 if (ShapedType::isDynamic(current))
4070 current = candidate;
4079 : adaptor(adaptor) {}
4083 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4091 outputShape[0] = outputBatch;
4092 inputSpatial[0] = inputHeight;
4093 inputSpatial[1] = inputWidth;
4098 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4106 outputShape[3] = outputChannels;
4107 weightSpatial[0] = kernelHeight;
4108 weightSpatial[1] = kernelWidth;
4117 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4118 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4119 dilationValues.assign(adaptor.getDilation().begin(),
4120 adaptor.getDilation().end());
4125 Conv2DOp::Adaptor adaptor;
4133 : adaptor(adaptor) {}
4137 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4138 if (inputDataShape.
hasRank()) {
4143 outputShape[0] = outputBatch;
4144 inputSpatial[0] = inputHeight;
4145 inputSpatial[1] = inputWidth;
4148 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4149 if (!inputScaleShape.
hasRank())
4163 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4164 if (weightDataShape.
hasRank()) {
4169 outputShape[3] = outputChannels;
4170 weightSpatial[0] = kernelHeight;
4171 weightSpatial[1] = kernelWidth;
4174 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4175 if (!weightScaleShape.
hasRank())
4204 Conv2DBlockScaledOp::Adaptor adaptor;
4212 : adaptor(adaptor) {}
4216 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4225 outputShape[0] = outputBatch;
4226 inputSpatial[0] = inputDepth;
4227 inputSpatial[1] = inputHeight;
4228 inputSpatial[2] = inputWidth;
4233 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4242 outputShape[4] = outputChannels;
4243 weightSpatial[0] = kernelDepth;
4244 weightSpatial[1] = kernelHeight;
4245 weightSpatial[2] = kernelWidth;
4254 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4255 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4256 dilationValues.assign(adaptor.getDilation().begin(),
4257 adaptor.getDilation().end());
4262 Conv3DOp::Adaptor adaptor;
4265template <
typename AdaptorT>
4271 ShapedType::kDynamic);
4273 ShapedType::kDynamic);
4275 ShapedType::kDynamic);
4277 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4278 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4280 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4283 if (biasSize != 1) {
4284 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4285 outputShape[outputChannelDim] =
4286 ShapedType::isDynamic(outputShape[outputChannelDim])
4288 : outputShape[outputChannelDim];
4295 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4301 for (
int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4302 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4303 !ShapedType::isStatic(weightSpatial[dim]))
4306 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4308 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4309 const int64_t unstridedResult = inputSize - filterSize + 1;
4310 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4317LogicalResult Conv2DOp::inferReturnTypeComponents(
4318 MLIRContext *context, ::std::optional<Location> location,
4319 Conv2DOp::Adaptor adaptor,
4320 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4324LogicalResult Conv2DOp::verify() {
4331LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4332 MLIRContext *context, ::std::optional<Location> location,
4333 Conv2DBlockScaledOp::Adaptor adaptor,
4334 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4338LogicalResult Conv2DBlockScaledOp::verify() {
4340 getWeightData().
getType(),
"input_data",
4343 getWeightScale().
getType(),
"input_scale",
4346 getOutput().
getType(),
"bias",
"output")))
4350 int64_t N = ShapedType::kDynamic;
4351 int64_t IH = ShapedType::kDynamic;
4352 int64_t IW = ShapedType::kDynamic;
4353 int64_t IC = ShapedType::kDynamic;
4354 int64_t multiplesOfIC = ShapedType::kDynamic;
4355 int64_t OC = ShapedType::kDynamic;
4356 int64_t KH = ShapedType::kDynamic;
4357 int64_t KW = ShapedType::kDynamic;
4359 const ShapeAdaptor inputDataShape(getInputData().
getType());
4360 if (inputDataShape.hasRank()) {
4361 N = inputDataShape.getDimSize(0);
4362 IH = inputDataShape.getDimSize(1);
4363 IW = inputDataShape.getDimSize(2);
4364 IC = inputDataShape.getDimSize(3);
4367 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
4368 if (inputScaleShape.hasRank()) {
4370 "input_scale",
"batch size")) ||
4372 "input_scale",
"input height")) ||
4374 "input_scale",
"input width")))
4376 multiplesOfIC = inputScaleShape.getDimSize(3);
4379 const ShapeAdaptor weightDataShape(getWeightData().
getType());
4380 if (weightDataShape.hasRank()) {
4381 OC = weightDataShape.getDimSize(0);
4382 KH = weightDataShape.getDimSize(1);
4383 KW = weightDataShape.getDimSize(2);
4385 "weight_data",
"input channels")))
4389 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
4390 if (weightScaleShape.hasRank()) {
4392 "weight_scale",
"output channels")) ||
4394 "weight_scale",
"kernel height")) ||
4396 "weight_scale",
"kernel width")) ||
4398 weightScaleShape.getDimSize(3),
4399 "weight_scale",
"input channel blocks")))
4403 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
4404 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4405 return emitOpError(
"expect block size to be 32, got ") << blockSize;
4407 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4408 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
4409 << IC <<
", block_size=" << blockSize;
4412 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4413 multiplesOfIC != IC / blockSize)
4415 "expect scale operands dimension 2 to equal IC/block_size (")
4416 << IC <<
"/" << blockSize <<
")"
4417 <<
", got " << multiplesOfIC;
4420 SmallVector<int64_t> padValues;
4422 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
4423 return emitOpError(
"expect all padding values to be >= 0, got ")
4427 SmallVector<int64_t> strideValues;
4429 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
4430 return emitOpError(
"expect all stride values to be >= 1, got ")
4434 SmallVector<int64_t> dilationValues;
4437 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
4438 return emitOpError(
"expect all dilation values to be >= 1, got ")
4443 const ShapeAdaptor outputShape(getOutput().
getType());
4444 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4445 outputShape.hasRank()) {
4447 padValues[0], padValues[1], strideValues[0],
4448 dilationValues[0],
"height",
"y",
"top",
4451 padValues[2], padValues[3], strideValues[1],
4452 dilationValues[1],
"width",
"x",
"left",
4458 const ShapeAdaptor biasShape(getBias().
getType());
4459 if (biasShape.hasRank() && outputShape.hasRank()) {
4460 const int64_t biasChannels = biasShape.getDimSize(0);
4461 const int64_t outputChannels =
4462 outputShape.getDimSize(outputShape.getRank() - 1);
4463 if (biasChannels == ShapedType::kDynamic ||
4464 outputChannels == ShapedType::kDynamic)
4468 if (biasChannels != outputChannels && biasChannels != 1)
4470 "bias channels expected to be equal to output channels (")
4471 << outputChannels <<
") or 1, got " << biasChannels;
4477LogicalResult Conv3DOp::inferReturnTypeComponents(
4478 MLIRContext *context, ::std::optional<Location> location,
4479 Conv3DOp::Adaptor adaptor,
4480 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4484LogicalResult Conv3DOp::verify() {
4491LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4492 MLIRContext *context, ::std::optional<Location> location,
4493 AvgPool2dOp::Adaptor adaptor,
4494 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4495 ShapeAdaptor inputShape(adaptor.getInput().getType());
4496 const Properties &prop = adaptor.getProperties();
4498 inferredReturnShapes);
4501LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4502 MLIRContext *context, ::std::optional<Location> location,
4503 AvgPool2dAdaptiveOp::Adaptor adaptor,
4504 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4505 ShapeAdaptor inputShape(adaptor.getInput().getType());
4507 llvm::SmallVector<int64_t> kernelValues;
4508 llvm::SmallVector<int64_t> strideValues;
4509 llvm::SmallVector<int64_t> padValues;
4516 padValues, inferredReturnShapes);
4519 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4520 if (inputShape.hasRank()) {
4522 outputShape[0] = inputShape.getDimSize(0);
4523 outputShape[3] = inputShape.getDimSize(3);
4526 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4530LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4531 MLIRContext *context, ::std::optional<Location> location,
4532 MaxPool2dOp::Adaptor adaptor,
4533 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4534 ShapeAdaptor inputShape(adaptor.getInput().getType());
4535 const Properties &prop = adaptor.getProperties();
4537 inferredReturnShapes);
4540LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4541 MLIRContext *context, ::std::optional<Location> location,
4542 MaxPool2dAdaptiveOp::Adaptor adaptor,
4543 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4544 ShapeAdaptor inputShape(adaptor.getInput().getType());
4546 llvm::SmallVector<int64_t> kernelValues;
4547 llvm::SmallVector<int64_t> strideValues;
4548 llvm::SmallVector<int64_t> padValues;
4555 padValues, inferredReturnShapes);
4558 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4559 if (inputShape.hasRank()) {
4560 outputShape[0] = inputShape.getDimSize(0);
4561 outputShape[3] = inputShape.getDimSize(3);
4563 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4567LogicalResult MaxPool2dOp::verify() {
4578LogicalResult MaxPool2dAdaptiveOp::verify() {
4583 AdaptivePoolingConstShapeValues values;
4587 values.pad, getInput(), getOutput())))
4593LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4594 MLIRContext *context, ::std::optional<Location> location,
4595 DepthwiseConv2DOp::Adaptor adaptor,
4596 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4597 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4599 int64_t inputWidth = ShapedType::kDynamic;
4600 int64_t inputHeight = ShapedType::kDynamic;
4601 int64_t inputChannels = ShapedType::kDynamic;
4603 int64_t weightWidth = ShapedType::kDynamic;
4604 int64_t weightHeight = ShapedType::kDynamic;
4605 int64_t depthChannels = ShapedType::kDynamic;
4608 ShapeAdaptor inputShape(adaptor.getInput().getType());
4609 if (inputShape.hasRank()) {
4610 outputShape[0] = inputShape.getDimSize(0);
4611 inputHeight = inputShape.getDimSize(1);
4612 inputWidth = inputShape.getDimSize(2);
4613 inputChannels = inputShape.getDimSize(3);
4617 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4618 if (weightShape.hasRank()) {
4619 weightHeight = weightShape.getDimSize(0);
4620 weightWidth = weightShape.getDimSize(1);
4621 inputChannels = ShapedType::isDynamic(inputChannels)
4622 ? weightShape.getDimSize(2)
4624 depthChannels = weightShape.getDimSize(3);
4629 if (ShapedType::isStatic(inputChannels) &&
4630 ShapedType::isStatic(depthChannels)) {
4631 outputShape[3] = inputChannels * depthChannels;
4635 ShapeAdaptor biasShape(adaptor.getBias().getType());
4636 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4637 int64_t bc = biasShape.getDimSize(0);
4638 if (bc != ShapedType::kDynamic && bc != 1)
4639 outputShape[3] = bc;
4642 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4643 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4644 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4646 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4647 int64_t inputSize = inputHeight + padding[0] + padding[1];
4648 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4649 int64_t unstridedResult = inputSize - filterSize + 1;
4650 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4653 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4654 int64_t inputSize = inputWidth + padding[2] + padding[3];
4655 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4656 int64_t unstridedResult = inputSize - filterSize + 1;
4657 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4660 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4664LogicalResult DepthwiseConv2DOp::verify() {
4671LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4672 MLIRContext *context, ::std::optional<Location> location,
4673 TransposeConv2DOp::Adaptor adaptor,
4674 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4675 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4677 int64_t inputWidth = ShapedType::kDynamic;
4678 int64_t inputHeight = ShapedType::kDynamic;
4679 int64_t weightWidth = ShapedType::kDynamic;
4680 int64_t weightHeight = ShapedType::kDynamic;
4683 ShapeAdaptor inputShape(adaptor.getInput().getType());
4684 if (inputShape.hasRank()) {
4685 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4686 ? inputShape.getDimSize(0)
4688 inputHeight = inputShape.getDimSize(1);
4689 inputWidth = inputShape.getDimSize(2);
4693 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4694 if (weightShape.hasRank()) {
4695 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4696 ? weightShape.getDimSize(0)
4698 weightHeight = weightShape.getDimSize(1);
4699 weightWidth = weightShape.getDimSize(2);
4703 ShapeAdaptor biasShape(adaptor.getBias().getType());
4704 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4705 int64_t bc = biasShape.getDimSize(0);
4706 if (bc != ShapedType::kDynamic && bc != 1)
4707 outputShape[3] = bc;
4710 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4711 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4713 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4714 int64_t calculateSize =
4715 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4717 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4720 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4721 int64_t calculateSize =
4722 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4724 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4727 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4731LogicalResult TransposeConv2DOp::verify() {
4735 const llvm::ArrayRef<int64_t> strides = getStride();
4736 const int64_t strideY = strides[0];
4737 const int64_t strideX = strides[1];
4739 if (strideY < 1 || strideX < 1)
4740 return emitOpError(
"expect all stride values to be >= 1, got [")
4743 const auto checkPadAgainstKernelDim =
4744 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4745 llvm::StringRef kernelDimName) -> LogicalResult {
4746 if (padValue <= -kernelDimSize)
4748 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4749 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4754 const llvm::ArrayRef<int64_t> padding = getOutPad();
4755 const int64_t outPadTop = padding[0];
4756 const int64_t outPadBottom = padding[1];
4757 const int64_t outPadLeft = padding[2];
4758 const int64_t outPadRight = padding[3];
4760 const auto weightType =
4761 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4764 const int64_t kernelHeight = weightType.getDimSize(1);
4765 if (ShapedType::isStatic(kernelHeight)) {
4766 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4767 "out_pad_top",
"KH")))
4770 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4771 "out_pad_bottom",
"KH")))
4775 const int64_t kernelWidth = weightType.getDimSize(2);
4776 if (ShapedType::isStatic(kernelWidth)) {
4777 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4778 "out_pad_left",
"KW")))
4781 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4782 "out_pad_right",
"KW")))
4788 const auto outputType =
4789 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4793 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4794 if (inputType && weightType) {
4795 const int64_t inputHeight = inputType.getDimSize(1);
4796 const int64_t kernelHeight = weightType.getDimSize(1);
4797 const int64_t outputHeight = outputType.getDimSize(1);
4799 if (ShapedType::isStatic(inputHeight) &&
4800 ShapedType::isStatic(outputHeight)) {
4802 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4804 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4805 "+ out_pad_top + out_pad_bottom + KH, but got ")
4806 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4807 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4808 <<
" + " << kernelHeight;
4811 const int64_t inputWidth = inputType.getDimSize(2);
4812 const int64_t kernelWidth = weightType.getDimSize(2);
4813 const int64_t outputWidth = outputType.getDimSize(2);
4815 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4817 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4819 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4820 "+ out_pad_left + out_pad_right + KW, but got ")
4821 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4822 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4827 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4832 const int64_t biasChannels = biasType.getDimSize(0);
4835 if (biasChannels == ShapedType::kDynamic)
4838 const int64_t outputChannels = outputType.getDimSize(3);
4839 if (!ShapedType::isDynamic(outputChannels) &&
4840 biasChannels != outputChannels && biasChannels != 1)
4842 "bias channels expected to be equal to output channels (")
4843 << outputChannels <<
") or 1, got " << biasChannels;
4848LogicalResult RescaleOp::verify() {
4849 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
4851 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
4855 auto inputElementType =
4857 if (!mlir::isa<IntegerType>(inputElementType)) {
4858 emitOpError(
"expect input to have integer element type, got ")
4859 << inputElementType;
4863 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
4865 emitOpError(
"expect shaped tensor for output, got ")
4866 << getOutput().getType();
4870 auto outputElementType =
4872 if (!mlir::isa<IntegerType>(outputElementType)) {
4873 emitOpError(
"expect output to have integer element type, got ")
4874 << outputElementType;
4886 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4887 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4890 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4891 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4894 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
4895 if (!multiplierType) {
4896 emitOpError(
"expect shaped tensor for multiplier, got ")
4897 << getMultiplier().getType();
4901 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
4903 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
4908 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4909 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4910 << multiplierType.getElementType();
4915 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4917 "expect i16 element type for multiplier for scale32=false, got ")
4918 << multiplierType.getElementType();
4922 if (!inputType.hasRank())
4928 int64_t numChannels = 1;
4929 if (getPerChannel()) {
4930 if (inputType.getRank() < 1) {
4931 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4932 "true, but got rank ")
4933 << inputType.getRank();
4936 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4939 if (!multiplierType.hasRank())
4942 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4944 if (multiplierShape[0] != ShapedType::kDynamic &&
4945 multiplierShape[0] != numChannels) {
4947 << numChannels <<
" } for multiplier input, got { "
4948 << multiplierShape[0] <<
" }";
4952 if (!shiftType.hasRank())
4955 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4957 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4959 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
4966LogicalResult RescaleOp::inferReturnTypeComponents(
4967 MLIRContext *context, ::std::optional<Location> location,
4968 RescaleOp::Adaptor adaptor,
4969 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4970 ShapeAdaptor inputShape(adaptor.getInput().getType());
4971 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4975LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4976 MLIRContext *context, ::std::optional<Location> location,
4977 CastFromBlockScaledOp::Adaptor adaptor,
4978 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4979 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4980 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4984LogicalResult CastFromBlockScaledOp::verify() {
4985 const Type inputDataType = getInputData().getType();
4986 const Type outputDataType = getResult().getType();
4988 return emitOpError() <<
"require compatible shapes for input_data ("
4989 << inputDataType <<
") and " <<
"output_data ("
4990 << outputDataType <<
")";
4992 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4994 if (inputDataShape.
hasRank()) {
4995 const unsigned int blockSize =
4997 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4998 return emitOpError(
"expect block size to be 32, got ") << blockSize;
4999 const int64_t inputDataLastDim =
5001 if (inputDataLastDim % blockSize != 0)
5002 return emitOpError() <<
"expect last dimension of input_data ("
5004 <<
") to be divisible by block_size (" << blockSize
5007 const Type inputScaleType = getInputScale().getType();
5008 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5010 if (inputScaleShape.
hasRank()) {
5011 SmallVector<int64_t> inputDataDims, inputScaleDims;
5012 inputDataShape.
getDims(inputDataDims);
5013 inputScaleShape.
getDims(inputScaleDims);
5015 if (inputDataDims.size() != inputScaleDims.size() ||
5017 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5018 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5020 <<
"require compatible shapes for input_data (" << inputDataType
5021 <<
") and " <<
"input_scale (" << inputScaleType
5022 <<
") except for the last dimension";
5024 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5025 inputScaleDims.back()};
5026 if (ShapedType::isStatic(inputDataLastDim) &&
5029 <<
"expect last dimension of input_scale ("
5030 << inputScaleDims.back()
5031 <<
") to be equal to last dimension of input_data / block_size ("
5032 << inputDataDims.back() / blockSize <<
")";
5039LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5040 MLIRContext *context, ::std::optional<Location> location,
5041 CastToBlockScaledOp::Adaptor adaptor,
5042 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5043 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5044 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5045 if (!inputShape.hasRank())
5049 SmallVector<int64_t> outputScaleShape;
5050 inputShape.getDims(outputScaleShape);
5051 const int64_t lastDimLoc = inputShape.getRank() - 1;
5052 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5053 if (ShapedType::isStatic(lastDimSize)) {
5054 const unsigned int blockSize =
5055 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5056 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5058 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5062LogicalResult CastToBlockScaledOp::verify() {
5063 const Type inputDataType = getInputData().getType();
5064 const Type outputDataType = getResult(0).getType();
5066 return emitOpError() <<
"require compatible shapes for input_data ("
5067 << inputDataType <<
") and " <<
"output_data ("
5068 << outputDataType <<
")";
5070 const unsigned int blockSize =
5072 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5073 return emitOpError(
"expect block size to be 32, got ") << blockSize;
5074 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5075 if (inputDataShape.
hasRank()) {
5076 const int64_t inputDataLastDim =
5078 if (ShapedType::isStatic(inputDataLastDim) &&
5079 inputDataLastDim % blockSize != 0)
5080 return emitOpError() <<
"expect last dimension of input_data ("
5082 <<
") to be divisible by block_size (" << blockSize
5086 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5087 const Type outputScaleType = getResult(1).getType();
5088 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5090 SmallVector<int64_t> outputDataDims, outputScaleDims;
5091 outputDataShape.
getDims(outputDataDims);
5092 outputScaleShape.
getDims(outputScaleDims);
5094 if (outputDataDims.size() != outputScaleDims.size() ||
5096 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5097 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5098 return emitOpError() <<
"require compatible shapes for output_data ("
5099 << outputDataType <<
") and " <<
"output_scale ("
5101 <<
") except for the last dimension";
5103 const int64_t outputDataLastDim = outputDataDims.back();
5104 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5105 outputScaleDims.back()};
5106 if (ShapedType::isStatic(outputDataLastDim) &&
5109 <<
"expect last dimension of output_scale ("
5110 << outputScaleDims.back()
5111 <<
") to be equal to last dimension of output_data / block_size ("
5112 << outputDataDims.back() / blockSize <<
")";
5118LogicalResult IfOp::inferReturnTypeComponents(
5119 MLIRContext *context, ::std::optional<Location> location,
5120 IfOp::Adaptor adaptor,
5121 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5122 llvm::SmallVector<tosa::YieldOp> yieldOps;
5123 for (Region *region : adaptor.getRegions()) {
5124 for (
auto &block : *region)
5125 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5126 yieldOps.push_back(returnOp);
5129 if (yieldOps.empty())
5133 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5134 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5135 for (
auto operand : yieldOps.front().getOperands()) {
5136 resultKnowledge.push_back(
5140 for (
auto yieldOp : yieldOps) {
5141 if (resultKnowledge.size() != yieldOp.getNumOperands())
5144 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5145 int32_t index = it.index();
5147 resultKnowledge[index],
5151 resultKnowledge[index] = meet;
5155 for (
const ValueKnowledge &
result : resultKnowledge) {
5156 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5162LogicalResult WhileOp::inferReturnTypeComponents(
5163 MLIRContext *context, ::std::optional<Location> location,
5164 WhileOp::Adaptor adaptor,
5165 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5166 llvm::SmallVector<tosa::YieldOp> yieldOps;
5167 for (
auto &block : adaptor.getBodyGraph())
5168 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5169 yieldOps.push_back(returnOp);
5173 if (yieldOps.empty())
5177 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5178 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5179 for (
auto operand : yieldOps.front().getOperands()) {
5180 resultKnowledge.push_back(
5184 for (
auto yieldOp : yieldOps) {
5185 if (resultKnowledge.size() != yieldOp.getNumOperands())
5188 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5189 int32_t index = it.index();
5191 resultKnowledge[index],
5193 resultKnowledge[index] = meet;
5198 for (
const ValueKnowledge &
result : resultKnowledge) {
5199 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5205std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5206 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
5207 return llvm::to_vector<4>(vt.getShape());
5208 return std::nullopt;
5214 StringRef prefix =
"") {
5215 assert(blocksArgs.size() == initializers.size() &&
5216 "expected same length of arguments and initializers");
5217 if (initializers.empty())
5220 parser << prefix <<
'(';
5221 llvm::interleaveComma(
5222 llvm::zip(blocksArgs, initializers), parser,
5223 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
5228ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
5230 result.regions.reserve(2);
5231 Region *thenRegion =
result.addRegion();
5232 Region *elseRegion =
result.addRegion();
5234 OpAsmParser::UnresolvedOperand cond;
5239 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5240 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5243 OptionalParseResult listResult =
5251 "expected type for condition operand");
5257 "expected type for condition operand");
5265 FunctionType functionType;
5269 <<
"expected list of types for block arguments "
5270 <<
"followed by arrow type and list of return types";
5272 result.addTypes(functionType.getResults());
5274 if (functionType.getNumInputs() != operands.size()) {
5276 <<
"expected as many input types as operands " <<
"(expected "
5277 << operands.size() <<
" got " << functionType.getNumInputs()
5308void IfOp::print(OpAsmPrinter &p) {
5309 p <<
" " << getCondition();
5312 getInputList(),
" ");
5314 p << getCondition().getType();
5316 if (!getInputList().empty()) {
5318 llvm::interleaveComma(getInputList().getTypes(), p);
5327 auto &elseRegion = getElseGraph();
5328 if (!elseRegion.
empty()) {
5336LogicalResult IfOp::verify() {
5338 "'then_graph' arguments", getInputList(),
5344 "'else_graph' arguments", getInputList(),
5350 if (getThenGraph().front().mightHaveTerminator()) {
5352 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5354 *
this, thenYield.getInputs(),
"'then_graph' results",
5355 getOutputList(),
"'output_list'")
5361 if (getElseGraph().front().mightHaveTerminator()) {
5363 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5365 *
this, elseYield.getInputs(),
"'else_graph' results",
5366 getOutputList(),
"'output_list'")
5371 auto condType = getCondition().getType();
5373 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
5379LogicalResult WhileOp::verify() {
5381 getOutputList(),
"'output_list'")
5386 "'cond_graph' arguments", getInputList(),
5392 "'body_graph' arguments", getInputList(),
5397 if (getBodyGraph().front().mightHaveTerminator()) {
5399 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5401 "'body_graph' results",
5402 getInputList(),
"'input_list'")
5409 if (!getCondGraph().front().mightHaveTerminator())
5413 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5417 if (condYield.getInputs().size() != 1)
5418 return emitOpError() <<
"require 'cond_graph' only have one result";
5420 auto condOutType = condYield.getInputs()[0].getType();
5422 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
5426 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
5432LogicalResult ReverseOp::verify() {
5437 TensorType inputType = getInput1().getType();
5438 TensorType outputType = getOutput().getType();
5439 int32_t reverseAxis = getAxis();
5441 if (reverseAxis < 0)
5442 return emitOpError(
"expected non-negative reverse axis");
5444 int64_t inputRank = inputType.getRank();
5447 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5449 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
5453 int64_t outputRank = outputType.getRank();
5454 if (inputType.
hasRank() && outputRank != inputType.getRank())
5456 "expect output tensor rank to be equal to input tensor rank");
5457 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
5459 << outputRank <<
") to be larger than reverse axis ("
5460 << reverseAxis <<
")";
5465LogicalResult tosa::SelectOp::verify() {
5476 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
5477 if (!predicateType) {
5478 return emitOpError(
"expect shaped tensor for input1, got ")
5479 << getInput1().getType();
5481 auto predicateElementType = predicateType.getElementType();
5482 if (!predicateElementType.isInteger(1)) {
5483 return emitOpError(
"expect element type of bool for input1, got ")
5484 << predicateElementType;
5490LogicalResult tosa::VariableReadOp::verify() {
5498LogicalResult tosa::VariableWriteOp::verify() {
5507ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
5508 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5509 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5510 Region *cond =
result.addRegion();
5511 Region *body =
result.addRegion();
5513 OptionalParseResult listResult =
5518 FunctionType functionType;
5523 result.addTypes(functionType.getResults());
5525 if (functionType.getNumInputs() != operands.size()) {
5527 <<
"expected as many input types as operands " <<
"(expected "
5528 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
5538 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
5539 regionArgs[i].type = functionType.getInput(i);
5541 return failure(parser.
parseRegion(*cond, regionArgs) ||
5546void WhileOp::print(OpAsmPrinter &parser) {
5548 getInputList(),
" ");
5551 getResults().getTypes());
5565 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5566 if (llvm::isa<FloatType>(srcElemType)) {
5568 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
5569 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5571 if (llvm::isa<IntegerType>(srcElemType)) {
5574 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5576 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
5577 return std::nullopt;
5585 return mlir::isa<tosa::shapeType>(t);
5592 return emitError() <<
"invalid rank (must be >= 0): " << rank;
5598 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5599 Operation *definingOp = v.getDefiningOp();
5601 return op->
emitOpError(
"shape operand is not compile time resolvable");
5614 auto getRank = [](
const Type type) {
5615 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5621 for (
auto type : operandTypes) {
5622 if (getRank(type) != rank) {
5623 return op->
emitOpError(
"operands don't have matching ranks");
5626 for (
auto type : resultTypes) {
5627 if (getRank(type) != rank) {
5628 return op->
emitOpError(
"result shape has different rank than operands");
5638LogicalResult tosa::ConstShapeOp::verify() {
5640 auto valuesRank = getValues().getType().getRank();
5641 if (valuesRank != 1)
5642 return emitOpError(
"expect elements in attribute values with rank 1");
5644 auto count = getValues().getNumElements();
5645 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
5646 if (count != rank && (count != 1 || rank != 0)) {
5647 return emitOpError(
"expect number of elements in attribute values (")
5648 << count <<
") to be equal to the rank (" << rank
5649 <<
") for the result shape type";
5654LogicalResult tosa::DimOp::verify() {
5655 const tosa::shapeType outShapeType =
5656 cast<tosa::shapeType>(getResult().
getType());
5657 if (outShapeType.getRank() != 1)
5658 return emitOpError(
"expect output shape type to contain one element, got ")
5663 const int64_t inputRank = inputType.getRank();
5664 const int64_t axis = getAxisAttr().getInt();
5665 if (axis < 0 || axis >= inputRank)
5666 return emitOpError(
"expect axis to be in the range [0, ")
5667 << inputRank <<
"), got " << axis;
5672LogicalResult tosa::ConcatShapeOp::verify() {
5673 const tosa::shapeType outShapeType =
5674 cast<tosa::shapeType>(getResult().
getType());
5675 const int64_t outputRank = outShapeType.getRank();
5678 if (inputList.size() == 0)
5679 return emitOpError(
"requires at least one input shape");
5681 if (llvm::any_of(inputList, [](Value v) {
5682 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
5684 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
5686 const int64_t inputsRank =
5687 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
5688 const tosa::shapeType inShapeType =
5689 cast<tosa::shapeType>(input.
getType());
5690 return acc + inShapeType.getRank();
5692 if (outputRank != inputsRank)
5693 return emitOpError(
"requires output shape rank to be equal to the sum of "
5694 "the input shape ranks (")
5695 << inputsRank <<
"), got " << outputRank;
5700LogicalResult tosa::SliceShapeOp::verify() {
5701 std::optional<int32_t> start;
5702 DenseIntElementsAttr startAttr;
5704 start = startAttr.getValues<int32_t>()[0];
5705 if (start && start.value() < 0)
5706 return emitOpError(
"expected non-negative start index, got ")
5709 std::optional<int32_t> size;
5710 DenseIntElementsAttr sizeAttr;
5712 size = sizeAttr.getValues<int32_t>()[0];
5713 if (size && size.value() <= 0)
5714 return emitOpError(
"expected positive size, got ") << size.value();
5719 const tosa::shapeType outShapeType =
5720 cast<tosa::shapeType>(getResult().
getType());
5721 const int64_t outputRank = outShapeType.getRank();
5722 if (outputRank != size)
5724 "expected output type size to be equal to size attribute, got ")
5725 << outputRank <<
" vs " << size.value();
5730 const tosa::shapeType inShapeType =
5731 cast<tosa::shapeType>(getInput().
getType());
5732 const int64_t inputRank = inShapeType.getRank();
5733 const int64_t sliceSize = start.value() + size.value();
5734 if (sliceSize > inputRank)
5735 return emitOpError(
"expected start + size to be less than or equal to "
5736 "input shape rank (")
5737 << inputRank <<
"), got " << sliceSize;
5746#define GET_ATTRDEF_CLASSES
5747#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5752#define GET_TYPEDEF_CLASSES
5753#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5759#define GET_OP_CLASSES
5760#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > getConstantScalarIntValue(Value val)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static void updateIfDynamic(int64_t ¤t, int64_t candidate)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
int64_t getOutputRank() const
int64_t getNumSpatialDims() const
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
int64_t getNumSpatialDims() const
int64_t getOutputRank() const
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
int64_t getNumSpatialDims() const
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
int64_t getOutputRank() const
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Operation is the basic unit of execution within MLIR.
ResultRange result_range
Support result iteration.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)