29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
56struct TosaInlinerInterface :
public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
65 IRMapping &map)
const final {
71 IRMapping &map)
const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
85 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer)
const override {
91 return ::writeAttribute(attr, writer);
97 Type readType(DialectBytecodeReader &reader)
const override {
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer)
const override {
103 return ::writeType(type, writer);
106 void writeVersion(DialectBytecodeWriter &writer)
const final {
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader)
const final {
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version)
const final {
131 return {&getBodyGraph()};
140 return dim == -1 ? ShapedType::kDynamic : dim;
146 Type elementType = variableOp.getType();
149 return RankedTensorType::get(
shape, elementType);
156void TosaDialect::initialize() {
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
200ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
202 TypeAttr &typeAttr) {
203 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
206 <<
"expected ranked type";
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
216 <<
"expected shaped type";
233 <<
"expected attribute";
235 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
240 <<
"expected Typed attr";
243 initialValueAttr =
nullptr;
247 <<
"expected type after colon";
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
254 TypeAttr typeAttr,
Attribute initialValueAttr) {
255 bool needsSpace =
false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
267 if (initialValueAttr) {
278template <
typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
281 llvm::StringRef name;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name ==
"rounding_mode" &&
291 auto sym = symbolizeRoundingMode(kw);
294 <<
"invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 auto sym = symbolizeResizeMode(kw);
306 <<
"invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 auto sym = symbolizeNanPropagationMode(kw);
319 <<
"invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 auto sym = symbolizeBlockSize(kw);
332 <<
"invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
345template <
typename EnumType>
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
382 parser << namedAttr.
getName().strref() <<
" = ";
384 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 }
else if (
auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
410 toPrint.erase(attr.getName());
416 if (!toPrint.empty()) {
418 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
435 llvm::interleaveComma(op->
getAttrs(), parser,
437 printNamedAttr(parser, namedAttr);
449 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
453 printWithEnumHandling(parser, *
this);
457 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
461 printWithEnumHandling(parser, *
this);
465 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
469 printWithEnumHandling(parser, *
this);
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
477 printWithNanPropagationHandling(parser, *
this);
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
485 printWithNanPropagationHandling(parser, *
this);
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
493 printWithNanPropagationHandling(parser, *
this);
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
501 printWithNanPropagationHandling(parser, *
this);
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
509 printWithNanPropagationHandling(parser, *
this);
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
517 printWithNanPropagationHandling(parser, *
this);
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
525 printWithNanPropagationHandling(parser, *
this);
528ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
534 printWithEnumHandling(parser, *
this);
537ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
542void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *
this);
546ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552 printWithEnumHandling(parser, *
this);
555ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
561 printWithEnumHandling(parser, *
this);
576 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
586 Value valZp, StringRef name) {
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
595 if (!bothInts || !sameBitWidth) {
597 <<
"expected " << name <<
" and " << name
598 <<
"_zp to both be integer of the same bitwidth, but got " << eType
599 <<
" vs. " << eZpType;
606 Value src, int32_t val) {
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
621 if (dyn_cast<tosa::mxint8Type>(type))
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
635 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
637 << dimName <<
" of " << operandName <<
" to match size " << currDim
638 <<
", got " << newDim;
646 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
654 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
657 if (!calculatedOutSizeMinusOne.has_value())
659 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
660 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
661 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
662 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
663 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
669 << dimName <<
" did not match expected: "
670 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
693 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
696 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
699 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
702 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
709 "expect both bias and result to have same element type, got ")
710 << biasEType <<
" and " << resultEType;
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
718 "expect both input and weight to have same element type, got ")
719 << inputEType <<
" and " << weightEType;
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
728 if (inputIsFloat != weightIsFloat) {
730 "expect both input and weight to be float or not together, got ")
731 << inputEType <<
" and " << weightEType;
736 if (inputEType != inputZpEType) {
737 return op.emitOpError(
"expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType <<
" and " << inputZpEType;
743 if (weightEType != weightZpEType) {
744 return op.emitOpError(
"expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType <<
" and " << weightZpEType;
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
760LogicalResult tosa::ConstOp::verify() {
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
765 if (!attrType || !outputType) {
766 emitOpError(
"expected tensors for attr/result type");
770 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError(
"expected same attr/result element types");
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
789 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
822 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
836 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
837 return op.emitOpError(
"expect all padding values to be >= 0, got ")
841 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
842 return op.emitOpError(
"expect all stride values to be >= 1, got ")
846 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
847 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
861 if (inputType && weightType) {
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0],
"height",
"y",
"top",
"bottom")))
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1],
"width",
"x",
"left",
"right")))
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0],
"height",
"y",
"top",
"bottom")))
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1],
"width",
"x",
"left",
"right")))
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0],
"depth",
"d",
"front",
"back")))
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1],
"height",
"y",
"top",
"bottom")))
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2],
"width",
"x",
"left",
"right")))
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
920 const int64_t biasChannels = biasType.getDimSize(0);
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels <<
") or 1, got " << biasChannels;
938 StringRef name1,
Type type2,
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
949 <<
"require same element type for " << name1 <<
" (" << elemType1
950 <<
") and " << name2 <<
" (" << elemType2 <<
")";
954 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
955 << name2 <<
" (" << type2 <<
")";
965 if (list1.size() != list2.size())
967 <<
"require same number of values in " << name1 <<
" ("
968 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
970 for (
auto [type1, type2] :
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
997 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1001 return op->emitOpError(
"'")
1002 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1014template <
typename T>
1016 StringRef aName =
"input",
1017 StringRef bName =
"output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1021 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1025 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1041 op.emitOpError(
"expect ")
1042 << aName <<
" and " << bName <<
" to have same element type, got "
1043 << aElementType <<
" and " << bElementType;
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1053 if (
const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError(
"result tensor is not of integer type");
1057 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1058 if (!inputType.hasRank())
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError(
"specified axis is outside the rank of the tensor");
1066 if (!resultType.hasRank())
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1075 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1080template <
typename T>
1083 if (llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1084 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1088 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1089 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1093 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1094 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError(
"expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError(
"expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop <<
", pad_bottom=" << padBottom
1113 <<
", kernel_y=" << kernelY;
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1122 const auto verifyOutputSize =
1126 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError(
"expected input_")
1136 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1137 << padAfterName <<
" - kernel_" << dimAxis
1138 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1139 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1140 << kernelSize <<
") / " << strideSize;
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError(
"calculated output ")
1145 << dimName <<
" did not match expected: " <<
"calculated="
1146 << calculatedOutSize <<
", expected=" << outputSize;
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height",
"y",
"top",
"bottom")))
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width",
"x",
"left",
"right")))
1164LogicalResult tosa::AvgPool2dOp::verify() {
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError(
"accumulator type for integer tensor is not i32");
1177 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1180 if (inputETy.
isBF16() && !accType.isF32())
1181 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1183 if (inputETy.
isF32() && !accType.isF32())
1184 return emitOpError(
"accumulator type for f32 tensor is not f32");
1186 if (inputETy != inputZpETy)
1187 return emitOpError(
"expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy <<
" and " << inputZpETy;
1191 if (resultETy != outputZpETy)
1192 return emitOpError(
"expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy <<
" and " << outputZpETy;
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1207LogicalResult tosa::ClampOp::verify() {
1209 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1210 if (
auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1215 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1216 if (
auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1220 if (inputETy != outputETy)
1221 return emitOpError(
"input/output element types are incompatible.");
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1228 if (inputETy.
isInteger(dataTypeBitWidth)) {
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError(
"min/max attributes types are incompatible with "
1238 "input/output element types.");
1241 const bool isBoolean = inputETy.
isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError(
"expected min_val <= max_val, got min_val=")
1246 << minValAttr <<
", max_val=" << maxValAttr;
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError(
"min/max attributes types are incompatible with "
1257 "input/output element types.");
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr <<
", max_val=" << maxValAttr;
1265 if (maxVal < minVal)
1266 return emitOpError(
"expected min_val <= max_val, got min_val=")
1267 << minValAttr <<
", max_val=" << maxValAttr;
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute(
"pad", pad);
1289 result.addAttribute(
"stride", stride);
1290 result.addAttribute(
"dilation", dilation);
1291 result.addAttribute(
"acc_type", accType);
1292 Type finalOutputType = outputType;
1298 result.addTypes(finalOutputType);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute(
"out_pad", outpad);
1311 result.addAttribute(
"stride", stride);
1312 result.addAttribute(
"acc_type", accType);
1313 Type finalOutputType = outputType;
1319 result.addTypes(finalOutputType);
1330 result.addOperands({a,
b, zps.first, zps.second});
1332 Type finalOutputType{outputType};
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType &&
"Output must be a shaped type");
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1346 finalOutputType = outputShapedType.clone(accElementType);
1348 result.addTypes(finalOutputType);
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1363 if (
auto quantAttr =
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1368 const std::optional<Value> inputZpOp =
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1375 const std::optional<Value> outputZpOp =
1378 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1388 result.addOperands({input});
1390 result.addAttribute(
"kernel", kernel);
1391 result.addAttribute(
"stride", stride);
1392 result.addAttribute(
"pad", pad);
1393 result.addAttribute(
"acc_type", accType);
1394 result.types.push_back(outputType);
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1411 const std::optional<Value> input1ZpOp =
1415 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1418 const std::optional<Value> outputZpOp =
1422 loc,
"Failed to create output zero point for quantized NEGATE op");
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1431 result.addOperands({input});
1434 result.types.push_back(outputType);
1447 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1455 StringRef name,
Type variableType,
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1462 (
void)
emitError(loc,
"variable type must be a shaped type");
1465 if (!shapedType.hasRank()) {
1466 (
void)
emitError(loc,
"variable type must be a ranked type");
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1475 result.addAttribute(
"sym_name", nameAttr);
1476 result.addAttribute(
"var_shape", varShapeAttr);
1477 result.addAttribute(
"type", elementTypeAttr);
1478 result.addAttribute(
"initial_value", initialValue);
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1501 for (
int i = 0, e = operands.size(); i != e; ++i) {
1503 if (!
shape.hasRank()) {
1508 outRank = std::max<int64_t>(outRank,
shape.getRank());
1511 outShape.resize(outRank, 1);
1513 for (
int i = 0, e = operands.size(); i != e; ++i) {
1515 auto rankDiff = outShape.size() -
shape.getRank();
1517 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 =
shape.getDimSize(i);
1521 const FailureOr<int64_t> maybeResolvedDim =
1523 if (failed(maybeResolvedDim))
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1541 if (!inputShape.hasRank()) {
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1551 outShape.push_back(inputShape.getDimSize(i));
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1564 if (!inputShape.hasRank())
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1589 << dimName <<
" to be a power of two, got " << dimSize;
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1597 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1620 outputType.getShape().drop_back())))
1621 return emitOpError(
"expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType <<
" output=" << outputType;
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1640 inferredReturnShapes.push_back(
1642 inferredReturnShapes.push_back(
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1652 if (!inputRealType || !inputImagType)
1655 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a :
b;
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1682 bool hasRankedInput =
false;
1683 for (
auto operand : adaptor.getOperands()) {
1685 if (!operandShape.hasRank())
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1693 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1705 hasRankedInput =
true;
1708 if (adaptor.getInput1().empty())
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1720 for (
auto operand : adaptor.getOperands()) {
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1730 concatDimSize += operandShape.getDimSize(axis);
1733 outputShape[axis] = concatDimSize;
1739LogicalResult tosa::ConcatOp::verify() {
1741 auto outType = getOutput().getType();
1745 if (inputList.empty())
1748 if (!llvm::all_of(inputList, [&](
auto input) {
1750 *
this, input.getType(), outType));
1755 const int32_t axis = getAxis();
1757 for (
const auto &input : inputList) {
1758 const Type inputType = input.getType();
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1763 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1764 return emitOpError(
"expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1771 const auto allOperandsHasRank = [](
const Value input) {
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1777 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum =
index + 1;
1783 if (inputRank != firstInputRank)
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1790 for (
int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1796 if (inputDim != firstInputDim)
1797 return emitOpError(
"expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1800 <<
" on operands 0 and " << operandNum;
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError(
"expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() <<
" vs " << firstInputRank;
1811 for (
const auto &input : inputList) {
1813 if (inputShape.isDynamicDim(axis)) {
1818 axisSum += inputShape.getDimSize(axis);
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError(
"requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum <<
" and " << outputShape.getDimSize(axis);
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1837 auto elementType = IntegerType::get(context, 1);
1850 if (l.size() != r.size() || l.size() != 1)
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1864 outShape.resize(3, ShapedType::kDynamic);
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1874 outShape[2] = rhsShape.getDimSize(2);
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1887 return emitOpError(
"expect a shaped tensor for input a, got ")
1888 << getA().getType();
1891 return emitOpError(
"expect a shaped tensor for input b, got ")
1892 << getB().getType();
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError(
"expect operands to be both quantized or both not "
1906 << aElementType <<
" and " << bElementType;
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError(
"expect quantized operands to have same widths, got ")
1913 << aQuantWidth <<
" and " << bQuantWidth;
1920 if (aEType != aZpEType) {
1921 return emitOpError(
"expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType <<
" and " << aZpEType;
1928 if (bEType != bZpEType) {
1929 return emitOpError(
"expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType <<
" and " << bZpEType;
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1989LogicalResult MatmulTBlockScaledOp::verify() {
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2015 "a_scale",
"batch")) ||
2017 "a_scale",
"height")))
2025 "b_data",
"batch")) ||
2027 "b_data",
"channels")))
2035 "b_scale",
"batch")) ||
2037 "b_scale",
"width")) ||
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2048 << D <<
" vs N=" << N;
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError(
"expect C to be a multiple of block size, got C=")
2054 <<
C <<
", block_size=" << blockSize;
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2064 N = ShapedType::isDynamic(N) ? D : N;
2066 const auto outputType = cast<ShapedType>(getResult().
getType());
2067 if (outputType.hasRank() &&
2071 auto stringifyDim = [&](
int64_t d) {
2072 if (ShapedType::isDynamic(d))
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError <<
" to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2112 outputShape.reserve(inputShape.getRank());
2113 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2122 outputShape.push_back(ShapedType::kDynamic);
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2133LogicalResult tosa::PadOp::verify() {
2140 if (
auto padConst = getPadConst()) {
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2152 if (!inputType || !outputType)
2159 auto inputRank = inputType.getRank();
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2166 return emitOpError() <<
"padding tensor must have " << inputRank
2167 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2168 << paddingValues.size();
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2173 for (
int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2179 <<
"invalid padding values at dimension " << i
2180 <<
": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart <<
", " << padEnd <<
"]";
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() <<
"mismatch in output shape at dimension " << i
2191 <<
": expected " << inputShape[i] <<
" + "
2192 << padStart <<
" + " << padEnd <<
" = "
2193 << (inputShape[i] + padStart + padEnd)
2194 <<
", but got " << outputShape[i];
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2223 if (inputShape.hasRank()) {
2224 for (
size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2232 outputShape[i] = size[i];
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2240 outputShape[i] = size[i];
2252LogicalResult tosa::SliceOp::verify() {
2259 if (inputShape.hasRank()) {
2260 const auto inputRank = inputShape.getRank();
2262 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2264 "expect input1 and output to have the same ranks, got ")
2265 << inputRank <<
" and " << outputShape.getRank();
2267 const auto startShapeRank =
2268 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2269 if (inputRank != startShapeRank)
2270 return emitOpError(
"length of start is not equal to rank of input shape");
2272 const auto sizeShapeRank =
2273 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2274 if (inputRank != sizeShapeRank)
2275 return emitOpError(
"length of size is not equal to rank of input shape");
2281LogicalResult tosa::MulOp::inferReturnTypeComponents(
2282 MLIRContext *context, ::std::optional<Location> location,
2297LogicalResult tosa::MulOp::verify() {
2298 const Value output = getOutput();
2303 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2304 IntegerType lhsIntType =
2306 IntegerType rhsIntType =
2308 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2309 return emitOpError(
"requires the same element type for all operands");
2314 if (lhsIntType.getWidth() > resIntType.getWidth())
2315 return emitOpError(
"invalid data type size for operands or result");
2320 for (
int i = 0; i < 2; ++i) {
2323 "requires the same element type for all operands and results");
2327 ElementsAttr shiftElem;
2329 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2331 return emitOpError() <<
"require shift to be 0 for float type";
2339 TypeRange operandTypes = getOperandTypes();
2340 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2341 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2343 const bool aHasRank = aType.hasRank();
2344 const bool bHasRank = bType.hasRank();
2345 if (aHasRank && bHasRank) {
2346 const int64_t aRank = aType.getRank();
2347 const int64_t bRank = bType.getRank();
2349 return emitOpError(
"a and b operands don't have matching ranks, got ")
2350 << aRank <<
" and " << bRank;
2355 aType.getShape(), bType.getShape(), resultShape))
2356 return emitOpError(
"a and b operands don't have broadcast-compatible "
2358 << aType <<
" and " << bType;
2361 ShapedType resultType = cast<ShapedType>(output.
getType());
2362 if (!resultType.hasRank())
2365 const int64_t resultRank = resultType.getRank();
2366 if (aHasRank && resultRank != aType.getRank())
2367 return emitOpError(
"result type has different rank than a, got ")
2368 << resultRank <<
" vs " << aType.getRank();
2369 if (bHasRank && resultRank != bType.getRank())
2370 return emitOpError(
"result type has different rank than b, got ")
2371 << resultRank <<
" vs " << bType.getRank();
2376LogicalResult tosa::TableOp::inferReturnTypeComponents(
2377 MLIRContext *context, ::std::optional<Location> location,
2378 TableOp::Adaptor adaptor,
2380 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2382 if (!inputShape.hasRank()) {
2387 inferredReturnShapes.resize(1);
2388 inputShape.getDims(inferredReturnShapes[0]);
2392LogicalResult tosa::TableOp::verify() {
2393 const TensorType inputType = getInput1().getType();
2394 const TensorType outputType = getOutput().getType();
2403 auto inputDims = inputType.
getShape();
2404 auto outputDims = outputType.
getShape();
2405 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2407 auto [inputDim, outputDim] = it.value();
2408 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2409 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2410 <<
" doesn't match dim(input, " << dim
2411 <<
") = " << inputDim;
2424 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2425 [](
const APInt &val) { return val.getSExtValue(); });
2429LogicalResult tosa::TileOp::inferReturnTypeComponents(
2430 MLIRContext *context, ::std::optional<Location> location,
2431 TileOp::Adaptor adaptor,
2438 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2446 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2448 if (!inputShape.hasRank()) {
2449 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2450 inferredReturnShapes.push_back(
2453 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2457 outputShape.reserve(multiples.size());
2458 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2459 if (multiples[i] == ShapedType::kDynamic) {
2460 outputShape.push_back(ShapedType::kDynamic);
2462 int64_t dim = inputShape.getDimSize(i);
2463 if (dim != ShapedType::kDynamic)
2464 dim *= multiples[i];
2465 outputShape.push_back(dim);
2473LogicalResult tosa::TileOp::verify() {
2479 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2480 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2482 shapeType multiplesType =
2483 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2485 auto multiplesRank = multiplesType.getRank();
2487 if (inputType.hasRank()) {
2488 if (inputType.getRank() != multiplesRank)
2489 return emitOpError(
"expect 'multiples' to have rank ")
2490 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2491 if (outputType.hasRank() &&
2495 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2496 return emitOpError(
"expect 'multiples' array to have length ")
2497 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2500 if (getConstantMultiples(multiples).succeeded() &&
2501 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2503 "expect element of 'multiples' to be positive integer or -1.");
2509 if (l.size() != r.size() || l.size() != 1)
2514LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2515 MLIRContext *context, ::std::optional<Location> location,
2516 ReshapeOp::Adaptor adaptor,
2518 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2523 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2533 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2534 inferredReturnShapes.push_back(
2542 int64_t numElements = inputShape.getNumElements();
2544 for (
auto val : newShapeValue) {
2545 if (ShapedType::isStatic(val)) {
2551 for (
auto &val : newShapeValue) {
2552 if (ShapedType::isDynamic(val))
2553 val = numElements / staticMul;
2556 inferredReturnShapes.push_back(
2561llvm::LogicalResult tosa::ReshapeOp::verify() {
2567 TensorType inputType = getInput1().getType();
2572 return mlir::success();
2575 int missingDims = llvm::count(shapeValues, -1);
2576 if (missingDims > 1)
2577 return emitOpError() <<
"expected at most one target dimension to be -1";
2579 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2583 if ((
int64_t)shapeValues.size() != outputType.getRank())
2584 return emitOpError() <<
"new shape does not match result rank";
2586 for (
auto [newShapeDim, outputShapeDim] :
2587 zip(shapeValues, outputType.getShape())) {
2588 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2589 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2590 return emitOpError() <<
"new shape is inconsistent with result shape";
2592 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2593 return emitOpError() <<
"new shape has invalid tensor dimension size "
2597 if (inputType.hasStaticShape()) {
2598 int64_t inputElementsNum = inputType.getNumElements();
2599 if (outputType.hasStaticShape()) {
2600 int64_t outputElementsNum = outputType.getNumElements();
2601 if (inputElementsNum != outputElementsNum) {
2602 return emitOpError() <<
"cannot reshape " << inputElementsNum
2603 <<
" elements into " << outputElementsNum;
2609 return (dim > 0) ?
acc * dim :
acc;
2611 bool isStaticNewShape =
2612 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2613 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2614 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2615 return emitOpError() <<
"cannot reshape " << inputElementsNum
2616 <<
" elements into " << newShapeElementsNum;
2620 return mlir::success();
2627 ElementsAttr zpAttr;
2632 Type zpElemType = zpAttr.getElementType();
2634 if (llvm::isa<FloatType>(zpElemType)) {
2635 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2642 if (llvm::isa<IntegerType>(zpElemType)) {
2644 return zpAttr.getValues<APInt>()[0].getSExtValue();
2646 return zpAttr.getValues<APInt>()[0].getZExtValue();
2653template <
typename T>
2655 const std::string &operand) {
2658 if (!zpElemType.
isInteger(8) && zp != 0) {
2660 std::string lower = operand;
2661 llvm::transform(lower, lower.begin(), ::tolower);
2662 return op.emitOpError()
2663 << lower <<
" zero point must be zero for non-int8 integer types";
2671 const std::string &operand) {
2672 bool isInputZp = (operand ==
"Input");
2674 bool tensorUnsigned =
2675 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2676 StringRef tensorName = isInputZp ?
"input" :
"output";
2682 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2683 return op.emitOpError()
2684 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2686 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2687 return op.emitOpError() <<
"expect " << tensorName
2688 <<
"_zp of 0 or 32768 for unsigned int16 "
2689 << tensorName <<
", got " << zp;
2696#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2697 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2698 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2700 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2701 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2720#undef ZERO_POINT_HELPER
2722LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2723 MLIRContext *context, ::std::optional<Location> location,
2724 TransposeOp::Adaptor adaptor,
2726 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2735 const auto inputRank = inputShape.
getRank();
2739 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2745 if (inputRank == 0) {
2751 bool allTheSame =
true;
2752 for (
int i = 1, s = inputRank; i < s; i++) {
2762 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2767 outputShape.resize(inputRank, ShapedType::kDynamic);
2770 if (llvm::any_of(adaptor.getPerms(),
2771 [inputRank](
const auto i) { return i >= inputRank; }))
2774 outputShape.reserve(inputRank);
2775 for (
int i = 0, s = inputRank; i < s; i++) {
2776 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2783LogicalResult tosa::TransposeOp::verify() {
2795 if (inputShape.hasRank() &&
2796 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2797 return emitOpError() <<
"expected perms attribute to have size "
2798 << inputShape.getRank()
2799 <<
" (input rank) but got size "
2800 << constantPerms.size();
2802 if (inputShape.hasRank() && outputShape.hasRank() &&
2803 inputShape.getRank() != outputShape.getRank())
2805 <<
"expected input tensor rank to equal result tensor rank";
2807 if (outputShape.hasRank() &&
2808 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2809 return emitOpError() <<
"expected perms attribute to have size "
2810 << outputShape.getRank()
2811 <<
" (output rank) but got size "
2812 << constantPerms.size();
2814 if (!llvm::all_of(constantPerms,
2815 [&constantPerms](int32_t s) {
2817 static_cast<size_t>(s) < constantPerms.size();
2820 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
2821 return emitOpError() <<
"expected valid permutation indices";
2824 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2825 inputShape.getNumElements() != outputShape.getNumElements())
2826 return emitOpError() <<
"expected input1 and output to have same numbers "
2828 << inputShape.getNumElements() <<
" and "
2829 << outputShape.getNumElements();
2833 if (inputShape.hasRank() && outputShape.hasRank()) {
2834 for (
auto i = 0; i < outputShape.getRank(); i++) {
2835 if (inputShape.isDynamicDim(constantPerms[i]) ||
2836 outputShape.isDynamicDim(i))
2839 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2841 <<
"expected output tensor dim " << i <<
" to match "
2842 <<
"input dim " << constantPerms[i] <<
" with value of "
2843 << inputShape.getDimSize(constantPerms[i]);
2850LogicalResult TransposeOp::reifyResultShapes(
2853 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2855 Value input = getInput1();
2856 auto inputType = cast<TensorType>(input.
getType());
2858 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2859 for (
auto dim : transposePerms) {
2860 int32_t dimInInput = transposePerms[dim];
2861 if (inputType.isDynamicDim(dimInInput))
2863 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2867 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2870 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2874LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2875 MLIRContext *context, ::std::optional<Location> location,
2876 GatherOp::Adaptor adaptor,
2877 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2878 llvm::SmallVector<int64_t> outputShape;
2879 outputShape.resize(3, ShapedType::kDynamic);
2881 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2882 if (valuesShape.hasRank()) {
2883 outputShape[0] = valuesShape.getDimSize(0);
2884 outputShape[2] = valuesShape.getDimSize(2);
2887 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2888 if (indicesShape.hasRank()) {
2889 if (outputShape[0] == ShapedType::kDynamic)
2890 outputShape[0] = indicesShape.getDimSize(0);
2891 if (outputShape[1] == ShapedType::kDynamic)
2892 outputShape[1] = indicesShape.getDimSize(1);
2895 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2899LogicalResult tosa::GatherOp::verify() {
2906 const ShapeAdaptor valuesShape(getValues().
getType());
2908 const ShapeAdaptor outputShape(getOutput().
getType());
2910 int64_t n = ShapedType::kDynamic;
2911 int64_t w = ShapedType::kDynamic;
2912 int64_t c = ShapedType::kDynamic;
2914 if (valuesShape.hasRank()) {
2915 n = valuesShape.getDimSize(0);
2916 c = valuesShape.getDimSize(2);
2918 if (indicesShape.hasRank()) {
2919 const int64_t indicesN = indicesShape.getDimSize(0);
2920 w = indicesShape.getDimSize(1);
2921 if (n == ShapedType::kDynamic)
2923 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2924 return emitOpError() <<
"requires indices dimension 0 to have size " << n
2925 <<
", got " << indicesN;
2927 if (outputShape.hasRank()) {
2928 const int64_t outputN = outputShape.getDimSize(0);
2929 const int64_t outputW = outputShape.getDimSize(1);
2930 const int64_t outputC = outputShape.getDimSize(2);
2931 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2933 return emitOpError() <<
"requires output dimension 0 to have size " << n
2934 <<
", got " << outputN;
2936 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2938 return emitOpError() <<
"requires output dimension 1 to have size " << w
2939 <<
", got " << outputW;
2940 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2942 return emitOpError() <<
"requires output dimension 2 to have size " << c
2943 <<
", got " << outputC;
2948LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2949 MLIRContext *context, ::std::optional<Location> location,
2950 ResizeOp::Adaptor adaptor,
2951 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2952 llvm::SmallVector<int64_t, 4> outputShape;
2953 outputShape.resize(4, ShapedType::kDynamic);
2955 ShapeAdaptor inputShape(adaptor.getInput().getType());
2956 if (!inputShape.hasRank())
2959 outputShape[0] = inputShape.getDimSize(0);
2960 outputShape[3] = inputShape.getDimSize(3);
2961 int64_t inputHeight = inputShape.getDimSize(1);
2962 int64_t inputWidth = inputShape.getDimSize(2);
2964 if ((inputHeight == ShapedType::kDynamic) ||
2965 (inputWidth == ShapedType::kDynamic))
2968 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2979 const int64_t outputHeight =
2980 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2984 const int64_t outputWidth =
2985 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2989 if (outputHeight < 0 || outputWidth < 0) {
2992 "calculated output height and width must be non-negative, "
2994 outputHeight,
", width = ", outputWidth);
2997 outputShape[1] = outputHeight;
2998 outputShape[2] = outputWidth;
2999 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3003LogicalResult tosa::ResizeOp::verify() {
3004 const Value input = getInput();
3005 const Value output = getOutput();
3006 const RankedTensorType inputType =
3007 llvm::dyn_cast<RankedTensorType>(input.
getType());
3008 const RankedTensorType outputType =
3009 llvm::dyn_cast<RankedTensorType>(output.
getType());
3011 SmallVector<int64_t> scaleValues;
3012 SmallVector<int64_t> offsetValues;
3013 SmallVector<int64_t> borderValues;
3021 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3022 return emitOpError(
"expect all scale values to be > 0, got ")
3025 const int64_t scaleYN = scaleValues[0];
3026 const int64_t scaleYD = scaleValues[1];
3027 const int64_t scaleXN = scaleValues[2];
3028 const int64_t scaleXD = scaleValues[3];
3030 const int64_t offsetY = offsetValues[0];
3031 const int64_t offsetX = offsetValues[1];
3033 const int64_t borderY = borderValues[0];
3034 const int64_t borderX = borderValues[1];
3041 const int64_t oh = outputType.getDimSize(1);
3042 const int64_t ow = outputType.getDimSize(2);
3043 const int64_t ih = inputType.getDimSize(1);
3044 const int64_t iw = inputType.getDimSize(2);
3050 if (ih != ShapedType::kDynamic && ih != 1) {
3051 const std::optional<int64_t> calculatedOutHeightMinusOne =
3052 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3053 if (!calculatedOutHeightMinusOne.has_value())
3054 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3056 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3057 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3058 <<
") / " << scaleYD;
3059 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3060 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3061 return emitOpError(
"calculated output height did not match expected: ")
3062 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3069 if (iw != ShapedType::kDynamic && iw != 1) {
3070 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3071 const std::optional<int64_t> calculatedOutWidthMinusOne =
3073 if (!calculatedOutWidthMinusOne.has_value())
3074 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3076 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3077 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3078 <<
") / " << scaleXD;
3079 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3080 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3081 return emitOpError(
"calculated output width did not match expected: ")
3082 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3088LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3089 MLIRContext *context, ::std::optional<Location> location,
3090 ScatterOp::Adaptor adaptor,
3091 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3092 llvm::SmallVector<int64_t> outputShape;
3093 outputShape.resize(3, ShapedType::kDynamic);
3095 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3096 if (valuesInShape.hasRank()) {
3097 outputShape[0] = valuesInShape.getDimSize(0);
3098 outputShape[1] = valuesInShape.getDimSize(1);
3099 outputShape[2] = valuesInShape.getDimSize(2);
3102 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3103 if (indicesShape.hasRank()) {
3104 if (outputShape[0] == ShapedType::kDynamic)
3105 outputShape[0] = indicesShape.getDimSize(0);
3108 ShapeAdaptor inputShape(adaptor.getInput().getType());
3109 if (inputShape.hasRank()) {
3110 if (outputShape[0] == ShapedType::kDynamic)
3111 outputShape[0] = inputShape.getDimSize(0);
3112 if (outputShape[2] == ShapedType::kDynamic)
3113 outputShape[2] = inputShape.getDimSize(2);
3116 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3120LogicalResult tosa::ScatterOp::verify() {
3130 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3132 const ShapeAdaptor inputShape(getInput().
getType());
3133 const ShapeAdaptor outputShape(getValuesOut().
getType());
3135 int64_t n = ShapedType::kDynamic;
3136 int64_t k = ShapedType::kDynamic;
3137 int64_t w = ShapedType::kDynamic;
3138 int64_t c = ShapedType::kDynamic;
3139 if (valuesInShape.hasRank()) {
3140 n = valuesInShape.getDimSize(0);
3141 k = valuesInShape.getDimSize(1);
3142 c = valuesInShape.getDimSize(2);
3144 if (indicesShape.hasRank()) {
3145 const int64_t indicesN = indicesShape.getDimSize(0);
3146 w = indicesShape.getDimSize(1);
3147 if (n == ShapedType::kDynamic)
3149 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3150 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3151 <<
", got " << indicesN;
3153 if (inputShape.hasRank()) {
3154 const int64_t inputN = inputShape.getDimSize(0);
3155 const int64_t inputW = inputShape.getDimSize(1);
3156 const int64_t inputC = inputShape.getDimSize(2);
3157 if (n == ShapedType::kDynamic)
3159 else if (inputN != ShapedType::kDynamic && n != inputN)
3160 return emitOpError() <<
"requires input dimension 0 to have size " << n
3161 <<
", got " << inputN;
3162 if (w == ShapedType::kDynamic)
3164 else if (inputW != ShapedType::kDynamic && w != inputW)
3165 return emitOpError() <<
"requires input dimension 1 to have size " << w
3166 <<
", got " << inputW;
3168 if (c == ShapedType::kDynamic)
3170 else if (inputC != ShapedType::kDynamic && c != inputC)
3171 return emitOpError() <<
"requires input dimension 2 to have size " << c
3172 <<
", got " << inputC;
3174 if (outputShape.hasRank()) {
3175 const int64_t outputN = outputShape.getDimSize(0);
3176 const int64_t outputK = outputShape.getDimSize(1);
3177 const int64_t outputC = outputShape.getDimSize(2);
3178 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3180 return emitOpError() <<
"requires values_out dimension 0 to have size "
3181 << n <<
", got " << outputN;
3182 if (k == ShapedType::kDynamic)
3184 else if (outputK != ShapedType::kDynamic && k != outputK)
3185 return emitOpError() <<
"requires values_out dimension 1 to have size "
3186 << k <<
", got " << outputK;
3187 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3189 return emitOpError() <<
"requires values_out dimension 2 to have size "
3190 << c <<
", got " << outputC;
3192 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3193 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3202 int64_t axisVal = axis.getValue().getSExtValue();
3203 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3209 operandShape.
getDims(outputShape);
3210 outputShape[axisVal] = 1;
3215#define COMPATIBLE_RETURN_TYPES(OP) \
3216 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3217 if (l.size() != r.size() || l.size() != 1) \
3219 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3221 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3224#define REDUCE_SHAPE_INFER(OP) \
3225 LogicalResult OP::inferReturnTypeComponents( \
3226 MLIRContext *context, ::std::optional<Location> location, \
3227 OP::Adaptor adaptor, \
3228 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3230 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3231 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3232 const Properties &prop = adaptor.getProperties(); \
3233 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3234 inferredReturnShapes); \
3236 COMPATIBLE_RETURN_TYPES(OP)
3244#undef REDUCE_SHAPE_INFER
3246#undef COMPATIBLE_RETURN_TYPES
3248template <
typename T>
3251 TensorType inputType = op.getInput().getType();
3252 TensorType outputType = op.getOutput().getType();
3253 int32_t reduceAxis = op.getAxis();
3255 if (reduceAxis < 0) {
3256 op.emitOpError(
"reduce axis must not be negative");
3260 int64_t inputRank = inputType.getRank();
3263 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3264 op.emitOpError(
"expect input tensor rank (")
3265 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3271 int64_t outputRank = outputType.getRank();
3272 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3274 "expect output tensor rank to be equal to input tensor rank");
3277 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3278 op.emitOpError(
"expect output tensor rank (")
3279 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3285 if (outputRank != 0) {
3286 auto outputShape = outputType.
getShape();
3287 if (!outputType.isDynamicDim(reduceAxis) &&
3288 outputShape[reduceAxis] != 1) {
3289 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3290 << outputShape[reduceAxis];
3298LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3299LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3300LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3301LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3302LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3303LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3317#define NARY_SHAPE_INFER(OP) \
3318 LogicalResult OP::inferReturnTypeComponents( \
3319 MLIRContext *context, ::std::optional<Location> location, \
3320 ValueShapeRange operands, DictionaryAttr attributes, \
3321 OpaqueProperties properties, RegionRange regions, \
3322 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3323 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3363#undef PRED_SHAPE_INFER
3365LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3366 MLIRContext *context, ::std::optional<Location> location,
3367 NegateOp::Adaptor adaptor,
3369 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3374LogicalResult tosa::NegateOp::verify() {
3376 const Type input1Type = getInput1().getType();
3377 const Type outputType = getOutput().getType();
3382 const SmallVector<Type, 2> types = {input1Type, outputType};
3384 return emitOpError() <<
"requires the same shape for input1 and output";
3387 const Type input1ZpEType =
3389 if (input1EType != input1ZpEType) {
3390 return emitOpError(
"expect both input1 and its zero point are the same "
3391 "element type, got ")
3392 << input1EType <<
" and " << input1ZpEType;
3395 const Type outputZpEType =
3397 if (outputEType != outputZpEType) {
3398 return emitOpError(
"expect both output and its zero point are the same "
3399 "element type, got ")
3400 << outputEType <<
" and " << outputZpEType;
3403 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3404 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3407 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3408 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3419 outputShape.resize(4, ShapedType::kDynamic);
3434 if (ShapedType::isStatic(height)) {
3435 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3436 outputShape[1] = padded / stride[0] + 1;
3439 if (ShapedType::isStatic(width)) {
3440 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3441 outputShape[2] = padded / stride[1] + 1;
3448LogicalResult Conv2DOp::inferReturnTypeComponents(
3449 MLIRContext *context, ::std::optional<Location> location,
3450 Conv2DOp::Adaptor adaptor,
3451 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3452 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3454 int64_t inputWidth = ShapedType::kDynamic;
3455 int64_t inputHeight = ShapedType::kDynamic;
3456 int64_t weightWidth = ShapedType::kDynamic;
3457 int64_t weightHeight = ShapedType::kDynamic;
3461 ShapeAdaptor inputShape(adaptor.getInput().getType());
3462 if (inputShape.hasRank()) {
3463 outputShape[0] = inputShape.getDimSize(0);
3464 inputHeight = inputShape.getDimSize(1);
3465 inputWidth = inputShape.getDimSize(2);
3469 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3470 if (weightShape.hasRank()) {
3471 outputShape[3] = weightShape.getDimSize(0);
3472 weightHeight = weightShape.getDimSize(1);
3473 weightWidth = weightShape.getDimSize(2);
3477 ShapeAdaptor biasShape(adaptor.getBias().getType());
3478 if (biasShape.hasRank()) {
3479 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3480 ? biasShape.getDimSize(0)
3484 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3485 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3486 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3488 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3489 int64_t inputSize = inputHeight + padding[0] + padding[1];
3490 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3491 int64_t unstridedResult = inputSize - filterSize + 1;
3492 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3495 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3496 int64_t inputSize = inputWidth + padding[2] + padding[3];
3497 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3498 int64_t unstridedResult = inputSize - filterSize + 1;
3499 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3502 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3506LogicalResult Conv2DOp::verify() {
3513LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3514 MLIRContext *context, ::std::optional<Location> location,
3515 Conv2DBlockScaledOp::Adaptor adaptor,
3516 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3517 SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
3519 int64_t inputWidth = ShapedType::kDynamic;
3520 int64_t inputHeight = ShapedType::kDynamic;
3521 int64_t weightWidth = ShapedType::kDynamic;
3522 int64_t weightHeight = ShapedType::kDynamic;
3525 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3526 if (inputDataShape.hasRank()) {
3527 outShape[0] = inputDataShape.getDimSize(0);
3528 inputHeight = inputDataShape.getDimSize(1);
3529 inputWidth = inputDataShape.getDimSize(2);
3531 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3532 if (inputScaleShape.hasRank()) {
3533 outShape[0] = ShapedType::isDynamic(outShape[0])
3534 ? inputScaleShape.getDimSize(0)
3536 inputHeight = ShapedType::isDynamic(inputHeight)
3537 ? inputScaleShape.getDimSize(1)
3539 inputWidth = ShapedType::isDynamic(inputWidth)
3540 ? inputScaleShape.getDimSize(2)
3545 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3546 if (weightDataShape.hasRank()) {
3547 outShape[3] = weightDataShape.getDimSize(0);
3548 weightHeight = weightDataShape.getDimSize(1);
3549 weightWidth = weightDataShape.getDimSize(2);
3551 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3552 if (weightScaleShape.hasRank()) {
3553 outShape[3] = ShapedType::isDynamic(outShape[3])
3554 ? weightScaleShape.getDimSize(0)
3556 weightHeight = ShapedType::isDynamic(weightHeight)
3557 ? weightScaleShape.getDimSize(1)
3559 weightWidth = ShapedType::isDynamic(weightWidth)
3560 ? weightScaleShape.getDimSize(2)
3565 const ShapeAdaptor biasShape(adaptor.getBias().getType());
3566 if (biasShape.hasRank()) {
3567 const int64_t biasSize = biasShape.getDimSize(0);
3569 if (biasSize != 1) {
3570 outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
3574 SmallVector<int64_t> padValues;
3575 SmallVector<int64_t> strideValues;
3576 SmallVector<int64_t> dilationValues;
3582 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3586 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3587 const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
3588 const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
3589 const int64_t unstridedResult = inputSize - filterSize + 1;
3590 outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
3593 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3594 const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
3595 const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
3596 const int64_t unstridedResult = inputSize - filterSize + 1;
3597 outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
3600 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3604LogicalResult Conv2DBlockScaledOp::verify() {
3606 getWeightData().
getType(),
"input_data",
3609 getWeightScale().
getType(),
"input_scale",
3612 getOutput().
getType(),
"bias",
"output")))
3616 int64_t N = ShapedType::kDynamic;
3617 int64_t IH = ShapedType::kDynamic;
3618 int64_t IW = ShapedType::kDynamic;
3619 int64_t IC = ShapedType::kDynamic;
3620 int64_t multiplesOfIC = ShapedType::kDynamic;
3621 int64_t OC = ShapedType::kDynamic;
3622 int64_t KH = ShapedType::kDynamic;
3623 int64_t KW = ShapedType::kDynamic;
3625 const ShapeAdaptor inputDataShape(getInputData().
getType());
3626 if (inputDataShape.hasRank()) {
3627 N = inputDataShape.getDimSize(0);
3628 IH = inputDataShape.getDimSize(1);
3629 IW = inputDataShape.getDimSize(2);
3630 IC = inputDataShape.getDimSize(3);
3633 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
3634 if (inputScaleShape.hasRank()) {
3636 "input_scale",
"batch size")) ||
3638 "input_scale",
"input height")) ||
3640 "input_scale",
"input width")))
3642 multiplesOfIC = inputScaleShape.getDimSize(3);
3645 const ShapeAdaptor weightDataShape(getWeightData().
getType());
3646 if (weightDataShape.hasRank()) {
3647 OC = weightDataShape.getDimSize(0);
3648 KH = weightDataShape.getDimSize(1);
3649 KW = weightDataShape.getDimSize(2);
3651 "weight_data",
"input channels")))
3655 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
3656 if (weightScaleShape.hasRank()) {
3658 "weight_scale",
"output channels")) ||
3660 "weight_scale",
"kernel height")) ||
3662 "weight_scale",
"kernel width")) ||
3664 weightScaleShape.getDimSize(3),
3665 "weight_scale",
"input channel blocks")))
3670 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3671 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3672 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
3673 << IC <<
", block_size=" << blockSize;
3676 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3677 multiplesOfIC != IC / blockSize)
3679 "expect scale operands dimension 2 to equal IC/block_size (")
3680 << IC <<
"/" << blockSize <<
")"
3681 <<
", got " << multiplesOfIC;
3684 SmallVector<int64_t> padValues;
3686 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
3687 return emitOpError(
"expect all padding values to be >= 0, got ")
3691 SmallVector<int64_t> strideValues;
3693 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
3694 return emitOpError(
"expect all stride values to be >= 1, got ")
3698 SmallVector<int64_t> dilationValues;
3701 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
3702 return emitOpError(
"expect all dilation values to be >= 1, got ")
3707 const ShapeAdaptor outputShape(getOutput().
getType());
3708 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3709 outputShape.hasRank()) {
3711 padValues[0], padValues[1], strideValues[0],
3712 dilationValues[0],
"height",
"y",
"top",
3715 padValues[2], padValues[3], strideValues[1],
3716 dilationValues[1],
"width",
"x",
"left",
3722 const ShapeAdaptor biasShape(getBias().
getType());
3723 if (biasShape.hasRank() && outputShape.hasRank()) {
3724 const int64_t biasChannels = biasShape.getDimSize(0);
3725 const int64_t outputChannels =
3726 outputShape.getDimSize(outputShape.getRank() - 1);
3727 if (biasChannels == ShapedType::kDynamic ||
3728 outputChannels == ShapedType::kDynamic)
3732 if (biasChannels != outputChannels && biasChannels != 1)
3734 "bias channels expected to be equal to output channels (")
3735 << outputChannels <<
") or 1, got " << biasChannels;
3741LogicalResult Conv3DOp::inferReturnTypeComponents(
3742 MLIRContext *context, ::std::optional<Location> location,
3743 Conv3DOp::Adaptor adaptor,
3744 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3745 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3747 int64_t inputWidth = ShapedType::kDynamic;
3748 int64_t inputHeight = ShapedType::kDynamic;
3749 int64_t inputDepth = ShapedType::kDynamic;
3751 int64_t weightWidth = ShapedType::kDynamic;
3752 int64_t weightHeight = ShapedType::kDynamic;
3753 int64_t weightDepth = ShapedType::kDynamic;
3756 ShapeAdaptor inputShape(adaptor.getInput().getType());
3757 if (inputShape.hasRank()) {
3758 outputShape[0] = inputShape.getDimSize(0);
3759 inputDepth = inputShape.getDimSize(1);
3760 inputHeight = inputShape.getDimSize(2);
3761 inputWidth = inputShape.getDimSize(3);
3765 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3766 if (weightShape.hasRank()) {
3767 outputShape[4] = weightShape.getDimSize(0);
3768 weightDepth = weightShape.getDimSize(1);
3769 weightHeight = weightShape.getDimSize(2);
3770 weightWidth = weightShape.getDimSize(3);
3774 ShapeAdaptor biasShape(adaptor.getBias().getType());
3775 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3776 outputShape[4] = biasShape.getDimSize(0);
3779 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3780 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3781 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3783 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3784 int32_t inputSize = inputDepth + pad[0] + pad[1];
3785 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3786 int32_t unstridedResult = inputSize - filterSize + 1;
3787 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3790 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3791 int32_t inputSize = inputHeight + pad[2] + pad[3];
3792 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3793 int32_t unstridedResult = inputSize - filterSize + 1;
3794 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3797 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3798 int32_t inputSize = inputWidth + pad[4] + pad[5];
3799 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3800 int32_t unstridedResult = inputSize - filterSize + 1;
3801 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3804 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3808LogicalResult Conv3DOp::verify() {
3815LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3816 MLIRContext *context, ::std::optional<Location> location,
3817 AvgPool2dOp::Adaptor adaptor,
3818 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3819 ShapeAdaptor inputShape(adaptor.getInput().getType());
3820 const Properties &prop = adaptor.getProperties();
3822 inferredReturnShapes);
3825LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3826 MLIRContext *context, ::std::optional<Location> location,
3827 MaxPool2dOp::Adaptor adaptor,
3828 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3829 ShapeAdaptor inputShape(adaptor.getInput().getType());
3830 const Properties &prop = adaptor.getProperties();
3832 inferredReturnShapes);
3835LogicalResult MaxPool2dOp::verify() {
3846LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3847 MLIRContext *context, ::std::optional<Location> location,
3848 DepthwiseConv2DOp::Adaptor adaptor,
3849 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3850 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3852 int64_t inputWidth = ShapedType::kDynamic;
3853 int64_t inputHeight = ShapedType::kDynamic;
3854 int64_t inputChannels = ShapedType::kDynamic;
3856 int64_t weightWidth = ShapedType::kDynamic;
3857 int64_t weightHeight = ShapedType::kDynamic;
3858 int64_t depthChannels = ShapedType::kDynamic;
3861 ShapeAdaptor inputShape(adaptor.getInput().getType());
3862 if (inputShape.hasRank()) {
3863 outputShape[0] = inputShape.getDimSize(0);
3864 inputHeight = inputShape.getDimSize(1);
3865 inputWidth = inputShape.getDimSize(2);
3866 inputChannels = inputShape.getDimSize(3);
3870 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3871 if (weightShape.hasRank()) {
3872 weightHeight = weightShape.getDimSize(0);
3873 weightWidth = weightShape.getDimSize(1);
3874 inputChannels = ShapedType::isDynamic(inputChannels)
3875 ? weightShape.getDimSize(2)
3877 depthChannels = weightShape.getDimSize(3);
3882 if (ShapedType::isStatic(inputChannels) &&
3883 ShapedType::isStatic(depthChannels)) {
3884 outputShape[3] = inputChannels * depthChannels;
3888 ShapeAdaptor biasShape(adaptor.getBias().getType());
3889 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3890 int64_t bc = biasShape.getDimSize(0);
3891 if (bc != ShapedType::kDynamic && bc != 1)
3892 outputShape[3] = bc;
3895 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3896 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3897 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3899 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3900 int64_t inputSize = inputHeight + padding[0] + padding[1];
3901 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3902 int64_t unstridedResult = inputSize - filterSize + 1;
3903 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3906 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3907 int64_t inputSize = inputWidth + padding[2] + padding[3];
3908 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3909 int64_t unstridedResult = inputSize - filterSize + 1;
3910 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3913 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3917LogicalResult DepthwiseConv2DOp::verify() {
3924LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3925 MLIRContext *context, ::std::optional<Location> location,
3926 TransposeConv2DOp::Adaptor adaptor,
3927 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3928 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3930 int64_t inputWidth = ShapedType::kDynamic;
3931 int64_t inputHeight = ShapedType::kDynamic;
3932 int64_t weightWidth = ShapedType::kDynamic;
3933 int64_t weightHeight = ShapedType::kDynamic;
3936 ShapeAdaptor inputShape(adaptor.getInput().getType());
3937 if (inputShape.hasRank()) {
3938 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3939 ? inputShape.getDimSize(0)
3941 inputHeight = inputShape.getDimSize(1);
3942 inputWidth = inputShape.getDimSize(2);
3946 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3947 if (weightShape.hasRank()) {
3948 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3949 ? weightShape.getDimSize(0)
3951 weightHeight = weightShape.getDimSize(1);
3952 weightWidth = weightShape.getDimSize(2);
3956 ShapeAdaptor biasShape(adaptor.getBias().getType());
3957 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3958 int64_t bc = biasShape.getDimSize(0);
3959 if (bc != ShapedType::kDynamic && bc != 1)
3960 outputShape[3] = bc;
3963 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3964 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3966 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3967 int64_t calculateSize =
3968 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3970 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3973 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3974 int64_t calculateSize =
3975 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3977 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3980 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3984LogicalResult TransposeConv2DOp::verify() {
3988 const llvm::ArrayRef<int64_t> strides = getStride();
3989 const int64_t strideY = strides[0];
3990 const int64_t strideX = strides[1];
3992 if (strideY < 1 || strideX < 1)
3993 return emitOpError(
"expect all stride values to be >= 1, got [")
3996 const auto checkPadAgainstKernelDim =
3997 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
3998 llvm::StringRef kernelDimName) -> LogicalResult {
3999 if (padValue <= -kernelDimSize)
4001 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4002 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4007 const llvm::ArrayRef<int64_t> padding = getOutPad();
4008 const int64_t outPadTop = padding[0];
4009 const int64_t outPadBottom = padding[1];
4010 const int64_t outPadLeft = padding[2];
4011 const int64_t outPadRight = padding[3];
4013 const auto weightType =
4014 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4017 const int64_t kernelHeight = weightType.getDimSize(1);
4018 if (ShapedType::isStatic(kernelHeight)) {
4019 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4020 "out_pad_top",
"KH")))
4023 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4024 "out_pad_bottom",
"KH")))
4028 const int64_t kernelWidth = weightType.getDimSize(2);
4029 if (ShapedType::isStatic(kernelWidth)) {
4030 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4031 "out_pad_left",
"KW")))
4034 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4035 "out_pad_right",
"KW")))
4041 const auto outputType =
4042 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4046 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4047 if (inputType && weightType) {
4048 const int64_t inputHeight = inputType.getDimSize(1);
4049 const int64_t kernelHeight = weightType.getDimSize(1);
4050 const int64_t outputHeight = outputType.getDimSize(1);
4052 if (ShapedType::isStatic(inputHeight) &&
4053 ShapedType::isStatic(outputHeight)) {
4055 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4057 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4058 "+ out_pad_top + out_pad_bottom + KH, but got ")
4059 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4060 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4061 <<
" + " << kernelHeight;
4064 const int64_t inputWidth = inputType.getDimSize(2);
4065 const int64_t kernelWidth = weightType.getDimSize(2);
4066 const int64_t outputWidth = outputType.getDimSize(2);
4068 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4070 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4072 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4073 "+ out_pad_left + out_pad_right + KW, but got ")
4074 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4075 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4080 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4085 const int64_t biasChannels = biasType.getDimSize(0);
4088 if (biasChannels == ShapedType::kDynamic)
4091 const int64_t outputChannels = outputType.getDimSize(3);
4092 if (!ShapedType::isDynamic(outputChannels) &&
4093 biasChannels != outputChannels && biasChannels != 1)
4095 "bias channels expected to be equal to output channels (")
4096 << outputChannels <<
") or 1, got " << biasChannels;
4101LogicalResult RescaleOp::verify() {
4102 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
4104 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
4108 auto inputElementType =
4110 if (!mlir::isa<IntegerType>(inputElementType)) {
4111 emitOpError(
"expect input to have integer element type, got ")
4112 << inputElementType;
4116 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
4118 emitOpError(
"expect shaped tensor for output, got ")
4119 << getOutput().getType();
4123 auto outputElementType =
4125 if (!mlir::isa<IntegerType>(outputElementType)) {
4126 emitOpError(
"expect output to have integer element type, got ")
4127 << outputElementType;
4139 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4140 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4143 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4144 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4147 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
4148 if (!multiplierType) {
4149 emitOpError(
"expect shaped tensor for multiplier, got ")
4150 << getMultiplier().getType();
4154 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
4156 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
4161 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4162 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4163 << multiplierType.getElementType();
4168 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4170 "expect i16 element type for multiplier for scale32=false, got ")
4171 << multiplierType.getElementType();
4175 if (!inputType.hasRank())
4181 int64_t numChannels = 1;
4182 if (getPerChannel()) {
4183 if (inputType.getRank() < 1) {
4184 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4185 "true, but got rank ")
4186 << inputType.getRank();
4189 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4192 if (!multiplierType.hasRank())
4195 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4197 if (multiplierShape[0] != ShapedType::kDynamic &&
4198 multiplierShape[0] != numChannels) {
4200 << numChannels <<
" } for multiplier input, got { "
4201 << multiplierShape[0] <<
" }";
4205 if (!shiftType.hasRank())
4208 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4210 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4212 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
4219LogicalResult RescaleOp::inferReturnTypeComponents(
4220 MLIRContext *context, ::std::optional<Location> location,
4221 RescaleOp::Adaptor adaptor,
4222 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4223 ShapeAdaptor inputShape(adaptor.getInput().getType());
4224 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4228LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4229 MLIRContext *context, ::std::optional<Location> location,
4230 CastFromBlockScaledOp::Adaptor adaptor,
4231 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4232 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4233 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4237LogicalResult CastFromBlockScaledOp::verify() {
4238 const Type inputDataType = getInputData().getType();
4239 const Type outputDataType = getResult().getType();
4241 return emitOpError() <<
"require compatible shapes for input_data ("
4242 << inputDataType <<
") and " <<
"output_data ("
4243 << outputDataType <<
")";
4245 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4247 if (inputDataShape.
hasRank()) {
4248 const unsigned int blockSize =
4250 const int64_t inputDataLastDim =
4252 if (inputDataLastDim % blockSize != 0)
4253 return emitOpError() <<
"expect last dimension of input_data ("
4255 <<
") to be divisible by block_size (" << blockSize
4258 const Type inputScaleType = getInputScale().getType();
4259 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4261 if (inputScaleShape.
hasRank()) {
4262 SmallVector<int64_t> inputDataDims, inputScaleDims;
4263 inputDataShape.
getDims(inputDataDims);
4264 inputScaleShape.
getDims(inputScaleDims);
4266 if (inputDataDims.size() != inputScaleDims.size() ||
4268 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4269 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4271 <<
"require compatible shapes for input_data (" << inputDataType
4272 <<
") and " <<
"input_scale (" << inputScaleType
4273 <<
") except for the last dimension";
4275 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4276 inputScaleDims.back()};
4277 if (ShapedType::isStatic(inputDataLastDim) &&
4280 <<
"expect last dimension of input_scale ("
4281 << inputScaleDims.back()
4282 <<
") to be equal to last dimension of input_data / block_size ("
4283 << inputDataDims.back() / blockSize <<
")";
4290LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4291 MLIRContext *context, ::std::optional<Location> location,
4292 CastToBlockScaledOp::Adaptor adaptor,
4293 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4294 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4295 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4296 if (!inputShape.hasRank())
4300 SmallVector<int64_t> outputScaleShape;
4301 inputShape.getDims(outputScaleShape);
4302 const int64_t lastDimLoc = inputShape.getRank() - 1;
4303 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4304 if (ShapedType::isStatic(lastDimSize)) {
4305 const unsigned int blockSize =
4306 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4307 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4309 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4313LogicalResult CastToBlockScaledOp::verify() {
4314 const Type inputDataType = getInputData().getType();
4315 const Type outputDataType = getResult(0).getType();
4317 return emitOpError() <<
"require compatible shapes for input_data ("
4318 << inputDataType <<
") and " <<
"output_data ("
4319 << outputDataType <<
")";
4321 const unsigned int blockSize =
4323 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4324 if (inputDataShape.
hasRank()) {
4325 const int64_t inputDataLastDim =
4327 if (ShapedType::isStatic(inputDataLastDim) &&
4328 inputDataLastDim % blockSize != 0)
4329 return emitOpError() <<
"expect last dimension of input_data ("
4331 <<
") to be divisible by block_size (" << blockSize
4335 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4336 const Type outputScaleType = getResult(1).getType();
4337 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4339 SmallVector<int64_t> outputDataDims, outputScaleDims;
4340 outputDataShape.
getDims(outputDataDims);
4341 outputScaleShape.
getDims(outputScaleDims);
4343 if (outputDataDims.size() != outputScaleDims.size() ||
4345 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4346 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4347 return emitOpError() <<
"require compatible shapes for output_data ("
4348 << outputDataType <<
") and " <<
"output_scale ("
4350 <<
") except for the last dimension";
4352 const int64_t outputDataLastDim = outputDataDims.back();
4353 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4354 outputScaleDims.back()};
4355 if (ShapedType::isStatic(outputDataLastDim) &&
4358 <<
"expect last dimension of output_scale ("
4359 << outputScaleDims.back()
4360 <<
") to be equal to last dimension of output_data / block_size ("
4361 << outputDataDims.back() / blockSize <<
")";
4367LogicalResult IfOp::inferReturnTypeComponents(
4368 MLIRContext *context, ::std::optional<Location> location,
4369 IfOp::Adaptor adaptor,
4370 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4371 llvm::SmallVector<tosa::YieldOp> yieldOps;
4372 for (Region *region : adaptor.getRegions()) {
4373 for (
auto &block : *region)
4374 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4375 yieldOps.push_back(returnOp);
4378 if (yieldOps.empty())
4382 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4383 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4384 for (
auto operand : yieldOps.front().getOperands()) {
4385 resultKnowledge.push_back(
4389 for (
auto yieldOp : yieldOps) {
4390 if (resultKnowledge.size() != yieldOp.getNumOperands())
4393 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4394 int32_t index = it.index();
4396 resultKnowledge[index],
4400 resultKnowledge[index] = meet;
4404 for (
const ValueKnowledge &
result : resultKnowledge) {
4405 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4411LogicalResult WhileOp::inferReturnTypeComponents(
4412 MLIRContext *context, ::std::optional<Location> location,
4413 WhileOp::Adaptor adaptor,
4414 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4415 llvm::SmallVector<tosa::YieldOp> yieldOps;
4416 for (
auto &block : adaptor.getBodyGraph())
4417 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4418 yieldOps.push_back(returnOp);
4422 if (yieldOps.empty())
4426 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4427 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4428 for (
auto operand : yieldOps.front().getOperands()) {
4429 resultKnowledge.push_back(
4433 for (
auto yieldOp : yieldOps) {
4434 if (resultKnowledge.size() != yieldOp.getNumOperands())
4437 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4438 int32_t index = it.index();
4440 resultKnowledge[index],
4442 resultKnowledge[index] = meet;
4447 for (
const ValueKnowledge &
result : resultKnowledge) {
4448 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4454std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4455 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4456 return llvm::to_vector<4>(vt.getShape());
4457 return std::nullopt;
4463 StringRef prefix =
"") {
4464 assert(blocksArgs.size() == initializers.size() &&
4465 "expected same length of arguments and initializers");
4466 if (initializers.empty())
4469 parser << prefix <<
'(';
4470 llvm::interleaveComma(
4471 llvm::zip(blocksArgs, initializers), parser,
4472 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4477ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
4479 result.regions.reserve(2);
4480 Region *thenRegion =
result.addRegion();
4481 Region *elseRegion =
result.addRegion();
4483 OpAsmParser::UnresolvedOperand cond;
4488 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4489 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4492 OptionalParseResult listResult =
4500 "expected type for condition operand");
4506 "expected type for condition operand");
4514 FunctionType functionType;
4518 <<
"expected list of types for block arguments "
4519 <<
"followed by arrow type and list of return types";
4521 result.addTypes(functionType.getResults());
4523 if (functionType.getNumInputs() != operands.size()) {
4525 <<
"expected as many input types as operands " <<
"(expected "
4526 << operands.size() <<
" got " << functionType.getNumInputs()
4557void IfOp::print(OpAsmPrinter &p) {
4558 p <<
" " << getCondition();
4561 getInputList(),
" ");
4563 p << getCondition().getType();
4565 if (!getInputList().empty()) {
4567 llvm::interleaveComma(getInputList().getTypes(), p);
4576 auto &elseRegion = getElseGraph();
4577 if (!elseRegion.
empty()) {
4585LogicalResult IfOp::verify() {
4587 "'then_graph' arguments", getInputList(),
4593 "'else_graph' arguments", getInputList(),
4599 if (getThenGraph().front().mightHaveTerminator()) {
4601 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4603 *
this, thenYield.getInputs(),
"'then_graph' results",
4604 getOutputList(),
"'output_list'")
4610 if (getElseGraph().front().mightHaveTerminator()) {
4612 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4614 *
this, elseYield.getInputs(),
"'else_graph' results",
4615 getOutputList(),
"'output_list'")
4620 auto condType = getCondition().getType();
4622 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4628LogicalResult WhileOp::verify() {
4630 getOutputList(),
"'output_list'")
4635 "'cond_graph' arguments", getInputList(),
4641 "'body_graph' arguments", getInputList(),
4646 if (getBodyGraph().front().mightHaveTerminator()) {
4648 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4650 "'body_graph' results",
4651 getInputList(),
"'input_list'")
4658 if (!getCondGraph().front().mightHaveTerminator())
4662 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4666 if (condYield.getInputs().size() != 1)
4667 return emitOpError() <<
"require 'cond_graph' only have one result";
4669 auto condOutType = condYield.getInputs()[0].getType();
4671 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4675 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4681LogicalResult ReverseOp::verify() {
4686 TensorType inputType = getInput1().getType();
4687 TensorType outputType = getOutput().getType();
4688 int32_t reverseAxis = getAxis();
4690 if (reverseAxis < 0)
4691 return emitOpError(
"expected non-negative reverse axis");
4693 int64_t inputRank = inputType.getRank();
4696 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4698 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4702 int64_t outputRank = outputType.getRank();
4703 if (inputType.
hasRank() && outputRank != inputType.getRank())
4705 "expect output tensor rank to be equal to input tensor rank");
4706 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4708 << outputRank <<
") to be larger than reverse axis ("
4709 << reverseAxis <<
")";
4714LogicalResult tosa::SelectOp::verify() {
4725 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4726 if (!predicateType) {
4727 return emitOpError(
"expect shaped tensor for input1, got ")
4728 << getInput1().getType();
4730 auto predicateElementType = predicateType.getElementType();
4731 if (!predicateElementType.isInteger(1)) {
4732 return emitOpError(
"expect element type of bool for input1, got ")
4733 << predicateElementType;
4739LogicalResult tosa::VariableReadOp::verify() {
4747LogicalResult tosa::VariableWriteOp::verify() {
4756ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
4757 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4758 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4759 Region *cond =
result.addRegion();
4760 Region *body =
result.addRegion();
4762 OptionalParseResult listResult =
4767 FunctionType functionType;
4772 result.addTypes(functionType.getResults());
4774 if (functionType.getNumInputs() != operands.size()) {
4776 <<
"expected as many input types as operands " <<
"(expected "
4777 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
4787 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4788 regionArgs[i].type = functionType.getInput(i);
4790 return failure(parser.
parseRegion(*cond, regionArgs) ||
4795void WhileOp::print(OpAsmPrinter &parser) {
4797 getInputList(),
" ");
4800 getResults().getTypes());
4814 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4815 if (llvm::isa<FloatType>(srcElemType)) {
4817 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4818 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4820 if (llvm::isa<IntegerType>(srcElemType)) {
4823 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4825 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4826 return std::nullopt;
4834 return mlir::isa<tosa::shapeType>(t);
4841 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4847 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4848 Operation *definingOp = v.getDefiningOp();
4850 return op->
emitOpError(
"shape operand is not compile time resolvable");
4863 auto getRank = [](
const Type type) {
4864 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4870 for (
auto type : operandTypes) {
4871 if (getRank(type) != rank) {
4872 return op->
emitOpError(
"operands don't have matching ranks");
4875 for (
auto type : resultTypes) {
4876 if (getRank(type) != rank) {
4877 return op->
emitOpError(
"result shape has different rank than operands");
4887LogicalResult tosa::ConstShapeOp::verify() {
4889 auto valuesRank = getValues().getType().getRank();
4890 if (valuesRank != 1)
4891 return emitOpError(
"expect elements in attribute values with rank 1");
4893 auto count = getValues().getNumElements();
4894 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4895 if (count != rank && (count != 1 || rank != 0)) {
4896 return emitOpError(
"expect number of elements in attribute values (")
4897 << count <<
") to be equal to the rank (" << rank
4898 <<
") for the result shape type";
4903LogicalResult tosa::DimOp::verify() {
4904 const tosa::shapeType outShapeType =
4905 cast<tosa::shapeType>(getResult().
getType());
4906 if (outShapeType.getRank() != 1)
4907 return emitOpError(
"expect output shape type to contain one element, got ")
4912 const int64_t inputRank = inputType.getRank();
4913 const int64_t axis = getAxisAttr().getInt();
4914 if (axis < 0 || axis >= inputRank)
4915 return emitOpError(
"expect axis to be in the range [0, ")
4916 << inputRank <<
"), got " << axis;
4921LogicalResult tosa::ConcatShapeOp::verify() {
4922 const tosa::shapeType outShapeType =
4923 cast<tosa::shapeType>(getResult().
getType());
4924 const int64_t outputRank = outShapeType.getRank();
4927 if (inputList.size() == 0)
4928 return emitOpError(
"requires at least one input shape");
4930 if (llvm::any_of(inputList, [](Value v) {
4931 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
4933 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
4935 const int64_t inputsRank =
4936 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
4937 const tosa::shapeType inShapeType =
4938 cast<tosa::shapeType>(input.
getType());
4939 return acc + inShapeType.getRank();
4941 if (outputRank != inputsRank)
4942 return emitOpError(
"requires output shape rank to be equal to the sum of "
4943 "the input shape ranks (")
4944 << inputsRank <<
"), got " << outputRank;
4949LogicalResult tosa::SliceShapeOp::verify() {
4950 std::optional<int32_t> start;
4951 DenseIntElementsAttr startAttr;
4953 start = startAttr.getValues<int32_t>()[0];
4954 if (start && start.value() < 0)
4955 return emitOpError(
"expected non-negative start index, got ")
4958 std::optional<int32_t> size;
4959 DenseIntElementsAttr sizeAttr;
4961 size = sizeAttr.getValues<int32_t>()[0];
4962 if (size && size.value() <= 0)
4963 return emitOpError(
"expected positive size, got ") << size.value();
4968 const tosa::shapeType outShapeType =
4969 cast<tosa::shapeType>(getResult().
getType());
4970 const int64_t outputRank = outShapeType.getRank();
4971 if (outputRank != size)
4973 "expected output type size to be equal to size attribute, got ")
4974 << outputRank <<
" vs " << size.value();
4979 const tosa::shapeType inShapeType =
4980 cast<tosa::shapeType>(getInput().
getType());
4981 const int64_t inputRank = inShapeType.getRank();
4982 const int64_t sliceSize = start.value() + size.value();
4983 if (sliceSize > inputRank)
4984 return emitOpError(
"expected start + size to be less than or equal to "
4985 "input shape rank (")
4986 << inputRank <<
"), got " << sliceSize;
4995#define GET_ATTRDEF_CLASSES
4996#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5001#define GET_TYPEDEF_CLASSES
5002#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5008#define GET_OP_CLASSES
5009#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)