29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/TypeSwitch.h"
37#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
44#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
50#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
55struct TosaInlinerInterface :
public DialectInlinerInterface {
56 using DialectInlinerInterface::DialectInlinerInterface;
64 IRMapping &map)
const final {
70 IRMapping &map)
const final {
71 return (isa<tosa::IfOp>(dest->getParentOp()) ||
72 isa<tosa::WhileOp>(dest->getParentOp()));
78 TosaDialectBytecodeInterface(Dialect *dialect)
79 : BytecodeDialectInterface(dialect) {}
84 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
88 LogicalResult writeAttribute(Attribute attr,
89 DialectBytecodeWriter &writer)
const override {
90 return ::writeAttribute(attr, writer);
96 Type readType(DialectBytecodeReader &reader)
const override {
100 LogicalResult writeType(Type type,
101 DialectBytecodeWriter &writer)
const override {
102 return ::writeType(type, writer);
105 void writeVersion(DialectBytecodeWriter &writer)
const final {
109 std::unique_ptr<DialectVersion>
110 readVersion(DialectBytecodeReader &reader)
const final {
112 reader.
emitError(
"Dialect does not support versioning");
116 LogicalResult upgradeFromVersion(Operation *topLevelOp,
117 const DialectVersion &version)
const final {
130 return {&getBodyGraph()};
138 return to_vector(llvm::map_range(
shape, [](
int64_t dim) {
139 return dim == -1 ? ShapedType::kDynamic : dim;
145 Type elementType = variableOp.getType();
148 return RankedTensorType::get(
shape, elementType);
155void TosaDialect::initialize() {
157#define GET_TYPEDEF_LIST
158#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
162#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
165#define GET_ATTRDEF_LIST
166#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
169 declarePromisedInterfaces<
170 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
171 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
172 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
173 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
174 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
175 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
176 GreaterEqualOp, MatMulOp>();
183 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
184 return tosa::ConstShapeOp::create(builder, loc, type,
185 llvm::cast<DenseIntElementsAttr>(value));
187 if (llvm::isa<ElementsAttr>(value))
188 return tosa::ConstOp::create(builder, loc, type,
189 llvm::cast<ElementsAttr>(value));
199ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
201 TypeAttr &typeAttr) {
202 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
203 if (!shapedType.hasRank())
205 <<
"expected ranked type";
207 auto elementType = shapedType.getElementType();
208 typeAttr = TypeAttr::get(elementType);
215 <<
"expected shaped type";
232 <<
"expected attribute";
234 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
235 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
239 <<
"expected Typed attr";
242 initialValueAttr =
nullptr;
246 <<
"expected type after colon";
248 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
253 TypeAttr typeAttr,
Attribute initialValueAttr) {
254 bool needsSpace =
false;
255 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
258 Type elementType = typeAttr.getValue();
259 RankedTensorType tensorType =
261 auto tensorTypeAttr = TypeAttr::get(tensorType);
266 if (initialValueAttr) {
277template <
typename EnumType>
278ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
280 llvm::StringRef name;
287 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
288 if (name ==
"rounding_mode" &&
290 auto sym = symbolizeRoundingMode(kw);
293 <<
"invalid rounding_mode value: " << kw;
294 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
300 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
302 auto sym = symbolizeResizeMode(kw);
305 <<
"invalid resize mode value: " << kw;
306 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
313 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
315 auto sym = symbolizeNanPropagationMode(kw);
318 <<
"invalid nan_mode value: " << kw;
319 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
326 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
328 auto sym = symbolizeBlockSize(kw);
331 <<
"invalid block_size value: " << kw;
332 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
344template <
typename EnumType>
349 [&]() { return parser.parseOperand(operands.emplace_back()); }))
357 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
374 result.addTypes(fnTy.getResults());
375 result.addAttributes(attrs);
381 parser << namedAttr.
getName().strref() <<
" = ";
383 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
384 parser << roundingModeAttr.getValue();
385 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
386 parser << resizeModeAttr.getValue();
387 }
else if (
auto nanPropagationModeAttr =
388 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
389 parser << nanPropagationModeAttr.getValue();
390 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
391 parser << blockSizeAttr.getValue();
404 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
406 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
407 if (nanAttr.getValue() == kDefaultNanValue) {
409 toPrint.erase(attr.getName());
415 if (!toPrint.empty()) {
417 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
418 printNamedAttr(parser, namedAttr);
434 llvm::interleaveComma(op->
getAttrs(), parser,
436 printNamedAttr(parser, namedAttr);
448 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
452 printWithEnumHandling(parser, *
this);
456 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
460 printWithEnumHandling(parser, *
this);
464 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
468 printWithEnumHandling(parser, *
this);
472 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
476 printWithNanPropagationHandling(parser, *
this);
480 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
484 printWithNanPropagationHandling(parser, *
this);
488 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
492 printWithNanPropagationHandling(parser, *
this);
496 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
500 printWithNanPropagationHandling(parser, *
this);
504 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
508 printWithNanPropagationHandling(parser, *
this);
512 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
516 printWithNanPropagationHandling(parser, *
this);
520 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
524 printWithNanPropagationHandling(parser, *
this);
527ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
529 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
533 printWithEnumHandling(parser, *
this);
536ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
538 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
541void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
542 printWithEnumHandling(parser, *
this);
545ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
547 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
551 printWithEnumHandling(parser, *
this);
554ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
556 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
560 printWithEnumHandling(parser, *
this);
575 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
585 Value valZp, StringRef name) {
590 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
594 if (!bothInts || !sameBitWidth) {
596 <<
"expected " << name <<
" and " << name
597 <<
"_zp to both be integer of the same bitwidth, but got " << eType
598 <<
" vs. " << eZpType;
605 Value src, int32_t val) {
608 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
609 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
610 const auto padConstAttr{
611 llvm::isa<FloatType>(srcElemType)
616 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
620 if (dyn_cast<tosa::mxint8Type>(type))
629 const StringRef operandName,
630 const StringRef dimName) {
631 if (ShapedType::isDynamic(currDim)) {
634 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
636 << dimName <<
" of " << operandName <<
" to match size " << currDim
637 <<
", got " << newDim;
645 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
646 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
647 const llvm::StringRef padAfterName) {
648 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
653 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
654 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
656 if (!calculatedOutSizeMinusOne.has_value())
658 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
659 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
660 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
661 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
662 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
665 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
666 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
668 << dimName <<
" did not match expected: "
669 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
680 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
681 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
683 auto inputEType = inputType.getElementType();
684 auto weightEType = weightType.getElementType();
686 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
688 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
689 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
690 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
692 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
695 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
698 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
701 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
704 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
708 "expect both bias and result to have same element type, got ")
709 << biasEType <<
" and " << resultEType;
713 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
714 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
715 if (inputEType != weightEType) {
717 "expect both input and weight to have same element type, got ")
718 << inputEType <<
" and " << weightEType;
723 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
724 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
727 if (inputIsFloat != weightIsFloat) {
729 "expect both input and weight to be float or not together, got ")
730 << inputEType <<
" and " << weightEType;
735 if (inputEType != inputZpEType) {
736 return op.emitOpError(
"expect both input and its zero point are the same "
737 "element type, got ")
738 << inputEType <<
" and " << inputZpEType;
742 if (weightEType != weightZpEType) {
743 return op.emitOpError(
"expect both weight and its zero point are the same "
744 "element type, got ")
745 << weightEType <<
" and " << weightZpEType;
748 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
749 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
752 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
753 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
759LogicalResult tosa::ConstOp::verify() {
761 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
762 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
764 if (!attrType || !outputType) {
765 emitOpError(
"expected tensors for attr/result type");
769 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
770 outputType.getElementType())) {
775 if (attrType.getElementType() != outputType.getElementType()) {
776 emitOpError(
"expected same attr/result element types");
786 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
788 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
791 auto accType = op.getAccType();
792 if (inputEType.isInteger(8) && !accType.isInteger(32))
793 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
796 if (inputEType.isInteger(16) && !accType.isInteger(48))
797 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
800 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
801 !(accType.isF16() || accType.isF32()))
802 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
805 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
806 return op.emitOpError(
807 "accumulator type for f16 tensor is not f16/f32, got ")
810 if (inputEType.isBF16() && !accType.isF32())
811 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
814 if (inputEType.isF32() && !accType.isF32())
815 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
819 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
821 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
835 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
836 return op.emitOpError(
"expect all padding values to be >= 0, got ")
840 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
841 return op.emitOpError(
"expect all stride values to be >= 1, got ")
845 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
846 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
849 const RankedTensorType outputType =
850 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
855 const RankedTensorType inputType =
856 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
857 const RankedTensorType weightType =
858 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
860 if (inputType && weightType) {
862 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
864 op, inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0],
"height",
"y",
"top",
"bottom")))
870 op, inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1],
"width",
"x",
"left",
"right")))
877 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
879 op, inputType.getDimSize(1), weightType.getDimSize(0),
880 outputType.getDimSize(1), padding[0], padding[1], strides[0],
881 dilations[0],
"height",
"y",
"top",
"bottom")))
885 op, inputType.getDimSize(2), weightType.getDimSize(1),
886 outputType.getDimSize(2), padding[2], padding[3], strides[1],
887 dilations[1],
"width",
"x",
"left",
"right")))
892 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
894 op, inputType.getDimSize(1), weightType.getDimSize(1),
895 outputType.getDimSize(1), padding[0], padding[1], strides[0],
896 dilations[0],
"depth",
"d",
"front",
"back")))
900 op, inputType.getDimSize(2), weightType.getDimSize(2),
901 outputType.getDimSize(2), padding[2], padding[3], strides[1],
902 dilations[1],
"height",
"y",
"top",
"bottom")))
906 op, inputType.getDimSize(3), weightType.getDimSize(3),
907 outputType.getDimSize(3), padding[4], padding[5], strides[2],
908 dilations[2],
"width",
"x",
"left",
"right")))
913 const RankedTensorType biasType =
914 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
919 const int64_t biasChannels = biasType.getDimSize(0);
921 outputType.getDimSize(outputType.getRank() - 1);
922 if (biasChannels == ShapedType::kDynamic ||
923 outputChannels == ShapedType::kDynamic)
927 if (biasChannels != outputChannels && biasChannels != 1)
928 return op.emitOpError(
929 "bias channels expected to be equal to output channels (")
930 << outputChannels <<
") or 1, got " << biasChannels;
937 StringRef name1,
Type type2,
939 auto shapeType1 = dyn_cast<ShapedType>(type1);
940 auto shapeType2 = dyn_cast<ShapedType>(type2);
941 if (!shapeType1 || !shapeType2)
944 auto elemType1 = shapeType1.getElementType();
945 auto elemType2 = shapeType2.getElementType();
946 if (elemType1 != elemType2)
948 <<
"require same element type for " << name1 <<
" (" << elemType1
949 <<
") and " << name2 <<
" (" << elemType2 <<
")";
953 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
954 << name2 <<
" (" << type2 <<
")";
964 if (list1.size() != list2.size())
966 <<
"require same number of values in " << name1 <<
" ("
967 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
969 for (
auto [type1, type2] :
989 op->template getParentWithTrait<OpTrait::SymbolTable>();
996 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1000 return op->emitOpError(
"'")
1001 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1013template <
typename T>
1015 StringRef aName =
"input",
1016 StringRef bName =
"output") {
1017 auto aTType = llvm::dyn_cast<TensorType>(aType);
1018 auto bTType = llvm::dyn_cast<TensorType>(bType);
1020 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1024 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1027 auto aElementType = aTType.getElementType();
1028 auto bElementType = bTType.getElementType();
1030 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1032 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1033 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1034 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1035 aElementType != bElementType) {
1040 op.emitOpError(
"expect ")
1041 << aName <<
" and " << bName <<
" to have same element type, got "
1042 << aElementType <<
" and " << bElementType;
1048LogicalResult tosa::ArgMaxOp::verify() {
1049 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1052 if (
const auto resultETy = resultType.getElementType();
1053 !resultETy.isIntOrIndex())
1054 return emitOpError(
"result tensor is not of integer type");
1056 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1057 if (!inputType.hasRank())
1061 const int64_t axis = getAxisAttr().getInt();
1062 if (((axis < 0) || axis >= inputType.getRank()))
1063 return emitOpError(
"specified axis is outside the rank of the tensor");
1065 if (!resultType.hasRank())
1071 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1074 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1079template <
typename T>
1082 if (llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1083 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1087 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1088 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1092 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1093 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1097 const int64_t kernelX = kernel[1];
1098 const int64_t padLeft = padding[2];
1099 const int64_t padRight = padding[3];
1100 if (padRight >= kernelX || padLeft >= kernelX)
1101 return op.emitOpError(
"expected left/right padding to be less than the "
1102 "width of the kernel, got pad_left=")
1103 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1105 const int64_t kernelY = kernel[0];
1106 const int64_t padTop = padding[0];
1107 const int64_t padBottom = padding[1];
1108 if (padTop >= kernelY || padBottom >= kernelY)
1109 return op.emitOpError(
"expected top/bottom padding to be less than the "
1110 "height of the kernel, got pad_top=")
1111 << padTop <<
", pad_bottom=" << padBottom
1112 <<
", kernel_y=" << kernelY;
1114 const auto inputType =
1115 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1116 const auto outputType =
1117 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1118 if (!inputType || !outputType)
1121 const auto verifyOutputSize =
1125 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1126 const llvm::StringRef padBeforeName,
1127 const llvm::StringRef padAfterName) -> LogicalResult {
1128 if (ShapedType::isDynamic(inputSize))
1131 const std::optional<int64_t> calculatedOutSizeMinusOne =
1132 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1133 if (!calculatedOutSizeMinusOne.has_value())
1134 return op.emitOpError(
"expected input_")
1135 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1136 << padAfterName <<
" - kernel_" << dimAxis
1137 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1138 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1139 << kernelSize <<
") / " << strideSize;
1141 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1142 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1143 return op.emitOpError(
"calculated output ")
1144 << dimName <<
" did not match expected: " <<
"calculated="
1145 << calculatedOutSize <<
", expected=" << outputSize;
1150 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1151 kernel[0], strides[0], padding[0], padding[1],
1152 "height",
"y",
"top",
"bottom")))
1155 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1156 kernel[1], strides[1], padding[2], padding[3],
1157 "width",
"x",
"left",
"right")))
1163LogicalResult tosa::AvgPool2dOp::verify() {
1172 auto accType = getAccType();
1173 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1174 return emitOpError(
"accumulator type for integer tensor is not i32");
1176 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1177 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1179 if (inputETy.
isBF16() && !accType.isF32())
1180 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1182 if (inputETy.
isF32() && !accType.isF32())
1183 return emitOpError(
"accumulator type for f32 tensor is not f32");
1185 if (inputETy != inputZpETy)
1186 return emitOpError(
"expect both input and its zero point are the same "
1187 "element type, got ")
1188 << inputETy <<
" and " << inputZpETy;
1190 if (resultETy != outputZpETy)
1191 return emitOpError(
"expect both output and its zero point are the same "
1192 "element type, got ")
1193 << resultETy <<
" and " << outputZpETy;
1195 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1196 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1199 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1200 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1206LogicalResult tosa::ClampOp::verify() {
1208 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1209 if (
auto quantType =
1210 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1214 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1215 if (
auto quantType =
1216 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1219 if (inputETy != outputETy)
1220 return emitOpError(
"input/output element types are incompatible.");
1222 auto maxValAttr = getMaxValAttr();
1223 auto minValAttr = getMinValAttr();
1227 if (inputETy.
isInteger(dataTypeBitWidth)) {
1231 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1232 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1233 if (!intMaxValAttr || !intMinValAttr ||
1234 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1235 (intMaxValAttr.getType() != inputETy))
1236 return emitOpError(
"min/max attributes types are incompatible with "
1237 "input/output element types.");
1240 const bool isBoolean = inputETy.
isInteger(1);
1241 const APInt minVal = intMinValAttr.getValue();
1242 const APInt maxVal = intMaxValAttr.getValue();
1243 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1244 return emitOpError(
"expected min_val <= max_val, got min_val=")
1245 << minValAttr <<
", max_val=" << maxValAttr;
1250 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1251 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1252 if (!floatMaxValAttr || !floatMinValAttr ||
1253 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1254 (floatMaxValAttr.getType() != inputETy))
1255 return emitOpError(
"min/max attributes types are incompatible with "
1256 "input/output element types.");
1258 const APFloat minVal = floatMinValAttr.getValue();
1259 const APFloat maxVal = floatMaxValAttr.getValue();
1260 if (minVal.isNaN() || maxVal.isNaN())
1261 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1262 << minValAttr <<
", max_val=" << maxValAttr;
1264 if (maxVal < minVal)
1265 return emitOpError(
"expected min_val <= max_val, got min_val=")
1266 << minValAttr <<
", max_val=" << maxValAttr;
1286 result.addOperands({input, weight, bias, zps.first, zps.second});
1287 result.addAttribute(
"pad", pad);
1288 result.addAttribute(
"stride", stride);
1289 result.addAttribute(
"dilation", dilation);
1290 result.addAttribute(
"acc_type", accType);
1291 Type finalOutputType = outputType;
1297 result.addTypes(finalOutputType);
1308 result.addOperands({input, weight, bias, zps.first, zps.second});
1309 result.addAttribute(
"out_pad", outpad);
1310 result.addAttribute(
"stride", stride);
1311 result.addAttribute(
"acc_type", accType);
1312 Type finalOutputType = outputType;
1318 result.addTypes(finalOutputType);
1329 result.addOperands({a,
b, zps.first, zps.second});
1331 Type finalOutputType{outputType};
1334 auto inputBits = eType.getIntOrFloatBitWidth();
1336 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1337 assert(outputShapedType &&
"Output must be a shaped type");
1339 IntegerType accElementType;
1340 if (inputBits == 16)
1345 finalOutputType = outputShapedType.clone(accElementType);
1347 result.addTypes(finalOutputType);
1356 DenseArrayAttr kernel, DenseArrayAttr stride,
1357 DenseArrayAttr pad, TypeAttr accType) {
1362 if (
auto quantAttr =
1364 inputZp = quantAttr.getInputZp();
1365 outputZp = quantAttr.getOutputZp();
1367 const std::optional<Value> inputZpOp =
1372 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1374 const std::optional<Value> outputZpOp =
1377 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1378 "quantized AVG_POOL2D op");
1381 if (inputZpOp && outputZpOp) {
1382 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1387 result.addOperands({input});
1389 result.addAttribute(
"kernel", kernel);
1390 result.addAttribute(
"stride", stride);
1391 result.addAttribute(
"pad", pad);
1392 result.addAttribute(
"acc_type", accType);
1393 result.types.push_back(outputType);
1407 input1Zp = quantAttr.getInputZp();
1408 outputZp = quantAttr.getOutputZp();
1410 const std::optional<Value> input1ZpOp =
1414 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1417 const std::optional<Value> outputZpOp =
1421 loc,
"Failed to create output zero point for quantized NEGATE op");
1424 if (input1ZpOp && outputZpOp) {
1425 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1430 result.addOperands({input});
1433 result.types.push_back(outputType);
1446 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1449 result.addOperands({input, paddings, padConstOp});
1450 result.types.push_back(outputType);
1454 StringRef name,
Type variableType,
1459 auto shapedType = dyn_cast<ShapedType>(variableType);
1461 (
void)
emitError(loc,
"variable type must be a shaped type");
1464 if (!shapedType.hasRank()) {
1465 (
void)
emitError(loc,
"variable type must be a ranked type");
1469 auto elementType = shapedType.getElementType();
1470 auto elementTypeAttr = TypeAttr::get(elementType);
1474 result.addAttribute(
"sym_name", nameAttr);
1475 result.addAttribute(
"var_shape", varShapeAttr);
1476 result.addAttribute(
"type", elementTypeAttr);
1477 result.addAttribute(
"initial_value", initialValue);
1487 for (
int i = 0, e = operands.size(); i != e; ++i) {
1489 if (!
shape.hasRank()) {
1494 outRank = std::max<int64_t>(outRank,
shape.getRank());
1497 outShape.resize(outRank, 1);
1499 for (
int i = 0, e = operands.size(); i != e; ++i) {
1501 auto rankDiff = outShape.size() -
shape.getRank();
1503 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1504 auto dim1 = outShape[i + rankDiff];
1505 auto dim2 =
shape.getDimSize(i);
1506 auto resolvedDim = dim1;
1510 }
else if (dim2 == 1) {
1512 }
else if (dim1 != dim2) {
1515 outShape[i + rankDiff] = resolvedDim;
1522LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1523 MLIRContext *context, ::std::optional<Location> location,
1524 ArgMaxOp::Adaptor adaptor,
1527 IntegerAttr axis = adaptor.getProperties().axis;
1528 int32_t axisVal = axis.getValue().getSExtValue();
1530 if (!inputShape.hasRank()) {
1536 outShape.reserve(inputShape.getRank() - 1);
1537 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1540 outShape.push_back(inputShape.getDimSize(i));
1547LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1548 MLIRContext *context, ::std::optional<Location> location,
1549 RFFT2dOp::Adaptor adaptor,
1551 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1553 if (!inputShape.hasRank())
1557 outputShape.resize(3, ShapedType::kDynamic);
1558 outputShape[0] = inputShape.getDimSize(0);
1559 outputShape[1] = inputShape.getDimSize(1);
1560 int64_t inWidth = inputShape.getDimSize(2);
1564 if (inWidth != ShapedType::kDynamic)
1565 outputShape[2] = inWidth / 2 + 1;
1574 const llvm::StringRef dimName) {
1575 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1578 << dimName <<
" to be a power of two, got " << dimSize;
1583LogicalResult tosa::RFFT2dOp::verify() {
1584 const auto outputTypes = getResultTypes();
1586 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1588 const auto inputType =
1589 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1593 const int64_t height = inputType.getDimSize(1);
1594 if (ShapedType::isStatic(height) &&
1598 const int64_t width = inputType.getDimSize(2);
1599 if (ShapedType::isStatic(width) &&
1603 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1609 outputType.getShape().drop_back())))
1610 return emitOpError(
"expected batch and height dimensions of input/output "
1611 "to match, got input=")
1612 << inputType <<
" output=" << outputType;
1615 const int64_t outputWidth = outputType.getDimSize(2);
1616 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1617 (outputWidth != (width / 2) + 1))
1619 "expected output width to be equal to input_width / 2 + 1, got ")
1625LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1626 MLIRContext *context, ::std::optional<Location> location,
1627 FFT2dOp::Adaptor adaptor,
1629 inferredReturnShapes.push_back(
1631 inferredReturnShapes.push_back(
1636LogicalResult tosa::FFT2dOp::verify() {
1637 const auto inputRealType =
1638 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1639 const auto inputImagType =
1640 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1641 if (!inputRealType || !inputImagType)
1644 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1645 return ShapedType::isDynamic(a) ? a :
b;
1648 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1649 inputImagType.getDimSize(1));
1650 if (ShapedType::isStatic(height) &&
1654 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1655 inputImagType.getDimSize(2));
1656 if (ShapedType::isStatic(width) &&
1663LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1664 MLIRContext *context, ::std::optional<Location> location,
1665 ConcatOp::Adaptor adaptor,
1668 const Properties &prop = adaptor.getProperties();
1669 int32_t axis = prop.axis.getValue().getSExtValue();
1671 bool hasRankedInput =
false;
1672 for (
auto operand : adaptor.getOperands()) {
1674 if (!operandShape.hasRank())
1678 if (!hasRankedInput)
1679 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1682 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1683 if (i == axis || operandShape.isDynamicDim(i))
1685 if (outputShape[i] == ShapedType::kDynamic)
1686 outputShape[i] = operandShape.getDimSize(i);
1687 if (outputShape[i] != operandShape.getDimSize(i))
1689 "Cannot concat tensors with different sizes"
1690 " on the non-axis dimension ",
1694 hasRankedInput =
true;
1697 if (adaptor.getInput1().empty())
1701 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1702 if (!hasRankedInput) {
1709 for (
auto operand : adaptor.getOperands()) {
1714 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1715 concatDimSize = ShapedType::kDynamic;
1719 concatDimSize += operandShape.getDimSize(axis);
1722 outputShape[axis] = concatDimSize;
1728LogicalResult tosa::ConcatOp::verify() {
1730 auto outType = getOutput().getType();
1734 if (inputList.empty())
1737 if (!llvm::all_of(inputList, [&](
auto input) {
1739 *
this, input.getType(), outType));
1744 const int32_t axis = getAxis();
1746 for (
const auto &input : inputList) {
1747 const Type inputType = input.getType();
1749 if (currShape.hasRank()) {
1750 firstRankedInputShape = currShape;
1752 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1753 return emitOpError(
"expect axis to be within range 0 < axis < "
1754 "rank(input1[firstRankedTensorIdx]), got ")
1760 const auto allOperandsHasRank = [](
const Value input) {
1763 if (llvm::all_of(inputList, allOperandsHasRank)) {
1766 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1768 const int64_t inputRank = inputShape.getRank();
1769 const size_t operandNum =
index + 1;
1772 if (inputRank != firstInputRank)
1774 "expect all operands to have the same rank, but got ")
1775 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1779 for (
int i = 0; i < inputRank; i++) {
1780 const int64_t inputDim = inputShape.getDimSize(i);
1782 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1783 inputShape.isDynamicDim(i))
1785 if (inputDim != firstInputDim)
1786 return emitOpError(
"expect all operand shapes to have the same sizes "
1787 "on non-axis dimensions, but got ")
1788 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1789 <<
" on operands 0 and " << operandNum;
1794 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1795 return emitOpError(
"expect output rank to match inputs rank, got ")
1796 << outputShape.getRank() <<
" vs " << firstInputRank;
1800 for (
const auto &input : inputList) {
1802 if (inputShape.isDynamicDim(axis)) {
1807 axisSum += inputShape.getDimSize(axis);
1810 if (axisSum >= 0 && outputShape.hasRank() &&
1811 !outputShape.isDynamicDim(axis) &&
1812 axisSum != outputShape.getDimSize(axis))
1813 return emitOpError(
"requires sum of axis dimensions of input1 "
1814 "equal to output axis dimension, got ")
1815 << axisSum <<
" and " << outputShape.getDimSize(axis);
1821LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1822 MLIRContext *context, ::std::optional<Location> location,
1826 auto elementType = IntegerType::get(context, 1);
1839 if (l.size() != r.size() || l.size() != 1)
1844LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1845 MLIRContext *context, ::std::optional<Location> location,
1846 MatMulOp::Adaptor adaptor,
1853 outShape.resize(3, ShapedType::kDynamic);
1855 if (lhsShape.hasRank()) {
1856 outShape[0] = lhsShape.getDimSize(0);
1857 outShape[1] = lhsShape.getDimSize(1);
1860 if (rhsShape.hasRank()) {
1861 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1863 outShape[2] = rhsShape.getDimSize(2);
1870LogicalResult MatMulOp::verify() {
1871 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1872 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1876 return emitOpError(
"expect a shaped tensor for input a, got ")
1877 << getA().getType();
1880 return emitOpError(
"expect a shaped tensor for input b, got ")
1881 << getB().getType();
1883 auto aElementType = aType.getElementType();
1884 auto bElementType = bType.getElementType();
1886 auto aQuantizedEType =
1887 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1888 auto bQuantizedEType =
1889 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1891 if (aQuantizedEType || bQuantizedEType) {
1892 if (!aQuantizedEType || !bQuantizedEType) {
1893 return emitOpError(
"expect operands to be both quantized or both not "
1895 << aElementType <<
" and " << bElementType;
1898 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1899 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1900 if (aQuantWidth != bQuantWidth) {
1901 return emitOpError(
"expect quantized operands to have same widths, got ")
1902 << aQuantWidth <<
" and " << bQuantWidth;
1909 if (aEType != aZpEType) {
1910 return emitOpError(
"expect input a and a_zp have the same "
1911 "element type, got ")
1912 << aEType <<
" and " << aZpEType;
1917 if (bEType != bZpEType) {
1918 return emitOpError(
"expect input b and b_zp have the same "
1919 "element type, got ")
1920 << bEType <<
" and " << bZpEType;
1923 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1924 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1927 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1928 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1934LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1935 MLIRContext *context, ::std::optional<Location> location,
1936 MatmulTBlockScaledOp::Adaptor adaptor,
1940 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1941 if (aDataShape.hasRank()) {
1942 outShape[0] = aDataShape.getDimSize(0);
1943 outShape[1] = aDataShape.getDimSize(1);
1946 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1947 if (aScaleShape.hasRank()) {
1948 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1950 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1955 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1956 if (bDataShape.hasRank()) {
1957 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1958 if (bDataBatchSize != 1)
1960 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1961 outShape[2] = bDataShape.getDimSize(1);
1964 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1965 if (bScaleShape.hasRank()) {
1966 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1967 if (bScaleBatchSize != 1)
1969 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1970 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1978LogicalResult MatmulTBlockScaledOp::verify() {
1980 const Type aDataType = getAData().getType();
1981 const Type bDataType = getBData().getType();
1987 int64_t N = ShapedType::kDynamic;
1988 int64_t D = ShapedType::kDynamic;
1989 int64_t H = ShapedType::kDynamic;
1992 int64_t multiplesOfC = ShapedType::kDynamic;
2004 "a_scale",
"batch")) ||
2006 "a_scale",
"height")))
2014 "b_data",
"batch")) ||
2016 "b_data",
"channels")))
2024 "b_scale",
"batch")) ||
2026 "b_scale",
"width")) ||
2034 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2035 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2037 << D <<
" vs N=" << N;
2040 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2041 if (ShapedType::isStatic(C) && C % blockSize != 0)
2042 return emitOpError(
"expect C to be a multiple of block size, got C=")
2043 <<
C <<
", block_size=" << blockSize;
2046 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2047 multiplesOfC != C / blockSize)
2049 "expect scale operands dimension 2 to equal C/block_size (")
2050 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2053 N = ShapedType::isDynamic(N) ? D : N;
2055 const auto outputType = cast<ShapedType>(getResult().
getType());
2056 if (outputType.hasRank() &&
2060 auto stringifyDim = [&](
int64_t d) {
2061 if (ShapedType::isDynamic(d))
2066 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2067 opError <<
" to be compatible with expected output shape ";
2068 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2075LogicalResult tosa::PadOp::inferReturnTypeComponents(
2076 MLIRContext *context, ::std::optional<Location> location,
2077 PadOp::Adaptor adaptor,
2079 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2081 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2086 if (!inputShape.hasRank()) {
2087 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2096 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2101 outputShape.reserve(inputShape.getRank());
2102 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2103 if (inputShape.isDynamicDim(i)) {
2104 outputShape.push_back(ShapedType::kDynamic);
2107 auto padFront = paddingValues[i * 2];
2108 auto padBack = paddingValues[i * 2 + 1];
2109 if (padFront < 0 || padBack < 0) {
2111 outputShape.push_back(ShapedType::kDynamic);
2115 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2122LogicalResult tosa::PadOp::verify() {
2129 if (
auto padConst = getPadConst()) {
2137 RankedTensorType inputType =
2138 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2139 RankedTensorType outputType =
2140 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2141 if (!inputType || !outputType)
2148 auto inputRank = inputType.getRank();
2154 auto paddingValues = paddingAttr.getValues<APInt>();
2155 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2156 return emitOpError() <<
"padding tensor must have " << inputRank
2157 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2158 << paddingValues.size();
2160 auto inputShape = inputType.getShape();
2161 auto outputShape = outputType.getShape();
2163 for (
int64_t i = 0; i < inputRank; ++i) {
2164 int64_t padStart = paddingValues[i * 2].getSExtValue();
2165 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2167 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2169 <<
"invalid padding values at dimension " << i
2170 <<
": values must be non-negative or -1 for dynamic padding, got ["
2171 << padStart <<
", " << padEnd <<
"]";
2175 if (inputShape[i] == ShapedType::kDynamic ||
2176 outputShape[i] == ShapedType::kDynamic)
2179 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2180 return emitOpError() <<
"mismatch in output shape at dimension " << i
2181 <<
": expected " << inputShape[i] <<
" + "
2182 << padStart <<
" + " << padEnd <<
" = "
2183 << (inputShape[i] + padStart + padEnd)
2184 <<
", but got " << outputShape[i];
2191LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2192 MLIRContext *context, ::std::optional<Location> location,
2193 SliceOp::Adaptor adaptor,
2202 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2210 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2213 if (inputShape.hasRank()) {
2214 for (
size_t i = 0; i < size.size(); i++) {
2215 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2216 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2217 start[i] < inputShape.getDimSize(i))) {
2219 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2222 outputShape[i] = size[i];
2226 if (size[i] == -1) {
2227 outputShape[i] = inputShape.getDimSize(i) - start[i];
2228 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2230 outputShape[i] = size[i];
2242LogicalResult tosa::SliceOp::verify() {
2249 if (inputShape.hasRank()) {
2250 const auto inputRank = inputShape.getRank();
2252 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2254 "expect input1 and output to have the same ranks, got ")
2255 << inputRank <<
" and " << outputShape.getRank();
2257 const auto startShapeRank =
2258 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2259 if (inputRank != startShapeRank)
2260 return emitOpError(
"length of start is not equal to rank of input shape");
2262 const auto sizeShapeRank =
2263 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2264 if (inputRank != sizeShapeRank)
2265 return emitOpError(
"length of size is not equal to rank of input shape");
2271LogicalResult tosa::MulOp::inferReturnTypeComponents(
2272 MLIRContext *context, ::std::optional<Location> location,
2287LogicalResult tosa::MulOp::verify() {
2288 const Value output = getOutput();
2293 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2294 IntegerType lhsIntType =
2296 IntegerType rhsIntType =
2298 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2299 return emitOpError(
"requires the same element type for all operands");
2304 if (lhsIntType.getWidth() > resIntType.getWidth())
2305 return emitOpError(
"invalid data type size for operands or result");
2310 for (
int i = 0; i < 2; ++i) {
2313 "requires the same element type for all operands and results");
2317 ElementsAttr shiftElem;
2319 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2321 return emitOpError() <<
"require shift to be 0 for float type";
2329 TypeRange operandTypes = getOperandTypes();
2330 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2331 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2333 const bool aHasRank = aType.hasRank();
2334 const bool bHasRank = bType.hasRank();
2335 if (aHasRank && bHasRank) {
2336 const int64_t aRank = aType.getRank();
2337 const int64_t bRank = bType.getRank();
2339 return emitOpError(
"a and b operands don't have matching ranks, got ")
2340 << aRank <<
" and " << bRank;
2345 aType.getShape(), bType.getShape(), resultShape))
2346 return emitOpError(
"a and b operands don't have broadcast-compatible "
2348 << aType <<
" and " << bType;
2351 ShapedType resultType = cast<ShapedType>(output.
getType());
2352 if (!resultType.hasRank())
2355 const int64_t resultRank = resultType.getRank();
2356 if (aHasRank && resultRank != aType.getRank())
2357 return emitOpError(
"result type has different rank than a, got ")
2358 << resultRank <<
" vs " << aType.getRank();
2359 if (bHasRank && resultRank != bType.getRank())
2360 return emitOpError(
"result type has different rank than b, got ")
2361 << resultRank <<
" vs " << bType.getRank();
2366LogicalResult tosa::TableOp::inferReturnTypeComponents(
2367 MLIRContext *context, ::std::optional<Location> location,
2368 TableOp::Adaptor adaptor,
2370 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2372 if (!inputShape.hasRank()) {
2377 inferredReturnShapes.resize(1);
2378 inputShape.getDims(inferredReturnShapes[0]);
2382LogicalResult tosa::TableOp::verify() {
2383 const TensorType inputType = getInput1().getType();
2384 const TensorType outputType = getOutput().getType();
2393 auto inputDims = inputType.
getShape();
2394 auto outputDims = outputType.
getShape();
2395 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2397 auto [inputDim, outputDim] = it.value();
2398 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2399 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2400 <<
" doesn't match dim(input, " << dim
2401 <<
") = " << inputDim;
2413 multiples = llvm::to_vector(
2414 llvm::map_range(multiplesAttr.getValues<APInt>(),
2415 [](
const APInt &val) { return val.getSExtValue(); }));
2419LogicalResult tosa::TileOp::inferReturnTypeComponents(
2420 MLIRContext *context, ::std::optional<Location> location,
2421 TileOp::Adaptor adaptor,
2428 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2436 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2438 if (!inputShape.hasRank()) {
2439 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2440 inferredReturnShapes.push_back(
2443 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2447 outputShape.reserve(multiples.size());
2448 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2449 if (multiples[i] == ShapedType::kDynamic) {
2450 outputShape.push_back(ShapedType::kDynamic);
2452 int64_t dim = inputShape.getDimSize(i);
2453 if (dim != ShapedType::kDynamic)
2454 dim *= multiples[i];
2455 outputShape.push_back(dim);
2463LogicalResult tosa::TileOp::verify() {
2469 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2470 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2472 shapeType multiplesType =
2473 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2475 auto multiplesRank = multiplesType.getRank();
2477 if (inputType.hasRank()) {
2478 if (inputType.getRank() != multiplesRank)
2479 return emitOpError(
"expect 'multiples' to have rank ")
2480 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2481 if (outputType.hasRank() &&
2485 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2486 return emitOpError(
"expect 'multiples' array to have length ")
2487 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2490 if (getConstantMultiples(multiples).succeeded() &&
2491 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2493 "expect element of 'multiples' to be positive integer or -1.");
2499 if (l.size() != r.size() || l.size() != 1)
2504LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2505 MLIRContext *context, ::std::optional<Location> location,
2506 ReshapeOp::Adaptor adaptor,
2508 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2513 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2523 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2524 inferredReturnShapes.push_back(
2532 int64_t numElements = inputShape.getNumElements();
2534 for (
auto val : newShapeValue) {
2535 if (ShapedType::isStatic(val)) {
2541 for (
auto &val : newShapeValue) {
2542 if (ShapedType::isDynamic(val))
2543 val = numElements / staticMul;
2546 inferredReturnShapes.push_back(
2551llvm::LogicalResult tosa::ReshapeOp::verify() {
2557 TensorType inputType = getInput1().getType();
2562 return mlir::success();
2565 int missingDims = llvm::count(shapeValues, -1);
2566 if (missingDims > 1)
2567 return emitOpError() <<
"expected at most one target dimension to be -1";
2569 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2573 if ((
int64_t)shapeValues.size() != outputType.getRank())
2574 return emitOpError() <<
"new shape does not match result rank";
2576 for (
auto [newShapeDim, outputShapeDim] :
2577 zip(shapeValues, outputType.getShape())) {
2578 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2579 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2580 return emitOpError() <<
"new shape is inconsistent with result shape";
2582 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2583 return emitOpError() <<
"new shape has invalid tensor dimension size "
2587 if (inputType.hasStaticShape()) {
2588 int64_t inputElementsNum = inputType.getNumElements();
2589 if (outputType.hasStaticShape()) {
2590 int64_t outputElementsNum = outputType.getNumElements();
2591 if (inputElementsNum != outputElementsNum) {
2592 return emitOpError() <<
"cannot reshape " << inputElementsNum
2593 <<
" elements into " << outputElementsNum;
2599 return (dim > 0) ?
acc * dim :
acc;
2601 bool isStaticNewShape =
2602 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2603 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2604 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2605 return emitOpError() <<
"cannot reshape " << inputElementsNum
2606 <<
" elements into " << newShapeElementsNum;
2610 return mlir::success();
2617 ElementsAttr zpAttr;
2622 Type zpElemType = zpAttr.getElementType();
2624 if (llvm::isa<FloatType>(zpElemType)) {
2625 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2632 if (llvm::isa<IntegerType>(zpElemType)) {
2634 return zpAttr.getValues<APInt>()[0].getSExtValue();
2636 return zpAttr.getValues<APInt>()[0].getZExtValue();
2643template <
typename T>
2645 const std::string &operand) {
2648 if (!zpElemType.
isInteger(8) && zp != 0) {
2650 std::string lower = operand;
2651 llvm::transform(lower, lower.begin(), ::tolower);
2652 return op.emitOpError()
2653 << lower <<
" zero point must be zero for non-int8 integer types";
2661 const std::string &operand) {
2662 bool isInputZp = (operand ==
"Input");
2664 bool tensorUnsigned =
2665 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2666 StringRef tensorName = isInputZp ?
"input" :
"output";
2672 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2673 return op.emitOpError()
2674 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2676 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2677 return op.emitOpError() <<
"expect " << tensorName
2678 <<
"_zp of 0 or 32768 for unsigned int16 "
2679 << tensorName <<
", got " << zp;
2686#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2687 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2688 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2690 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2691 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2710#undef ZERO_POINT_HELPER
2712LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2713 MLIRContext *context, ::std::optional<Location> location,
2714 TransposeOp::Adaptor adaptor,
2716 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2725 const auto inputRank = inputShape.
getRank();
2729 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2735 if (inputRank == 0) {
2741 bool allTheSame =
true;
2742 for (
int i = 1, s = inputRank; i < s; i++) {
2752 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2757 outputShape.resize(inputRank, ShapedType::kDynamic);
2760 if (llvm::any_of(adaptor.getPerms(),
2761 [inputRank](
const auto i) { return i >= inputRank; }))
2764 outputShape.reserve(inputRank);
2765 for (
int i = 0, s = inputRank; i < s; i++) {
2766 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2773LogicalResult tosa::TransposeOp::verify() {
2785 if (inputShape.hasRank() &&
2786 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2787 return emitOpError() <<
"expected perms attribute to have size "
2788 << inputShape.getRank()
2789 <<
" (input rank) but got size "
2790 << constantPerms.size();
2792 if (inputShape.hasRank() && outputShape.hasRank() &&
2793 inputShape.getRank() != outputShape.getRank())
2795 <<
"expected input tensor rank to equal result tensor rank";
2797 if (outputShape.hasRank() &&
2798 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2799 return emitOpError() <<
"expected perms attribute to have size "
2800 << outputShape.getRank()
2801 <<
" (output rank) but got size "
2802 << constantPerms.size();
2804 if (!llvm::all_of(constantPerms,
2805 [&constantPerms](int32_t s) {
2807 static_cast<size_t>(s) < constantPerms.size();
2810 constantPerms, [](int32_t v) ->
int64_t {
return v; }))))
2811 return emitOpError() <<
"expected valid permutation indices";
2814 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2815 inputShape.getNumElements() != outputShape.getNumElements())
2816 return emitOpError() <<
"expected input1 and output to have same numbers "
2818 << inputShape.getNumElements() <<
" and "
2819 << outputShape.getNumElements();
2823 if (inputShape.hasRank() && outputShape.hasRank()) {
2824 for (
auto i = 0; i < outputShape.getRank(); i++) {
2825 if (inputShape.isDynamicDim(constantPerms[i]) ||
2826 outputShape.isDynamicDim(i))
2829 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2831 <<
"expected output tensor dim " << i <<
" to match "
2832 <<
"input dim " << constantPerms[i] <<
" with value of "
2833 << inputShape.getDimSize(constantPerms[i]);
2840LogicalResult TransposeOp::reifyResultShapes(
2843 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2845 Value input = getInput1();
2846 auto inputType = cast<TensorType>(input.
getType());
2848 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2849 for (
auto dim : transposePerms) {
2850 int32_t dimInInput = transposePerms[dim];
2851 if (inputType.isDynamicDim(dimInInput))
2853 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2857 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2860 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2864LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2865 MLIRContext *context, ::std::optional<Location> location,
2866 GatherOp::Adaptor adaptor,
2867 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2868 llvm::SmallVector<int64_t> outputShape;
2869 outputShape.resize(3, ShapedType::kDynamic);
2871 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2872 if (valuesShape.hasRank()) {
2873 outputShape[0] = valuesShape.getDimSize(0);
2874 outputShape[2] = valuesShape.getDimSize(2);
2877 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2878 if (indicesShape.hasRank()) {
2879 if (outputShape[0] == ShapedType::kDynamic)
2880 outputShape[0] = indicesShape.getDimSize(0);
2881 if (outputShape[1] == ShapedType::kDynamic)
2882 outputShape[1] = indicesShape.getDimSize(1);
2885 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2889LogicalResult tosa::GatherOp::verify() {
2896 const ShapeAdaptor valuesShape(getValues().
getType());
2898 const ShapeAdaptor outputShape(getOutput().
getType());
2900 int64_t n = ShapedType::kDynamic;
2901 int64_t w = ShapedType::kDynamic;
2902 int64_t c = ShapedType::kDynamic;
2904 if (valuesShape.hasRank()) {
2905 n = valuesShape.getDimSize(0);
2906 c = valuesShape.getDimSize(2);
2908 if (indicesShape.hasRank()) {
2909 const int64_t indicesN = indicesShape.getDimSize(0);
2910 w = indicesShape.getDimSize(1);
2911 if (n == ShapedType::kDynamic)
2913 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2914 return emitOpError() <<
"requires indices dimension 0 to have size " << n
2915 <<
", got " << indicesN;
2917 if (outputShape.hasRank()) {
2918 const int64_t outputN = outputShape.getDimSize(0);
2919 const int64_t outputW = outputShape.getDimSize(1);
2920 const int64_t outputC = outputShape.getDimSize(2);
2921 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2923 return emitOpError() <<
"requires output dimension 0 to have size " << n
2924 <<
", got " << outputN;
2926 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2928 return emitOpError() <<
"requires output dimension 1 to have size " << w
2929 <<
", got " << outputW;
2930 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2932 return emitOpError() <<
"requires output dimension 2 to have size " << c
2933 <<
", got " << outputC;
2938LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2939 MLIRContext *context, ::std::optional<Location> location,
2940 ResizeOp::Adaptor adaptor,
2941 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2942 llvm::SmallVector<int64_t, 4> outputShape;
2943 outputShape.resize(4, ShapedType::kDynamic);
2945 ShapeAdaptor inputShape(adaptor.getInput().getType());
2946 if (!inputShape.hasRank())
2949 outputShape[0] = inputShape.getDimSize(0);
2950 outputShape[3] = inputShape.getDimSize(3);
2951 int64_t inputHeight = inputShape.getDimSize(1);
2952 int64_t inputWidth = inputShape.getDimSize(2);
2954 if ((inputHeight == ShapedType::kDynamic) ||
2955 (inputWidth == ShapedType::kDynamic))
2958 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2969 const int64_t outputHeight =
2970 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2974 const int64_t outputWidth =
2975 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2979 if (outputHeight < 0 || outputWidth < 0) {
2982 "calculated output height and width must be non-negative, "
2984 outputHeight,
", width = ", outputWidth);
2987 outputShape[1] = outputHeight;
2988 outputShape[2] = outputWidth;
2989 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2993LogicalResult tosa::ResizeOp::verify() {
2994 const Value input = getInput();
2995 const Value output = getOutput();
2996 const RankedTensorType inputType =
2997 llvm::dyn_cast<RankedTensorType>(input.
getType());
2998 const RankedTensorType outputType =
2999 llvm::dyn_cast<RankedTensorType>(output.
getType());
3001 SmallVector<int64_t> scaleValues;
3002 SmallVector<int64_t> offsetValues;
3003 SmallVector<int64_t> borderValues;
3011 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3012 return emitOpError(
"expect all scale values to be > 0, got ")
3015 const int64_t scaleYN = scaleValues[0];
3016 const int64_t scaleYD = scaleValues[1];
3017 const int64_t scaleXN = scaleValues[2];
3018 const int64_t scaleXD = scaleValues[3];
3020 const int64_t offsetY = offsetValues[0];
3021 const int64_t offsetX = offsetValues[1];
3023 const int64_t borderY = borderValues[0];
3024 const int64_t borderX = borderValues[1];
3031 const int64_t oh = outputType.getDimSize(1);
3032 const int64_t ow = outputType.getDimSize(2);
3033 const int64_t ih = inputType.getDimSize(1);
3034 const int64_t iw = inputType.getDimSize(2);
3040 if (ih != ShapedType::kDynamic && ih != 1) {
3041 const std::optional<int64_t> calculatedOutHeightMinusOne =
3042 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3043 if (!calculatedOutHeightMinusOne.has_value())
3044 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3046 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3047 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3048 <<
") / " << scaleYD;
3049 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3050 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3051 return emitOpError(
"calculated output height did not match expected: ")
3052 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3059 if (iw != ShapedType::kDynamic && iw != 1) {
3060 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3061 const std::optional<int64_t> calculatedOutWidthMinusOne =
3063 if (!calculatedOutWidthMinusOne.has_value())
3064 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3066 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3067 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3068 <<
") / " << scaleXD;
3069 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3070 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3071 return emitOpError(
"calculated output width did not match expected: ")
3072 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3078LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3079 MLIRContext *context, ::std::optional<Location> location,
3080 ScatterOp::Adaptor adaptor,
3081 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3082 llvm::SmallVector<int64_t> outputShape;
3083 outputShape.resize(3, ShapedType::kDynamic);
3085 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3086 if (valuesInShape.hasRank()) {
3087 outputShape[0] = valuesInShape.getDimSize(0);
3088 outputShape[1] = valuesInShape.getDimSize(1);
3089 outputShape[2] = valuesInShape.getDimSize(2);
3092 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3093 if (indicesShape.hasRank()) {
3094 if (outputShape[0] == ShapedType::kDynamic)
3095 outputShape[0] = indicesShape.getDimSize(0);
3098 ShapeAdaptor inputShape(adaptor.getInput().getType());
3099 if (inputShape.hasRank()) {
3100 if (outputShape[0] == ShapedType::kDynamic)
3101 outputShape[0] = inputShape.getDimSize(0);
3102 if (outputShape[2] == ShapedType::kDynamic)
3103 outputShape[2] = inputShape.getDimSize(2);
3106 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3110LogicalResult tosa::ScatterOp::verify() {
3120 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3122 const ShapeAdaptor inputShape(getInput().
getType());
3123 const ShapeAdaptor outputShape(getValuesOut().
getType());
3125 int64_t n = ShapedType::kDynamic;
3126 int64_t k = ShapedType::kDynamic;
3127 int64_t w = ShapedType::kDynamic;
3128 int64_t c = ShapedType::kDynamic;
3129 if (valuesInShape.hasRank()) {
3130 n = valuesInShape.getDimSize(0);
3131 k = valuesInShape.getDimSize(1);
3132 c = valuesInShape.getDimSize(2);
3134 if (indicesShape.hasRank()) {
3135 const int64_t indicesN = indicesShape.getDimSize(0);
3136 w = indicesShape.getDimSize(1);
3137 if (n == ShapedType::kDynamic)
3139 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3140 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3141 <<
", got " << indicesN;
3143 if (inputShape.hasRank()) {
3144 const int64_t inputN = inputShape.getDimSize(0);
3145 const int64_t inputW = inputShape.getDimSize(1);
3146 const int64_t inputC = inputShape.getDimSize(2);
3147 if (n == ShapedType::kDynamic)
3149 else if (inputN != ShapedType::kDynamic && n != inputN)
3150 return emitOpError() <<
"requires input dimension 0 to have size " << n
3151 <<
", got " << inputN;
3152 if (w == ShapedType::kDynamic)
3154 else if (inputW != ShapedType::kDynamic && w != inputW)
3155 return emitOpError() <<
"requires input dimension 1 to have size " << w
3156 <<
", got " << inputW;
3158 if (c == ShapedType::kDynamic)
3160 else if (inputC != ShapedType::kDynamic && c != inputC)
3161 return emitOpError() <<
"requires input dimension 2 to have size " << c
3162 <<
", got " << inputC;
3164 if (outputShape.hasRank()) {
3165 const int64_t outputN = outputShape.getDimSize(0);
3166 const int64_t outputK = outputShape.getDimSize(1);
3167 const int64_t outputC = outputShape.getDimSize(2);
3168 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3170 return emitOpError() <<
"requires values_out dimension 0 to have size "
3171 << n <<
", got " << outputN;
3172 if (k == ShapedType::kDynamic)
3174 else if (outputK != ShapedType::kDynamic && k != outputK)
3175 return emitOpError() <<
"requires values_out dimension 1 to have size "
3176 << k <<
", got " << outputK;
3177 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3179 return emitOpError() <<
"requires values_out dimension 2 to have size "
3180 << c <<
", got " << outputC;
3182 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3183 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3192 int64_t axisVal = axis.getValue().getSExtValue();
3193 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3199 operandShape.
getDims(outputShape);
3200 outputShape[axisVal] = 1;
3205#define COMPATIBLE_RETURN_TYPES(OP) \
3206 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3207 if (l.size() != r.size() || l.size() != 1) \
3209 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3211 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3214#define REDUCE_SHAPE_INFER(OP) \
3215 LogicalResult OP::inferReturnTypeComponents( \
3216 MLIRContext *context, ::std::optional<Location> location, \
3217 OP::Adaptor adaptor, \
3218 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3220 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3221 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3222 const Properties &prop = adaptor.getProperties(); \
3223 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3224 inferredReturnShapes); \
3226 COMPATIBLE_RETURN_TYPES(OP)
3234#undef REDUCE_SHAPE_INFER
3236#undef COMPATIBLE_RETURN_TYPES
3238template <
typename T>
3241 TensorType inputType = op.getInput().getType();
3242 TensorType outputType = op.getOutput().getType();
3243 int32_t reduceAxis = op.getAxis();
3245 if (reduceAxis < 0) {
3246 op.emitOpError(
"reduce axis must not be negative");
3250 int64_t inputRank = inputType.getRank();
3253 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3254 op.emitOpError(
"expect input tensor rank (")
3255 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3261 int64_t outputRank = outputType.getRank();
3262 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3264 "expect output tensor rank to be equal to input tensor rank");
3267 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3268 op.emitOpError(
"expect output tensor rank (")
3269 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3275 if (outputRank != 0) {
3276 auto outputShape = outputType.
getShape();
3277 if (!outputType.isDynamicDim(reduceAxis) &&
3278 outputShape[reduceAxis] != 1) {
3279 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3280 << outputShape[reduceAxis];
3288LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3289LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3290LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3291LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3292LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3293LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3307#define NARY_SHAPE_INFER(OP) \
3308 LogicalResult OP::inferReturnTypeComponents( \
3309 MLIRContext *context, ::std::optional<Location> location, \
3310 ValueShapeRange operands, DictionaryAttr attributes, \
3311 OpaqueProperties properties, RegionRange regions, \
3312 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3313 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3353#undef PRED_SHAPE_INFER
3355LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3356 MLIRContext *context, ::std::optional<Location> location,
3357 NegateOp::Adaptor adaptor,
3359 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3364LogicalResult tosa::NegateOp::verify() {
3366 const Type input1Type = getInput1().getType();
3367 const Type outputType = getOutput().getType();
3372 const SmallVector<Type, 2> types = {input1Type, outputType};
3374 return emitOpError() <<
"requires the same shape for input1 and output";
3377 const Type input1ZpEType =
3379 if (input1EType != input1ZpEType) {
3380 return emitOpError(
"expect both input1 and its zero point are the same "
3381 "element type, got ")
3382 << input1EType <<
" and " << input1ZpEType;
3385 const Type outputZpEType =
3387 if (outputEType != outputZpEType) {
3388 return emitOpError(
"expect both output and its zero point are the same "
3389 "element type, got ")
3390 << outputEType <<
" and " << outputZpEType;
3393 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3394 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3397 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3398 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3409 outputShape.resize(4, ShapedType::kDynamic);
3424 if (ShapedType::isStatic(height)) {
3425 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3426 outputShape[1] = padded / stride[0] + 1;
3429 if (ShapedType::isStatic(width)) {
3430 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3431 outputShape[2] = padded / stride[1] + 1;
3438LogicalResult Conv2DOp::inferReturnTypeComponents(
3439 MLIRContext *context, ::std::optional<Location> location,
3440 Conv2DOp::Adaptor adaptor,
3441 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3442 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3444 int64_t inputWidth = ShapedType::kDynamic;
3445 int64_t inputHeight = ShapedType::kDynamic;
3446 int64_t weightWidth = ShapedType::kDynamic;
3447 int64_t weightHeight = ShapedType::kDynamic;
3451 ShapeAdaptor inputShape(adaptor.getInput().getType());
3452 if (inputShape.hasRank()) {
3453 outputShape[0] = inputShape.getDimSize(0);
3454 inputHeight = inputShape.getDimSize(1);
3455 inputWidth = inputShape.getDimSize(2);
3459 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3460 if (weightShape.hasRank()) {
3461 outputShape[3] = weightShape.getDimSize(0);
3462 weightHeight = weightShape.getDimSize(1);
3463 weightWidth = weightShape.getDimSize(2);
3467 ShapeAdaptor biasShape(adaptor.getBias().getType());
3468 if (biasShape.hasRank()) {
3469 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3470 ? biasShape.getDimSize(0)
3474 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3475 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3476 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3478 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3479 int64_t inputSize = inputHeight + padding[0] + padding[1];
3480 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3481 int64_t unstridedResult = inputSize - filterSize + 1;
3482 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3485 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3486 int64_t inputSize = inputWidth + padding[2] + padding[3];
3487 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3488 int64_t unstridedResult = inputSize - filterSize + 1;
3489 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3492 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3496LogicalResult Conv2DOp::verify() {
3503LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3504 MLIRContext *context, ::std::optional<Location> location,
3505 Conv2DBlockScaledOp::Adaptor adaptor,
3506 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3507 SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
3509 int64_t inputWidth = ShapedType::kDynamic;
3510 int64_t inputHeight = ShapedType::kDynamic;
3511 int64_t weightWidth = ShapedType::kDynamic;
3512 int64_t weightHeight = ShapedType::kDynamic;
3515 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3516 if (inputDataShape.hasRank()) {
3517 outShape[0] = inputDataShape.getDimSize(0);
3518 inputHeight = inputDataShape.getDimSize(1);
3519 inputWidth = inputDataShape.getDimSize(2);
3521 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3522 if (inputScaleShape.hasRank()) {
3523 outShape[0] = ShapedType::isDynamic(outShape[0])
3524 ? inputScaleShape.getDimSize(0)
3526 inputHeight = ShapedType::isDynamic(inputHeight)
3527 ? inputScaleShape.getDimSize(1)
3529 inputWidth = ShapedType::isDynamic(inputWidth)
3530 ? inputScaleShape.getDimSize(2)
3535 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3536 if (weightDataShape.hasRank()) {
3537 outShape[3] = weightDataShape.getDimSize(0);
3538 weightHeight = weightDataShape.getDimSize(1);
3539 weightWidth = weightDataShape.getDimSize(2);
3541 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3542 if (weightScaleShape.hasRank()) {
3543 outShape[3] = ShapedType::isDynamic(outShape[3])
3544 ? weightScaleShape.getDimSize(0)
3546 weightHeight = ShapedType::isDynamic(weightHeight)
3547 ? weightScaleShape.getDimSize(1)
3549 weightWidth = ShapedType::isDynamic(weightWidth)
3550 ? weightScaleShape.getDimSize(2)
3555 const ShapeAdaptor biasShape(adaptor.getBias().getType());
3556 if (biasShape.hasRank()) {
3557 const int64_t biasSize = biasShape.getDimSize(0);
3559 if (biasSize != 1) {
3560 outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
3564 SmallVector<int64_t> padValues;
3565 SmallVector<int64_t> strideValues;
3566 SmallVector<int64_t> dilationValues;
3572 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3576 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3577 const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
3578 const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
3579 const int64_t unstridedResult = inputSize - filterSize + 1;
3580 outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
3583 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3584 const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
3585 const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
3586 const int64_t unstridedResult = inputSize - filterSize + 1;
3587 outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
3590 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3594LogicalResult Conv2DBlockScaledOp::verify() {
3596 getWeightData().
getType(),
"input_data",
3599 getWeightScale().
getType(),
"input_scale",
3602 getOutput().
getType(),
"bias",
"output")))
3606 int64_t N = ShapedType::kDynamic;
3607 int64_t IH = ShapedType::kDynamic;
3608 int64_t IW = ShapedType::kDynamic;
3609 int64_t IC = ShapedType::kDynamic;
3610 int64_t multiplesOfIC = ShapedType::kDynamic;
3611 int64_t OC = ShapedType::kDynamic;
3612 int64_t KH = ShapedType::kDynamic;
3613 int64_t KW = ShapedType::kDynamic;
3615 const ShapeAdaptor inputDataShape(getInputData().
getType());
3616 if (inputDataShape.hasRank()) {
3617 N = inputDataShape.getDimSize(0);
3618 IH = inputDataShape.getDimSize(1);
3619 IW = inputDataShape.getDimSize(2);
3620 IC = inputDataShape.getDimSize(3);
3623 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
3624 if (inputScaleShape.hasRank()) {
3626 "input_scale",
"batch size")) ||
3628 "input_scale",
"input height")) ||
3630 "input_scale",
"input width")))
3632 multiplesOfIC = inputScaleShape.getDimSize(3);
3635 const ShapeAdaptor weightDataShape(getWeightData().
getType());
3636 if (weightDataShape.hasRank()) {
3637 OC = weightDataShape.getDimSize(0);
3638 KH = weightDataShape.getDimSize(1);
3639 KW = weightDataShape.getDimSize(2);
3641 "weight_data",
"input channels")))
3645 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
3646 if (weightScaleShape.hasRank()) {
3648 "weight_scale",
"output channels")) ||
3650 "weight_scale",
"kernel height")) ||
3652 "weight_scale",
"kernel width")) ||
3654 weightScaleShape.getDimSize(3),
3655 "weight_scale",
"input channel blocks")))
3660 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3661 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3662 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
3663 << IC <<
", block_size=" << blockSize;
3666 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3667 multiplesOfIC != IC / blockSize)
3669 "expect scale operands dimension 2 to equal IC/block_size (")
3670 << IC <<
"/" << blockSize <<
")"
3671 <<
", got " << multiplesOfIC;
3674 SmallVector<int64_t> padValues;
3676 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
3677 return emitOpError(
"expect all padding values to be >= 0, got ")
3681 SmallVector<int64_t> strideValues;
3683 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
3684 return emitOpError(
"expect all stride values to be >= 1, got ")
3688 SmallVector<int64_t> dilationValues;
3691 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
3692 return emitOpError(
"expect all dilation values to be >= 1, got ")
3697 const ShapeAdaptor outputShape(getOutput().
getType());
3698 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3699 outputShape.hasRank()) {
3701 padValues[0], padValues[1], strideValues[0],
3702 dilationValues[0],
"height",
"y",
"top",
3705 padValues[2], padValues[3], strideValues[1],
3706 dilationValues[1],
"width",
"x",
"left",
3712 const ShapeAdaptor biasShape(getBias().
getType());
3713 if (biasShape.hasRank() && outputShape.hasRank()) {
3714 const int64_t biasChannels = biasShape.getDimSize(0);
3715 const int64_t outputChannels =
3716 outputShape.getDimSize(outputShape.getRank() - 1);
3717 if (biasChannels == ShapedType::kDynamic ||
3718 outputChannels == ShapedType::kDynamic)
3722 if (biasChannels != outputChannels && biasChannels != 1)
3724 "bias channels expected to be equal to output channels (")
3725 << outputChannels <<
") or 1, got " << biasChannels;
3731LogicalResult Conv3DOp::inferReturnTypeComponents(
3732 MLIRContext *context, ::std::optional<Location> location,
3733 Conv3DOp::Adaptor adaptor,
3734 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3735 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3737 int64_t inputWidth = ShapedType::kDynamic;
3738 int64_t inputHeight = ShapedType::kDynamic;
3739 int64_t inputDepth = ShapedType::kDynamic;
3741 int64_t weightWidth = ShapedType::kDynamic;
3742 int64_t weightHeight = ShapedType::kDynamic;
3743 int64_t weightDepth = ShapedType::kDynamic;
3746 ShapeAdaptor inputShape(adaptor.getInput().getType());
3747 if (inputShape.hasRank()) {
3748 outputShape[0] = inputShape.getDimSize(0);
3749 inputDepth = inputShape.getDimSize(1);
3750 inputHeight = inputShape.getDimSize(2);
3751 inputWidth = inputShape.getDimSize(3);
3755 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3756 if (weightShape.hasRank()) {
3757 outputShape[4] = weightShape.getDimSize(0);
3758 weightDepth = weightShape.getDimSize(1);
3759 weightHeight = weightShape.getDimSize(2);
3760 weightWidth = weightShape.getDimSize(3);
3764 ShapeAdaptor biasShape(adaptor.getBias().getType());
3765 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3766 outputShape[4] = biasShape.getDimSize(0);
3769 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3770 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3771 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3773 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3774 int32_t inputSize = inputDepth + pad[0] + pad[1];
3775 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3776 int32_t unstridedResult = inputSize - filterSize + 1;
3777 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3780 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3781 int32_t inputSize = inputHeight + pad[2] + pad[3];
3782 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3783 int32_t unstridedResult = inputSize - filterSize + 1;
3784 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3787 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3788 int32_t inputSize = inputWidth + pad[4] + pad[5];
3789 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3790 int32_t unstridedResult = inputSize - filterSize + 1;
3791 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3794 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3798LogicalResult Conv3DOp::verify() {
3805LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3806 MLIRContext *context, ::std::optional<Location> location,
3807 AvgPool2dOp::Adaptor adaptor,
3808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3809 ShapeAdaptor inputShape(adaptor.getInput().getType());
3810 const Properties &prop = adaptor.getProperties();
3812 inferredReturnShapes);
3815LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3816 MLIRContext *context, ::std::optional<Location> location,
3817 MaxPool2dOp::Adaptor adaptor,
3818 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3819 ShapeAdaptor inputShape(adaptor.getInput().getType());
3820 const Properties &prop = adaptor.getProperties();
3822 inferredReturnShapes);
3825LogicalResult MaxPool2dOp::verify() {
3836LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3837 MLIRContext *context, ::std::optional<Location> location,
3838 DepthwiseConv2DOp::Adaptor adaptor,
3839 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3840 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3842 int64_t inputWidth = ShapedType::kDynamic;
3843 int64_t inputHeight = ShapedType::kDynamic;
3844 int64_t inputChannels = ShapedType::kDynamic;
3846 int64_t weightWidth = ShapedType::kDynamic;
3847 int64_t weightHeight = ShapedType::kDynamic;
3848 int64_t depthChannels = ShapedType::kDynamic;
3851 ShapeAdaptor inputShape(adaptor.getInput().getType());
3852 if (inputShape.hasRank()) {
3853 outputShape[0] = inputShape.getDimSize(0);
3854 inputHeight = inputShape.getDimSize(1);
3855 inputWidth = inputShape.getDimSize(2);
3856 inputChannels = inputShape.getDimSize(3);
3860 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3861 if (weightShape.hasRank()) {
3862 weightHeight = weightShape.getDimSize(0);
3863 weightWidth = weightShape.getDimSize(1);
3864 inputChannels = ShapedType::isDynamic(inputChannels)
3865 ? weightShape.getDimSize(2)
3867 depthChannels = weightShape.getDimSize(3);
3872 if (ShapedType::isStatic(inputChannels) &&
3873 ShapedType::isStatic(depthChannels)) {
3874 outputShape[3] = inputChannels * depthChannels;
3878 ShapeAdaptor biasShape(adaptor.getBias().getType());
3879 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3880 int64_t bc = biasShape.getDimSize(0);
3881 if (bc != ShapedType::kDynamic && bc != 1)
3882 outputShape[3] = bc;
3885 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3886 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3887 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3889 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3890 int64_t inputSize = inputHeight + padding[0] + padding[1];
3891 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3892 int64_t unstridedResult = inputSize - filterSize + 1;
3893 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3896 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3897 int64_t inputSize = inputWidth + padding[2] + padding[3];
3898 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3899 int64_t unstridedResult = inputSize - filterSize + 1;
3900 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3903 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3907LogicalResult DepthwiseConv2DOp::verify() {
3914LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3915 MLIRContext *context, ::std::optional<Location> location,
3916 TransposeConv2DOp::Adaptor adaptor,
3917 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3918 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3920 int64_t inputWidth = ShapedType::kDynamic;
3921 int64_t inputHeight = ShapedType::kDynamic;
3922 int64_t weightWidth = ShapedType::kDynamic;
3923 int64_t weightHeight = ShapedType::kDynamic;
3926 ShapeAdaptor inputShape(adaptor.getInput().getType());
3927 if (inputShape.hasRank()) {
3928 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3929 ? inputShape.getDimSize(0)
3931 inputHeight = inputShape.getDimSize(1);
3932 inputWidth = inputShape.getDimSize(2);
3936 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3937 if (weightShape.hasRank()) {
3938 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3939 ? weightShape.getDimSize(0)
3941 weightHeight = weightShape.getDimSize(1);
3942 weightWidth = weightShape.getDimSize(2);
3946 ShapeAdaptor biasShape(adaptor.getBias().getType());
3947 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3948 int64_t bc = biasShape.getDimSize(0);
3949 if (bc != ShapedType::kDynamic && bc != 1)
3950 outputShape[3] = bc;
3953 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3954 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3956 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3957 int64_t calculateSize =
3958 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3960 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3963 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3964 int64_t calculateSize =
3965 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3967 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3970 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3974LogicalResult TransposeConv2DOp::verify() {
3978 const llvm::ArrayRef<int64_t> strides = getStride();
3979 const int64_t strideY = strides[0];
3980 const int64_t strideX = strides[1];
3982 if (strideY < 1 || strideX < 1)
3983 return emitOpError(
"expect all stride values to be >= 1, got [")
3986 const auto checkPadAgainstKernelDim =
3987 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
3988 llvm::StringRef kernelDimName) -> LogicalResult {
3989 if (padValue <= -kernelDimSize)
3991 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
3992 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
3997 const llvm::ArrayRef<int64_t> padding = getOutPad();
3998 const int64_t outPadTop = padding[0];
3999 const int64_t outPadBottom = padding[1];
4000 const int64_t outPadLeft = padding[2];
4001 const int64_t outPadRight = padding[3];
4003 const auto weightType =
4004 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4007 const int64_t kernelHeight = weightType.getDimSize(1);
4008 if (ShapedType::isStatic(kernelHeight)) {
4009 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4010 "out_pad_top",
"KH")))
4013 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4014 "out_pad_bottom",
"KH")))
4018 const int64_t kernelWidth = weightType.getDimSize(2);
4019 if (ShapedType::isStatic(kernelWidth)) {
4020 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4021 "out_pad_left",
"KW")))
4024 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4025 "out_pad_right",
"KW")))
4031 const auto outputType =
4032 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4036 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4037 if (inputType && weightType) {
4038 const int64_t inputHeight = inputType.getDimSize(1);
4039 const int64_t kernelHeight = weightType.getDimSize(1);
4040 const int64_t outputHeight = outputType.getDimSize(1);
4042 if (ShapedType::isStatic(inputHeight) &&
4043 ShapedType::isStatic(outputHeight)) {
4045 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4047 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4048 "+ out_pad_top + out_pad_bottom + KH, but got ")
4049 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4050 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4051 <<
" + " << kernelHeight;
4054 const int64_t inputWidth = inputType.getDimSize(2);
4055 const int64_t kernelWidth = weightType.getDimSize(2);
4056 const int64_t outputWidth = outputType.getDimSize(2);
4058 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4060 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4062 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4063 "+ out_pad_left + out_pad_right + KW, but got ")
4064 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4065 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4070 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4075 const int64_t biasChannels = biasType.getDimSize(0);
4078 if (biasChannels == ShapedType::kDynamic)
4081 const int64_t outputChannels = outputType.getDimSize(3);
4082 if (!ShapedType::isDynamic(outputChannels) &&
4083 biasChannels != outputChannels && biasChannels != 1)
4085 "bias channels expected to be equal to output channels (")
4086 << outputChannels <<
") or 1, got " << biasChannels;
4091LogicalResult RescaleOp::verify() {
4092 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
4094 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
4098 auto inputElementType =
4100 if (!mlir::isa<IntegerType>(inputElementType)) {
4101 emitOpError(
"expect input to have integer element type, got ")
4102 << inputElementType;
4106 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
4108 emitOpError(
"expect shaped tensor for output, got ")
4109 << getOutput().getType();
4113 auto outputElementType =
4115 if (!mlir::isa<IntegerType>(outputElementType)) {
4116 emitOpError(
"expect output to have integer element type, got ")
4117 << outputElementType;
4129 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4130 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4133 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4134 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4137 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
4138 if (!multiplierType) {
4139 emitOpError(
"expect shaped tensor for multiplier, got ")
4140 << getMultiplier().getType();
4144 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
4146 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
4151 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4152 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4153 << multiplierType.getElementType();
4158 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4160 "expect i16 element type for multiplier for scale32=false, got ")
4161 << multiplierType.getElementType();
4165 if (!inputType.hasRank())
4171 int64_t numChannels = 1;
4172 if (getPerChannel()) {
4173 if (inputType.getRank() < 1) {
4174 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4175 "true, but got rank ")
4176 << inputType.getRank();
4179 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4182 if (!multiplierType.hasRank())
4185 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4187 if (multiplierShape[0] != ShapedType::kDynamic &&
4188 multiplierShape[0] != numChannels) {
4190 << numChannels <<
" } for multiplier input, got { "
4191 << multiplierShape[0] <<
" }";
4195 if (!shiftType.hasRank())
4198 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4200 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4202 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
4209LogicalResult RescaleOp::inferReturnTypeComponents(
4210 MLIRContext *context, ::std::optional<Location> location,
4211 RescaleOp::Adaptor adaptor,
4212 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4213 ShapeAdaptor inputShape(adaptor.getInput().getType());
4214 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4218LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4219 MLIRContext *context, ::std::optional<Location> location,
4220 CastFromBlockScaledOp::Adaptor adaptor,
4221 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4222 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4223 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4227LogicalResult CastFromBlockScaledOp::verify() {
4228 const Type inputDataType = getInputData().getType();
4229 const Type outputDataType = getResult().getType();
4231 return emitOpError() <<
"require compatible shapes for input_data ("
4232 << inputDataType <<
") and " <<
"output_data ("
4233 << outputDataType <<
")";
4235 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4237 if (inputDataShape.
hasRank()) {
4238 const unsigned int blockSize =
4240 const int64_t inputDataLastDim =
4242 if (inputDataLastDim % blockSize != 0)
4243 return emitOpError() <<
"expect last dimension of input_data ("
4245 <<
") to be divisible by block_size (" << blockSize
4248 const Type inputScaleType = getInputScale().getType();
4249 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4251 if (inputScaleShape.
hasRank()) {
4252 SmallVector<int64_t> inputDataDims, inputScaleDims;
4253 inputDataShape.
getDims(inputDataDims);
4254 inputScaleShape.
getDims(inputScaleDims);
4256 if (inputDataDims.size() != inputScaleDims.size() ||
4258 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4259 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4261 <<
"require compatible shapes for input_data (" << inputDataType
4262 <<
") and " <<
"input_scale (" << inputScaleType
4263 <<
") except for the last dimension";
4265 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4266 inputScaleDims.back()};
4267 if (ShapedType::isStatic(inputDataLastDim) &&
4270 <<
"expect last dimension of input_scale ("
4271 << inputScaleDims.back()
4272 <<
") to be equal to last dimension of input_data / block_size ("
4273 << inputDataDims.back() / blockSize <<
")";
4280LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4281 MLIRContext *context, ::std::optional<Location> location,
4282 CastToBlockScaledOp::Adaptor adaptor,
4283 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4284 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4285 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4286 if (!inputShape.hasRank())
4290 SmallVector<int64_t> outputScaleShape;
4291 inputShape.getDims(outputScaleShape);
4292 const int64_t lastDimLoc = inputShape.getRank() - 1;
4293 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4294 if (ShapedType::isStatic(lastDimSize)) {
4295 const unsigned int blockSize =
4296 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4297 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4299 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4303LogicalResult CastToBlockScaledOp::verify() {
4304 const Type inputDataType = getInputData().getType();
4305 const Type outputDataType = getResult(0).getType();
4307 return emitOpError() <<
"require compatible shapes for input_data ("
4308 << inputDataType <<
") and " <<
"output_data ("
4309 << outputDataType <<
")";
4311 const unsigned int blockSize =
4313 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4314 if (inputDataShape.
hasRank()) {
4315 const int64_t inputDataLastDim =
4317 if (ShapedType::isStatic(inputDataLastDim) &&
4318 inputDataLastDim % blockSize != 0)
4319 return emitOpError() <<
"expect last dimension of input_data ("
4321 <<
") to be divisible by block_size (" << blockSize
4325 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4326 const Type outputScaleType = getResult(1).getType();
4327 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4329 SmallVector<int64_t> outputDataDims, outputScaleDims;
4330 outputDataShape.
getDims(outputDataDims);
4331 outputScaleShape.
getDims(outputScaleDims);
4333 if (outputDataDims.size() != outputScaleDims.size() ||
4335 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4336 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4337 return emitOpError() <<
"require compatible shapes for output_data ("
4338 << outputDataType <<
") and " <<
"output_scale ("
4340 <<
") except for the last dimension";
4342 const int64_t outputDataLastDim = outputDataDims.back();
4343 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4344 outputScaleDims.back()};
4345 if (ShapedType::isStatic(outputDataLastDim) &&
4348 <<
"expect last dimension of output_scale ("
4349 << outputScaleDims.back()
4350 <<
") to be equal to last dimension of output_data / block_size ("
4351 << outputDataDims.back() / blockSize <<
")";
4357LogicalResult IfOp::inferReturnTypeComponents(
4358 MLIRContext *context, ::std::optional<Location> location,
4359 IfOp::Adaptor adaptor,
4360 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4361 llvm::SmallVector<tosa::YieldOp> yieldOps;
4362 for (Region *region : adaptor.getRegions()) {
4363 for (
auto &block : *region)
4364 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4365 yieldOps.push_back(returnOp);
4368 if (yieldOps.empty())
4372 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4373 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4374 for (
auto operand : yieldOps.front().getOperands()) {
4375 resultKnowledge.push_back(
4379 for (
auto yieldOp : yieldOps) {
4380 if (resultKnowledge.size() != yieldOp.getNumOperands())
4383 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4384 int32_t index = it.index();
4386 resultKnowledge[index],
4390 resultKnowledge[index] = meet;
4394 for (
const ValueKnowledge &
result : resultKnowledge) {
4395 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4401LogicalResult WhileOp::inferReturnTypeComponents(
4402 MLIRContext *context, ::std::optional<Location> location,
4403 WhileOp::Adaptor adaptor,
4404 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4405 llvm::SmallVector<tosa::YieldOp> yieldOps;
4406 for (
auto &block : adaptor.getBodyGraph())
4407 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4408 yieldOps.push_back(returnOp);
4412 if (yieldOps.empty())
4416 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4417 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4418 for (
auto operand : yieldOps.front().getOperands()) {
4419 resultKnowledge.push_back(
4423 for (
auto yieldOp : yieldOps) {
4424 if (resultKnowledge.size() != yieldOp.getNumOperands())
4427 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4428 int32_t index = it.index();
4430 resultKnowledge[index],
4432 resultKnowledge[index] = meet;
4437 for (
const ValueKnowledge &
result : resultKnowledge) {
4438 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4444std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4445 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4446 return llvm::to_vector<4>(vt.getShape());
4447 return std::nullopt;
4453 StringRef prefix =
"") {
4454 assert(blocksArgs.size() == initializers.size() &&
4455 "expected same length of arguments and initializers");
4456 if (initializers.empty())
4459 parser << prefix <<
'(';
4460 llvm::interleaveComma(
4461 llvm::zip(blocksArgs, initializers), parser,
4462 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4467ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
4469 result.regions.reserve(2);
4470 Region *thenRegion =
result.addRegion();
4471 Region *elseRegion =
result.addRegion();
4473 OpAsmParser::UnresolvedOperand cond;
4478 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4479 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4482 OptionalParseResult listResult =
4490 "expected type for condition operand");
4496 "expected type for condition operand");
4504 FunctionType functionType;
4508 <<
"expected list of types for block arguments "
4509 <<
"followed by arrow type and list of return types";
4511 result.addTypes(functionType.getResults());
4513 if (functionType.getNumInputs() != operands.size()) {
4515 <<
"expected as many input types as operands " <<
"(expected "
4516 << operands.size() <<
" got " << functionType.getNumInputs()
4547void IfOp::print(OpAsmPrinter &p) {
4548 p <<
" " << getCondition();
4551 getInputList(),
" ");
4553 p << getCondition().getType();
4555 if (!getInputList().empty()) {
4557 llvm::interleaveComma(getInputList().getTypes(), p);
4566 auto &elseRegion = getElseGraph();
4567 if (!elseRegion.
empty()) {
4575LogicalResult IfOp::verify() {
4577 "'then_graph' arguments", getInputList(),
4583 "'else_graph' arguments", getInputList(),
4589 if (getThenGraph().front().mightHaveTerminator()) {
4591 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4593 *
this, thenYield.getInputs(),
"'then_graph' results",
4594 getOutputList(),
"'output_list'")
4600 if (getElseGraph().front().mightHaveTerminator()) {
4602 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4604 *
this, elseYield.getInputs(),
"'else_graph' results",
4605 getOutputList(),
"'output_list'")
4610 auto condType = getCondition().getType();
4612 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4618LogicalResult WhileOp::verify() {
4620 getOutputList(),
"'output_list'")
4625 "'cond_graph' arguments", getInputList(),
4631 "'body_graph' arguments", getInputList(),
4636 if (getBodyGraph().front().mightHaveTerminator()) {
4638 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4640 "'body_graph' results",
4641 getInputList(),
"'input_list'")
4648 if (!getCondGraph().front().mightHaveTerminator())
4652 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4656 if (condYield.getInputs().size() != 1)
4657 return emitOpError() <<
"require 'cond_graph' only have one result";
4659 auto condOutType = condYield.getInputs()[0].getType();
4661 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4665 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4671LogicalResult ReverseOp::verify() {
4676 TensorType inputType = getInput1().getType();
4677 TensorType outputType = getOutput().getType();
4678 int32_t reverseAxis = getAxis();
4680 if (reverseAxis < 0)
4681 return emitOpError(
"expected non-negative reverse axis");
4683 int64_t inputRank = inputType.getRank();
4686 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4688 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4692 int64_t outputRank = outputType.getRank();
4693 if (inputType.
hasRank() && outputRank != inputType.getRank())
4695 "expect output tensor rank to be equal to input tensor rank");
4696 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4698 << outputRank <<
") to be larger than reverse axis ("
4699 << reverseAxis <<
")";
4704LogicalResult tosa::SelectOp::verify() {
4715 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4716 if (!predicateType) {
4717 return emitOpError(
"expect shaped tensor for input1, got ")
4718 << getInput1().getType();
4720 auto predicateElementType = predicateType.getElementType();
4721 if (!predicateElementType.isInteger(1)) {
4722 return emitOpError(
"expect element type of bool for input1, got ")
4723 << predicateElementType;
4729LogicalResult tosa::VariableReadOp::verify() {
4737LogicalResult tosa::VariableWriteOp::verify() {
4746ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
4747 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4748 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4749 Region *cond =
result.addRegion();
4750 Region *body =
result.addRegion();
4752 OptionalParseResult listResult =
4757 FunctionType functionType;
4762 result.addTypes(functionType.getResults());
4764 if (functionType.getNumInputs() != operands.size()) {
4766 <<
"expected as many input types as operands " <<
"(expected "
4767 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
4777 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4778 regionArgs[i].type = functionType.getInput(i);
4780 return failure(parser.
parseRegion(*cond, regionArgs) ||
4785void WhileOp::print(OpAsmPrinter &parser) {
4787 getInputList(),
" ");
4790 getResults().getTypes());
4804 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4805 if (llvm::isa<FloatType>(srcElemType)) {
4807 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4808 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4810 if (llvm::isa<IntegerType>(srcElemType)) {
4813 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4815 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4816 return std::nullopt;
4824 return mlir::isa<tosa::shapeType>(t);
4831 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4837 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4838 Operation *definingOp = v.getDefiningOp();
4840 return op->
emitOpError(
"shape operand is not compile time resolvable");
4853 auto getRank = [](
const Type type) {
4854 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4860 for (
auto type : operandTypes) {
4861 if (getRank(type) != rank) {
4862 return op->
emitOpError(
"operands don't have matching ranks");
4865 for (
auto type : resultTypes) {
4866 if (getRank(type) != rank) {
4867 return op->
emitOpError(
"result shape has different rank than operands");
4877LogicalResult tosa::ConstShapeOp::verify() {
4879 auto valuesRank = getValues().getType().getRank();
4880 if (valuesRank != 1)
4881 return emitOpError(
"expect elements in attribute values with rank 1");
4883 auto count = getValues().getNumElements();
4884 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4885 if (count != rank && (count != 1 || rank != 0)) {
4886 return emitOpError(
"expect number of elements in attribute values (")
4887 << count <<
") to be equal to the rank (" << rank
4888 <<
") for the result shape type";
4893LogicalResult tosa::DimOp::verify() {
4894 const tosa::shapeType outShapeType =
4895 cast<tosa::shapeType>(getResult().
getType());
4896 if (outShapeType.getRank() != 1)
4897 return emitOpError(
"expect output shape type to contain one element, got ")
4902 const int64_t inputRank = inputType.getRank();
4903 const int64_t axis = getAxisAttr().getInt();
4904 if (axis < 0 || axis >= inputRank)
4905 return emitOpError(
"expect axis to be in the range [0, ")
4906 << inputRank <<
"), got " << axis;
4911LogicalResult tosa::ConcatShapeOp::verify() {
4912 const tosa::shapeType outShapeType =
4913 cast<tosa::shapeType>(getResult().
getType());
4914 const int64_t outputRank = outShapeType.getRank();
4916 const int64_t inputsRank =
4917 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
4918 const tosa::shapeType inShapeType =
4919 cast<tosa::shapeType>(input.
getType());
4920 return acc + inShapeType.getRank();
4922 if (outputRank != inputsRank)
4923 return emitOpError(
"requires output shape rank to be equal to the sum of "
4924 "the input shape ranks (")
4925 << inputsRank <<
"), got " << outputRank;
4930LogicalResult tosa::SliceShapeOp::verify() {
4931 std::optional<int32_t> start;
4932 DenseIntElementsAttr startAttr;
4934 start = startAttr.getValues<int32_t>()[0];
4935 if (start && start.value() < 0)
4936 return emitOpError(
"expected non-negative start index, got ")
4939 std::optional<int32_t> size;
4940 DenseIntElementsAttr sizeAttr;
4942 size = sizeAttr.getValues<int32_t>()[0];
4943 if (size && size.value() <= 0)
4944 return emitOpError(
"expected positive size, got ") << size.value();
4949 const tosa::shapeType outShapeType =
4950 cast<tosa::shapeType>(getResult().
getType());
4951 const int64_t outputRank = outShapeType.getRank();
4952 if (outputRank != size)
4954 "expected output type size to be equal to size attribute, got ")
4955 << outputRank <<
" vs " << size.value();
4960 const tosa::shapeType inShapeType =
4961 cast<tosa::shapeType>(getInput().
getType());
4962 const int64_t inputRank = inShapeType.getRank();
4963 const int64_t sliceSize = start.value() + size.value();
4964 if (sliceSize > inputRank)
4965 return emitOpError(
"expected start + size to be less than or equal to "
4966 "input shape rank (")
4967 << inputRank <<
"), got " << sliceSize;
4976#define GET_ATTRDEF_CLASSES
4977#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4982#define GET_TYPEDEF_CLASSES
4983#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4989#define GET_OP_CLASSES
4990#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)