29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
56struct TosaInlinerInterface :
public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
65 IRMapping &map)
const final {
71 IRMapping &map)
const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
85 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer)
const override {
91 return ::writeAttribute(attr, writer);
97 Type readType(DialectBytecodeReader &reader)
const override {
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer)
const override {
103 return ::writeType(type, writer);
106 void writeVersion(DialectBytecodeWriter &writer)
const final {
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader)
const final {
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version)
const final {
131 return {&getBodyGraph()};
140 return dim == -1 ? ShapedType::kDynamic : dim;
146 Type elementType = variableOp.getType();
149 return RankedTensorType::get(
shape, elementType);
156void TosaDialect::initialize() {
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
200ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
202 TypeAttr &typeAttr) {
203 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
206 <<
"expected ranked type";
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
216 <<
"expected shaped type";
233 <<
"expected attribute";
235 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
240 <<
"expected Typed attr";
243 initialValueAttr =
nullptr;
247 <<
"expected type after colon";
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
254 TypeAttr typeAttr,
Attribute initialValueAttr) {
255 bool needsSpace =
false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
267 if (initialValueAttr) {
278template <
typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
281 llvm::StringRef name;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name ==
"rounding_mode" &&
291 auto sym = symbolizeRoundingMode(kw);
294 <<
"invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 auto sym = symbolizeResizeMode(kw);
306 <<
"invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 auto sym = symbolizeNanPropagationMode(kw);
319 <<
"invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 auto sym = symbolizeBlockSize(kw);
332 <<
"invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
345template <
typename EnumType>
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
382 parser << namedAttr.
getName().strref() <<
" = ";
384 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 }
else if (
auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
410 toPrint.erase(attr.getName());
416 if (!toPrint.empty()) {
418 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
435 llvm::interleaveComma(op->
getAttrs(), parser,
437 printNamedAttr(parser, namedAttr);
449 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
453 printWithEnumHandling(parser, *
this);
457 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
461 printWithEnumHandling(parser, *
this);
465 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
469 printWithEnumHandling(parser, *
this);
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
477 printWithNanPropagationHandling(parser, *
this);
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
485 printWithNanPropagationHandling(parser, *
this);
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
493 printWithNanPropagationHandling(parser, *
this);
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
501 printWithNanPropagationHandling(parser, *
this);
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
509 printWithNanPropagationHandling(parser, *
this);
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
517 printWithNanPropagationHandling(parser, *
this);
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
525 printWithNanPropagationHandling(parser, *
this);
528ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
534 printWithEnumHandling(parser, *
this);
537ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
542void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *
this);
546ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552 printWithEnumHandling(parser, *
this);
555ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
561 printWithEnumHandling(parser, *
this);
576 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
586 Value valZp, StringRef name) {
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
595 if (!bothInts || !sameBitWidth) {
597 <<
"expected " << name <<
" and " << name
598 <<
"_zp to both be integer of the same bitwidth, but got " << eType
599 <<
" vs. " << eZpType;
606 Value src, int32_t val) {
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
621 if (dyn_cast<tosa::mxint8Type>(type))
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
635 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
637 << dimName <<
" of " << operandName <<
" to match size " << currDim
638 <<
", got " << newDim;
646 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
654 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
657 if (!calculatedOutSizeMinusOne.has_value())
659 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
660 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
661 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
662 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
663 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
669 << dimName <<
" did not match expected: "
670 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
693 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
696 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
699 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
702 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
709 "expect both bias and result to have same element type, got ")
710 << biasEType <<
" and " << resultEType;
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
718 "expect both input and weight to have same element type, got ")
719 << inputEType <<
" and " << weightEType;
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
728 if (inputIsFloat != weightIsFloat) {
730 "expect both input and weight to be float or not together, got ")
731 << inputEType <<
" and " << weightEType;
736 if (inputEType != inputZpEType) {
737 return op.emitOpError(
"expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType <<
" and " << inputZpEType;
743 if (weightEType != weightZpEType) {
744 return op.emitOpError(
"expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType <<
" and " << weightZpEType;
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
760LogicalResult tosa::ConstOp::verify() {
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
765 if (!attrType || !outputType) {
766 emitOpError(
"expected tensors for attr/result type");
770 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError(
"expected same attr/result element types");
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
789 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
822 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
836 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
837 return op.emitOpError(
"expect all padding values to be >= 0, got ")
841 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
842 return op.emitOpError(
"expect all stride values to be >= 1, got ")
846 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
847 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
861 if (inputType && weightType) {
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0],
"height",
"y",
"top",
"bottom")))
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1],
"width",
"x",
"left",
"right")))
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0],
"height",
"y",
"top",
"bottom")))
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1],
"width",
"x",
"left",
"right")))
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0],
"depth",
"d",
"front",
"back")))
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1],
"height",
"y",
"top",
"bottom")))
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2],
"width",
"x",
"left",
"right")))
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
920 const int64_t biasChannels = biasType.getDimSize(0);
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels <<
") or 1, got " << biasChannels;
938 StringRef name1,
Type type2,
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
949 <<
"require same element type for " << name1 <<
" (" << elemType1
950 <<
") and " << name2 <<
" (" << elemType2 <<
")";
954 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
955 << name2 <<
" (" << type2 <<
")";
965 if (list1.size() != list2.size())
967 <<
"require same number of values in " << name1 <<
" ("
968 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
970 for (
auto [type1, type2] :
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
997 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1001 return op->emitOpError(
"'")
1002 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1014template <
typename T>
1016 StringRef aName =
"input",
1017 StringRef bName =
"output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1021 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1025 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1041 op.emitOpError(
"expect ")
1042 << aName <<
" and " << bName <<
" to have same element type, got "
1043 << aElementType <<
" and " << bElementType;
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1053 if (
const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError(
"result tensor is not of integer type");
1057 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1058 if (!inputType.hasRank())
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError(
"specified axis is outside the rank of the tensor");
1066 if (!resultType.hasRank())
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1075 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1080template <
typename T>
1083 if (llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1084 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1088 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1089 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1093 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1094 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError(
"expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError(
"expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop <<
", pad_bottom=" << padBottom
1113 <<
", kernel_y=" << kernelY;
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1122 const auto verifyOutputSize =
1126 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError(
"expected input_")
1136 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1137 << padAfterName <<
" - kernel_" << dimAxis
1138 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1139 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1140 << kernelSize <<
") / " << strideSize;
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError(
"calculated output ")
1145 << dimName <<
" did not match expected: " <<
"calculated="
1146 << calculatedOutSize <<
", expected=" << outputSize;
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height",
"y",
"top",
"bottom")))
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width",
"x",
"left",
"right")))
1164LogicalResult tosa::AvgPool2dOp::verify() {
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError(
"accumulator type for integer tensor is not i32");
1177 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1180 if (inputETy.
isBF16() && !accType.isF32())
1181 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1183 if (inputETy.
isF32() && !accType.isF32())
1184 return emitOpError(
"accumulator type for f32 tensor is not f32");
1186 if (inputETy != inputZpETy)
1187 return emitOpError(
"expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy <<
" and " << inputZpETy;
1191 if (resultETy != outputZpETy)
1192 return emitOpError(
"expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy <<
" and " << outputZpETy;
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1207LogicalResult tosa::ClampOp::verify() {
1209 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1210 if (
auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1215 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1216 if (
auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1220 if (inputETy != outputETy)
1221 return emitOpError(
"input/output element types are incompatible.");
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1228 if (inputETy.
isInteger(dataTypeBitWidth)) {
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError(
"min/max attributes types are incompatible with "
1238 "input/output element types.");
1241 const bool isBoolean = inputETy.
isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError(
"expected min_val <= max_val, got min_val=")
1246 << minValAttr <<
", max_val=" << maxValAttr;
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError(
"min/max attributes types are incompatible with "
1257 "input/output element types.");
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr <<
", max_val=" << maxValAttr;
1265 if (maxVal < minVal)
1266 return emitOpError(
"expected min_val <= max_val, got min_val=")
1267 << minValAttr <<
", max_val=" << maxValAttr;
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute(
"pad", pad);
1289 result.addAttribute(
"stride", stride);
1290 result.addAttribute(
"dilation", dilation);
1291 result.addAttribute(
"acc_type", accType);
1292 Type finalOutputType = outputType;
1298 result.addTypes(finalOutputType);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute(
"out_pad", outpad);
1311 result.addAttribute(
"stride", stride);
1312 result.addAttribute(
"acc_type", accType);
1313 Type finalOutputType = outputType;
1319 result.addTypes(finalOutputType);
1330 result.addOperands({a,
b, zps.first, zps.second});
1332 Type finalOutputType{outputType};
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType &&
"Output must be a shaped type");
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1346 finalOutputType = outputShapedType.clone(accElementType);
1348 result.addTypes(finalOutputType);
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1363 if (
auto quantAttr =
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1368 const std::optional<Value> inputZpOp =
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1375 const std::optional<Value> outputZpOp =
1378 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1388 result.addOperands({input});
1390 result.addAttribute(
"kernel", kernel);
1391 result.addAttribute(
"stride", stride);
1392 result.addAttribute(
"pad", pad);
1393 result.addAttribute(
"acc_type", accType);
1394 result.types.push_back(outputType);
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1411 const std::optional<Value> input1ZpOp =
1415 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1418 const std::optional<Value> outputZpOp =
1422 loc,
"Failed to create output zero point for quantized NEGATE op");
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1431 result.addOperands({input});
1434 result.types.push_back(outputType);
1447 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1455 StringRef name,
Type variableType,
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1462 (
void)
emitError(loc,
"variable type must be a shaped type");
1465 if (!shapedType.hasRank()) {
1466 (
void)
emitError(loc,
"variable type must be a ranked type");
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1475 result.addAttribute(
"sym_name", nameAttr);
1476 result.addAttribute(
"var_shape", varShapeAttr);
1477 result.addAttribute(
"type", elementTypeAttr);
1478 result.addAttribute(
"initial_value", initialValue);
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1501 for (
int i = 0, e = operands.size(); i != e; ++i) {
1503 if (!
shape.hasRank()) {
1508 outRank = std::max<int64_t>(outRank,
shape.getRank());
1511 outShape.resize(outRank, 1);
1513 for (
int i = 0, e = operands.size(); i != e; ++i) {
1515 auto rankDiff = outShape.size() -
shape.getRank();
1517 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 =
shape.getDimSize(i);
1521 const FailureOr<int64_t> maybeResolvedDim =
1523 if (failed(maybeResolvedDim))
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1541 if (!inputShape.hasRank()) {
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1551 outShape.push_back(inputShape.getDimSize(i));
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1564 if (!inputShape.hasRank())
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1589 << dimName <<
" to be a power of two, got " << dimSize;
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1597 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1620 outputType.getShape().drop_back())))
1621 return emitOpError(
"expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType <<
" output=" << outputType;
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1640 inferredReturnShapes.push_back(
1642 inferredReturnShapes.push_back(
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1652 if (!inputRealType || !inputImagType)
1655 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a :
b;
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1682 bool hasRankedInput =
false;
1683 for (
auto operand : adaptor.getOperands()) {
1685 if (!operandShape.hasRank())
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1693 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1705 hasRankedInput =
true;
1708 if (adaptor.getInput1().empty())
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1720 for (
auto operand : adaptor.getOperands()) {
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1730 concatDimSize += operandShape.getDimSize(axis);
1733 outputShape[axis] = concatDimSize;
1739LogicalResult tosa::ConcatOp::verify() {
1741 auto outType = getOutput().getType();
1745 if (inputList.empty())
1748 if (!llvm::all_of(inputList, [&](
auto input) {
1750 *
this, input.getType(), outType));
1755 const int32_t axis = getAxis();
1757 for (
const auto &input : inputList) {
1758 const Type inputType = input.getType();
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1763 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1764 return emitOpError(
"expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1771 const auto allOperandsHasRank = [](
const Value input) {
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1777 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum =
index + 1;
1783 if (inputRank != firstInputRank)
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1790 for (
int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1796 if (inputDim != firstInputDim)
1797 return emitOpError(
"expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1800 <<
" on operands 0 and " << operandNum;
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError(
"expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() <<
" vs " << firstInputRank;
1811 for (
const auto &input : inputList) {
1813 if (inputShape.isDynamicDim(axis)) {
1818 axisSum += inputShape.getDimSize(axis);
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError(
"requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum <<
" and " << outputShape.getDimSize(axis);
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1837 auto elementType = IntegerType::get(context, 1);
1850 if (l.size() != r.size() || l.size() != 1)
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1864 outShape.resize(3, ShapedType::kDynamic);
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1874 outShape[2] = rhsShape.getDimSize(2);
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1887 return emitOpError(
"expect a shaped tensor for input a, got ")
1888 << getA().getType();
1891 return emitOpError(
"expect a shaped tensor for input b, got ")
1892 << getB().getType();
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError(
"expect operands to be both quantized or both not "
1906 << aElementType <<
" and " << bElementType;
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError(
"expect quantized operands to have same widths, got ")
1913 << aQuantWidth <<
" and " << bQuantWidth;
1920 if (aEType != aZpEType) {
1921 return emitOpError(
"expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType <<
" and " << aZpEType;
1928 if (bEType != bZpEType) {
1929 return emitOpError(
"expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType <<
" and " << bZpEType;
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1989LogicalResult MatmulTBlockScaledOp::verify() {
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2015 "a_scale",
"batch")) ||
2017 "a_scale",
"height")))
2025 "b_data",
"batch")) ||
2027 "b_data",
"channels")))
2035 "b_scale",
"batch")) ||
2037 "b_scale",
"width")) ||
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2048 << D <<
" vs N=" << N;
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError(
"expect C to be a multiple of block size, got C=")
2054 <<
C <<
", block_size=" << blockSize;
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2064 N = ShapedType::isDynamic(N) ? D : N;
2066 const auto outputType = cast<ShapedType>(getResult().
getType());
2067 if (outputType.hasRank() &&
2071 auto stringifyDim = [&](
int64_t d) {
2072 if (ShapedType::isDynamic(d))
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError <<
" to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2112 outputShape.reserve(inputShape.getRank());
2113 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2122 outputShape.push_back(ShapedType::kDynamic);
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2133LogicalResult tosa::PadOp::verify() {
2140 if (
auto padConst = getPadConst()) {
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2152 if (!inputType || !outputType)
2159 auto inputRank = inputType.getRank();
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2166 return emitOpError() <<
"padding tensor must have " << inputRank
2167 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2168 << paddingValues.size();
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2173 for (
int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2179 <<
"invalid padding values at dimension " << i
2180 <<
": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart <<
", " << padEnd <<
"]";
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() <<
"mismatch in output shape at dimension " << i
2191 <<
": expected " << inputShape[i] <<
" + "
2192 << padStart <<
" + " << padEnd <<
" = "
2193 << (inputShape[i] + padStart + padEnd)
2194 <<
", but got " << outputShape[i];
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2223 if (inputShape.hasRank()) {
2224 for (
size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2232 outputShape[i] = size[i];
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2240 outputShape[i] = size[i];
2252LogicalResult tosa::SliceOp::verify() {
2253 const Value input = getInput1();
2254 const Value output = getOutput();
2260 const Value start = getStart();
2261 const Value size = getSize();
2265 if (inputShape.hasRank()) {
2266 const auto inputRank = inputShape.getRank();
2267 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2269 "expect input1 and output to have the same ranks, got ")
2270 << inputRank <<
" and " << outputShape.getRank();
2272 const auto startShapeRank =
2273 llvm::cast<tosa::shapeType>(start.
getType()).getRank();
2274 if (inputRank != startShapeRank)
2275 return emitOpError(
"length of start is not equal to rank of input shape");
2277 const auto sizeShapeRank =
2278 llvm::cast<tosa::shapeType>(size.
getType()).getRank();
2279 if (inputRank != sizeShapeRank)
2280 return emitOpError(
"length of size is not equal to rank of input shape");
2285 if (startValues.size()) {
2286 if (llvm::any_of(startValues, [](
const int64_t v) {
2289 return emitOpError(
"start values must be non-negative, got [")
2290 << startValues <<
"]";
2297 if (llvm::any_of(sizeValues, [](
const int64_t v) {
2300 return emitOpError(
"size values must be > 0, got [") << sizeValues <<
"]";
2301 if (outputShape.hasRank()) {
2303 outputShape.getDims(outputDims);
2304 const bool hasNoInferableDims = llvm::all_of(
2306 if (hasNoInferableDims &&
2308 return emitOpError(
"expected output shape to match size values, got ")
2309 << output.
getType() <<
" vs [" << sizeValues <<
"]";
2312 if (inputShape.hasRank() && startValues.size()) {
2314 inputShape.getDims(inputDims);
2315 for (
const auto &[
index, vals] :
2316 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2317 const auto &[start, size, inputDim] = vals;
2319 ShapedType::isDynamic(inputDim))
2321 if (start + size > inputDim)
2322 return emitOpError(
"start + size must be less than or equal to input "
2323 "dimension size, got start=")
2324 << start <<
", size=" << size
2325 <<
" vs input dim size=" << inputDim <<
" at dimension "
2333LogicalResult tosa::MulOp::inferReturnTypeComponents(
2334 MLIRContext *context, ::std::optional<Location> location,
2349LogicalResult tosa::MulOp::verify() {
2350 const Value output = getOutput();
2355 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2356 IntegerType lhsIntType =
2358 IntegerType rhsIntType =
2360 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2361 return emitOpError(
"requires the same element type for all operands");
2366 if (lhsIntType.getWidth() > resIntType.getWidth())
2367 return emitOpError(
"invalid data type size for operands or result");
2372 for (
int i = 0; i < 2; ++i) {
2375 "requires the same element type for all operands and results");
2379 ElementsAttr shiftElem;
2381 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2383 return emitOpError() <<
"require shift to be 0 for float type";
2391 TypeRange operandTypes = getOperandTypes();
2392 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2393 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2395 const bool aHasRank = aType.hasRank();
2396 const bool bHasRank = bType.hasRank();
2397 if (aHasRank && bHasRank) {
2398 const int64_t aRank = aType.getRank();
2399 const int64_t bRank = bType.getRank();
2401 return emitOpError(
"a and b operands don't have matching ranks, got ")
2402 << aRank <<
" and " << bRank;
2407 aType.getShape(), bType.getShape(), resultShape))
2408 return emitOpError(
"a and b operands don't have broadcast-compatible "
2410 << aType <<
" and " << bType;
2413 ShapedType resultType = cast<ShapedType>(output.
getType());
2414 if (!resultType.hasRank())
2417 const int64_t resultRank = resultType.getRank();
2418 if (aHasRank && resultRank != aType.getRank())
2419 return emitOpError(
"result type has different rank than a, got ")
2420 << resultRank <<
" vs " << aType.getRank();
2421 if (bHasRank && resultRank != bType.getRank())
2422 return emitOpError(
"result type has different rank than b, got ")
2423 << resultRank <<
" vs " << bType.getRank();
2428LogicalResult tosa::TableOp::inferReturnTypeComponents(
2429 MLIRContext *context, ::std::optional<Location> location,
2430 TableOp::Adaptor adaptor,
2432 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2434 if (!inputShape.hasRank()) {
2439 inferredReturnShapes.resize(1);
2440 inputShape.getDims(inferredReturnShapes[0]);
2444LogicalResult tosa::TableOp::verify() {
2445 const TensorType inputType = getInput1().getType();
2446 const TensorType outputType = getOutput().getType();
2455 auto inputDims = inputType.
getShape();
2456 auto outputDims = outputType.
getShape();
2457 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2459 auto [inputDim, outputDim] = it.value();
2460 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2461 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2462 <<
" doesn't match dim(input, " << dim
2463 <<
") = " << inputDim;
2476 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2477 [](
const APInt &val) { return val.getSExtValue(); });
2481LogicalResult tosa::TileOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2483 TileOp::Adaptor adaptor,
2490 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2498 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2500 if (!inputShape.hasRank()) {
2501 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2502 inferredReturnShapes.push_back(
2505 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2509 outputShape.reserve(multiples.size());
2510 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2511 if (multiples[i] == ShapedType::kDynamic) {
2512 outputShape.push_back(ShapedType::kDynamic);
2514 int64_t dim = inputShape.getDimSize(i);
2515 if (dim != ShapedType::kDynamic)
2516 dim *= multiples[i];
2517 outputShape.push_back(dim);
2525LogicalResult tosa::TileOp::verify() {
2531 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2532 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2534 shapeType multiplesType =
2535 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2537 auto multiplesRank = multiplesType.getRank();
2539 if (inputType.hasRank()) {
2540 if (inputType.getRank() != multiplesRank)
2541 return emitOpError(
"expect 'multiples' to have rank ")
2542 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2543 if (outputType.hasRank() &&
2547 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2548 return emitOpError(
"expect 'multiples' array to have length ")
2549 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2552 if (getConstantMultiples(multiples).succeeded() &&
2553 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2555 "expect element of 'multiples' to be positive integer or -1.");
2561 if (l.size() != r.size() || l.size() != 1)
2566LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2567 MLIRContext *context, ::std::optional<Location> location,
2568 ReshapeOp::Adaptor adaptor,
2570 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2575 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2585 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2586 inferredReturnShapes.push_back(
2594 int64_t numElements = inputShape.getNumElements();
2596 for (
auto val : newShapeValue) {
2597 if (ShapedType::isStatic(val)) {
2603 for (
auto &val : newShapeValue) {
2604 if (ShapedType::isDynamic(val))
2605 val = numElements / staticMul;
2608 inferredReturnShapes.push_back(
2613llvm::LogicalResult tosa::ReshapeOp::verify() {
2619 TensorType inputType = getInput1().getType();
2624 return mlir::success();
2628 if (missingDims > 1)
2629 return emitOpError() <<
"expected at most one target dimension to be "
2632 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2636 if ((
int64_t)shapeValues.size() != outputType.getRank())
2637 return emitOpError() <<
"new shape does not match result rank";
2639 for (
auto [newShapeDim, outputShapeDim] :
2640 zip(shapeValues, outputType.getShape())) {
2642 newShapeDim != ShapedType::kDynamic &&
2643 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2644 return emitOpError() <<
"new shape is inconsistent with result shape";
2647 return emitOpError() <<
"new shape has invalid tensor dimension size "
2651 if (inputType.hasStaticShape()) {
2652 int64_t inputElementsNum = inputType.getNumElements();
2653 if (outputType.hasStaticShape()) {
2654 int64_t outputElementsNum = outputType.getNumElements();
2655 if (inputElementsNum != outputElementsNum) {
2656 return emitOpError() <<
"cannot reshape " << inputElementsNum
2657 <<
" elements into " << outputElementsNum;
2663 return (dim > 0) ?
acc * dim :
acc;
2665 bool isStaticNewShape =
2666 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2667 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2668 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2669 return emitOpError() <<
"cannot reshape " << inputElementsNum
2670 <<
" elements into " << newShapeElementsNum;
2674 return mlir::success();
2681 ElementsAttr zpAttr;
2686 Type zpElemType = zpAttr.getElementType();
2688 if (llvm::isa<FloatType>(zpElemType)) {
2689 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2696 if (llvm::isa<IntegerType>(zpElemType)) {
2698 return zpAttr.getValues<APInt>()[0].getSExtValue();
2700 return zpAttr.getValues<APInt>()[0].getZExtValue();
2707template <
typename T>
2709 const std::string &operand) {
2712 if (!zpElemType.
isInteger(8) && zp != 0) {
2714 std::string lower = operand;
2715 llvm::transform(lower, lower.begin(), ::tolower);
2716 return op.emitOpError()
2717 << lower <<
" zero point must be zero for non-int8 integer types";
2725 const std::string &operand) {
2726 bool isInputZp = (operand ==
"Input");
2728 bool tensorUnsigned =
2729 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2730 StringRef tensorName = isInputZp ?
"input" :
"output";
2736 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2737 return op.emitOpError()
2738 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2740 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2741 return op.emitOpError() <<
"expect " << tensorName
2742 <<
"_zp of 0 or 32768 for unsigned int16 "
2743 << tensorName <<
", got " << zp;
2750#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2751 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2752 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2754 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2755 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2774#undef ZERO_POINT_HELPER
2776LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2777 MLIRContext *context, ::std::optional<Location> location,
2778 TransposeOp::Adaptor adaptor,
2780 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2789 const auto inputRank = inputShape.
getRank();
2793 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2799 if (inputRank == 0) {
2805 bool allTheSame =
true;
2806 for (
int i = 1, s = inputRank; i < s; i++) {
2816 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2821 outputShape.resize(inputRank, ShapedType::kDynamic);
2824 if (llvm::any_of(adaptor.getPerms(),
2825 [inputRank](
const auto i) { return i >= inputRank; }))
2828 outputShape.reserve(inputRank);
2829 for (
int i = 0, s = inputRank; i < s; i++) {
2830 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2837LogicalResult tosa::TransposeOp::verify() {
2849 if (inputShape.hasRank() &&
2850 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2851 return emitOpError() <<
"expected perms attribute to have size "
2852 << inputShape.getRank()
2853 <<
" (input rank) but got size "
2854 << constantPerms.size();
2856 if (inputShape.hasRank() && outputShape.hasRank() &&
2857 inputShape.getRank() != outputShape.getRank())
2859 <<
"expected input tensor rank to equal result tensor rank";
2861 if (outputShape.hasRank() &&
2862 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2863 return emitOpError() <<
"expected perms attribute to have size "
2864 << outputShape.getRank()
2865 <<
" (output rank) but got size "
2866 << constantPerms.size();
2868 if (!llvm::all_of(constantPerms,
2869 [&constantPerms](int32_t s) {
2871 static_cast<size_t>(s) < constantPerms.size();
2874 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
2875 return emitOpError() <<
"expected valid permutation indices";
2878 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2879 inputShape.getNumElements() != outputShape.getNumElements())
2880 return emitOpError() <<
"expected input1 and output to have same numbers "
2882 << inputShape.getNumElements() <<
" and "
2883 << outputShape.getNumElements();
2887 if (inputShape.hasRank() && outputShape.hasRank()) {
2888 for (
auto i = 0; i < outputShape.getRank(); i++) {
2889 if (inputShape.isDynamicDim(constantPerms[i]) ||
2890 outputShape.isDynamicDim(i))
2893 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2895 <<
"expected output tensor dim " << i <<
" to match "
2896 <<
"input dim " << constantPerms[i] <<
" with value of "
2897 << inputShape.getDimSize(constantPerms[i]);
2904LogicalResult TransposeOp::reifyResultShapes(
2907 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2909 Value input = getInput1();
2910 auto inputType = cast<TensorType>(input.
getType());
2912 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2913 for (
auto dim : transposePerms) {
2914 int32_t dimInInput = transposePerms[dim];
2915 if (inputType.isDynamicDim(dimInInput))
2917 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2921 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2924 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2928LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2929 MLIRContext *context, ::std::optional<Location> location,
2930 GatherOp::Adaptor adaptor,
2931 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2932 llvm::SmallVector<int64_t> outputShape;
2933 outputShape.resize(3, ShapedType::kDynamic);
2935 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2936 if (valuesShape.hasRank()) {
2937 outputShape[0] = valuesShape.getDimSize(0);
2938 outputShape[2] = valuesShape.getDimSize(2);
2941 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2942 if (indicesShape.hasRank()) {
2943 if (outputShape[0] == ShapedType::kDynamic)
2944 outputShape[0] = indicesShape.getDimSize(0);
2945 if (outputShape[1] == ShapedType::kDynamic)
2946 outputShape[1] = indicesShape.getDimSize(1);
2949 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2953LogicalResult tosa::GatherOp::verify() {
2960 const ShapeAdaptor valuesShape(getValues().
getType());
2962 const ShapeAdaptor outputShape(getOutput().
getType());
2964 int64_t n = ShapedType::kDynamic;
2965 int64_t w = ShapedType::kDynamic;
2966 int64_t c = ShapedType::kDynamic;
2968 if (valuesShape.hasRank()) {
2969 n = valuesShape.getDimSize(0);
2970 c = valuesShape.getDimSize(2);
2972 if (indicesShape.hasRank()) {
2973 const int64_t indicesN = indicesShape.getDimSize(0);
2974 w = indicesShape.getDimSize(1);
2975 if (n == ShapedType::kDynamic)
2977 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2978 return emitOpError() <<
"requires indices dimension 0 to have size " << n
2979 <<
", got " << indicesN;
2981 if (outputShape.hasRank()) {
2982 const int64_t outputN = outputShape.getDimSize(0);
2983 const int64_t outputW = outputShape.getDimSize(1);
2984 const int64_t outputC = outputShape.getDimSize(2);
2985 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2987 return emitOpError() <<
"requires output dimension 0 to have size " << n
2988 <<
", got " << outputN;
2990 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2992 return emitOpError() <<
"requires output dimension 1 to have size " << w
2993 <<
", got " << outputW;
2994 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2996 return emitOpError() <<
"requires output dimension 2 to have size " << c
2997 <<
", got " << outputC;
3002LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3003 MLIRContext *context, ::std::optional<Location> location,
3004 ResizeOp::Adaptor adaptor,
3005 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3006 llvm::SmallVector<int64_t, 4> outputShape;
3007 outputShape.resize(4, ShapedType::kDynamic);
3009 ShapeAdaptor inputShape(adaptor.getInput().getType());
3010 if (!inputShape.hasRank())
3013 outputShape[0] = inputShape.getDimSize(0);
3014 outputShape[3] = inputShape.getDimSize(3);
3015 int64_t inputHeight = inputShape.getDimSize(1);
3016 int64_t inputWidth = inputShape.getDimSize(2);
3018 if ((inputHeight == ShapedType::kDynamic) ||
3019 (inputWidth == ShapedType::kDynamic))
3022 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3033 const int64_t outputHeight =
3034 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3038 const int64_t outputWidth =
3039 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3043 if (outputHeight < 0 || outputWidth < 0) {
3046 "calculated output height and width must be non-negative, "
3048 outputHeight,
", width = ", outputWidth);
3051 outputShape[1] = outputHeight;
3052 outputShape[2] = outputWidth;
3053 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3057LogicalResult tosa::ResizeOp::verify() {
3058 const Value input = getInput();
3059 const Value output = getOutput();
3060 const RankedTensorType inputType =
3061 llvm::dyn_cast<RankedTensorType>(input.
getType());
3062 const RankedTensorType outputType =
3063 llvm::dyn_cast<RankedTensorType>(output.
getType());
3065 SmallVector<int64_t> scaleValues;
3066 SmallVector<int64_t> offsetValues;
3067 SmallVector<int64_t> borderValues;
3075 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3076 return emitOpError(
"expect all scale values to be > 0, got ")
3079 const int64_t scaleYN = scaleValues[0];
3080 const int64_t scaleYD = scaleValues[1];
3081 const int64_t scaleXN = scaleValues[2];
3082 const int64_t scaleXD = scaleValues[3];
3084 const int64_t offsetY = offsetValues[0];
3085 const int64_t offsetX = offsetValues[1];
3087 const int64_t borderY = borderValues[0];
3088 const int64_t borderX = borderValues[1];
3095 const int64_t oh = outputType.getDimSize(1);
3096 const int64_t ow = outputType.getDimSize(2);
3097 const int64_t ih = inputType.getDimSize(1);
3098 const int64_t iw = inputType.getDimSize(2);
3104 if (ih != ShapedType::kDynamic && ih != 1) {
3105 const std::optional<int64_t> calculatedOutHeightMinusOne =
3106 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3107 if (!calculatedOutHeightMinusOne.has_value())
3108 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3110 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3111 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3112 <<
") / " << scaleYD;
3113 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3114 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3115 return emitOpError(
"calculated output height did not match expected: ")
3116 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3123 if (iw != ShapedType::kDynamic && iw != 1) {
3124 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3125 const std::optional<int64_t> calculatedOutWidthMinusOne =
3127 if (!calculatedOutWidthMinusOne.has_value())
3128 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3130 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3131 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3132 <<
") / " << scaleXD;
3133 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3134 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3135 return emitOpError(
"calculated output width did not match expected: ")
3136 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3142LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3143 MLIRContext *context, ::std::optional<Location> location,
3144 ScatterOp::Adaptor adaptor,
3145 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3146 llvm::SmallVector<int64_t> outputShape;
3147 outputShape.resize(3, ShapedType::kDynamic);
3149 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3150 if (valuesInShape.hasRank()) {
3151 outputShape[0] = valuesInShape.getDimSize(0);
3152 outputShape[1] = valuesInShape.getDimSize(1);
3153 outputShape[2] = valuesInShape.getDimSize(2);
3156 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3157 if (indicesShape.hasRank()) {
3158 if (outputShape[0] == ShapedType::kDynamic)
3159 outputShape[0] = indicesShape.getDimSize(0);
3162 ShapeAdaptor inputShape(adaptor.getInput().getType());
3163 if (inputShape.hasRank()) {
3164 if (outputShape[0] == ShapedType::kDynamic)
3165 outputShape[0] = inputShape.getDimSize(0);
3166 if (outputShape[2] == ShapedType::kDynamic)
3167 outputShape[2] = inputShape.getDimSize(2);
3170 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3174LogicalResult tosa::ScatterOp::verify() {
3184 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3186 const ShapeAdaptor inputShape(getInput().
getType());
3187 const ShapeAdaptor outputShape(getValuesOut().
getType());
3189 int64_t n = ShapedType::kDynamic;
3190 int64_t k = ShapedType::kDynamic;
3191 int64_t w = ShapedType::kDynamic;
3192 int64_t c = ShapedType::kDynamic;
3193 if (valuesInShape.hasRank()) {
3194 n = valuesInShape.getDimSize(0);
3195 k = valuesInShape.getDimSize(1);
3196 c = valuesInShape.getDimSize(2);
3198 if (indicesShape.hasRank()) {
3199 const int64_t indicesN = indicesShape.getDimSize(0);
3200 w = indicesShape.getDimSize(1);
3201 if (n == ShapedType::kDynamic)
3203 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3204 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3205 <<
", got " << indicesN;
3207 if (inputShape.hasRank()) {
3208 const int64_t inputN = inputShape.getDimSize(0);
3209 const int64_t inputW = inputShape.getDimSize(1);
3210 const int64_t inputC = inputShape.getDimSize(2);
3211 if (n == ShapedType::kDynamic)
3213 else if (inputN != ShapedType::kDynamic && n != inputN)
3214 return emitOpError() <<
"requires input dimension 0 to have size " << n
3215 <<
", got " << inputN;
3216 if (w == ShapedType::kDynamic)
3218 else if (inputW != ShapedType::kDynamic && w != inputW)
3219 return emitOpError() <<
"requires input dimension 1 to have size " << w
3220 <<
", got " << inputW;
3222 if (c == ShapedType::kDynamic)
3224 else if (inputC != ShapedType::kDynamic && c != inputC)
3225 return emitOpError() <<
"requires input dimension 2 to have size " << c
3226 <<
", got " << inputC;
3228 if (outputShape.hasRank()) {
3229 const int64_t outputN = outputShape.getDimSize(0);
3230 const int64_t outputK = outputShape.getDimSize(1);
3231 const int64_t outputC = outputShape.getDimSize(2);
3232 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3234 return emitOpError() <<
"requires values_out dimension 0 to have size "
3235 << n <<
", got " << outputN;
3236 if (k == ShapedType::kDynamic)
3238 else if (outputK != ShapedType::kDynamic && k != outputK)
3239 return emitOpError() <<
"requires values_out dimension 1 to have size "
3240 << k <<
", got " << outputK;
3241 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3243 return emitOpError() <<
"requires values_out dimension 2 to have size "
3244 << c <<
", got " << outputC;
3246 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3247 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3256 int64_t axisVal = axis.getValue().getSExtValue();
3257 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3263 operandShape.
getDims(outputShape);
3264 outputShape[axisVal] = 1;
3269#define COMPATIBLE_RETURN_TYPES(OP) \
3270 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3271 if (l.size() != r.size() || l.size() != 1) \
3273 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3275 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3278#define REDUCE_SHAPE_INFER(OP) \
3279 LogicalResult OP::inferReturnTypeComponents( \
3280 MLIRContext *context, ::std::optional<Location> location, \
3281 OP::Adaptor adaptor, \
3282 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3284 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3285 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3286 const Properties &prop = adaptor.getProperties(); \
3287 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3288 inferredReturnShapes); \
3290 COMPATIBLE_RETURN_TYPES(OP)
3298#undef REDUCE_SHAPE_INFER
3300#undef COMPATIBLE_RETURN_TYPES
3302template <
typename T>
3305 TensorType inputType = op.getInput().getType();
3306 TensorType outputType = op.getOutput().getType();
3307 int32_t reduceAxis = op.getAxis();
3309 if (reduceAxis < 0) {
3310 op.emitOpError(
"reduce axis must not be negative");
3314 int64_t inputRank = inputType.getRank();
3317 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3318 op.emitOpError(
"expect input tensor rank (")
3319 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3325 int64_t outputRank = outputType.getRank();
3326 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3328 "expect output tensor rank to be equal to input tensor rank");
3331 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3332 op.emitOpError(
"expect output tensor rank (")
3333 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3339 if (outputRank != 0) {
3340 auto outputShape = outputType.
getShape();
3341 if (!outputType.isDynamicDim(reduceAxis) &&
3342 outputShape[reduceAxis] != 1) {
3343 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3344 << outputShape[reduceAxis];
3352LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3353LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3354LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3355LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3356LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3357LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3371#define NARY_SHAPE_INFER(OP) \
3372 LogicalResult OP::inferReturnTypeComponents( \
3373 MLIRContext *context, ::std::optional<Location> location, \
3374 ValueShapeRange operands, DictionaryAttr attributes, \
3375 OpaqueProperties properties, RegionRange regions, \
3376 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3377 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3417#undef PRED_SHAPE_INFER
3419LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3420 MLIRContext *context, ::std::optional<Location> location,
3421 NegateOp::Adaptor adaptor,
3423 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3428LogicalResult tosa::NegateOp::verify() {
3430 const Type input1Type = getInput1().getType();
3431 const Type outputType = getOutput().getType();
3436 const SmallVector<Type, 2> types = {input1Type, outputType};
3438 return emitOpError() <<
"requires the same shape for input1 and output";
3441 const Type input1ZpEType =
3443 if (input1EType != input1ZpEType) {
3444 return emitOpError(
"expect both input1 and its zero point are the same "
3445 "element type, got ")
3446 << input1EType <<
" and " << input1ZpEType;
3449 const Type outputZpEType =
3451 if (outputEType != outputZpEType) {
3452 return emitOpError(
"expect both output and its zero point are the same "
3453 "element type, got ")
3454 << outputEType <<
" and " << outputZpEType;
3457 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3458 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3461 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3462 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3473 outputShape.resize(4, ShapedType::kDynamic);
3488 if (ShapedType::isStatic(height)) {
3489 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3490 outputShape[1] = padded / stride[0] + 1;
3493 if (ShapedType::isStatic(width)) {
3494 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3495 outputShape[2] = padded / stride[1] + 1;
3502template <
typename AdaptorT>
3508 if (ShapedType::isDynamic(current))
3509 current = candidate;
3518 : adaptor(adaptor) {}
3522 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3530 outputShape[0] = outputBatch;
3531 inputSpatial[0] = inputHeight;
3532 inputSpatial[1] = inputWidth;
3537 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3545 outputShape[3] = outputChannels;
3546 weightSpatial[0] = kernelHeight;
3547 weightSpatial[1] = kernelWidth;
3556 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3557 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3558 dilationValues.assign(adaptor.getDilation().begin(),
3559 adaptor.getDilation().end());
3564 Conv2DOp::Adaptor adaptor;
3572 : adaptor(adaptor) {}
3576 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3577 if (inputDataShape.
hasRank()) {
3582 outputShape[0] = outputBatch;
3583 inputSpatial[0] = inputHeight;
3584 inputSpatial[1] = inputWidth;
3587 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3588 if (!inputScaleShape.
hasRank())
3602 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3603 if (weightDataShape.
hasRank()) {
3608 outputShape[3] = outputChannels;
3609 weightSpatial[0] = kernelHeight;
3610 weightSpatial[1] = kernelWidth;
3613 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3614 if (!weightScaleShape.
hasRank())
3643 Conv2DBlockScaledOp::Adaptor adaptor;
3651 : adaptor(adaptor) {}
3655 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3664 outputShape[0] = outputBatch;
3665 inputSpatial[0] = inputDepth;
3666 inputSpatial[1] = inputHeight;
3667 inputSpatial[2] = inputWidth;
3672 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3681 outputShape[4] = outputChannels;
3682 weightSpatial[0] = kernelDepth;
3683 weightSpatial[1] = kernelHeight;
3684 weightSpatial[2] = kernelWidth;
3693 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3694 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3695 dilationValues.assign(adaptor.getDilation().begin(),
3696 adaptor.getDilation().end());
3701 Conv3DOp::Adaptor adaptor;
3704template <
typename AdaptorT>
3710 ShapedType::kDynamic);
3712 ShapedType::kDynamic);
3714 ShapedType::kDynamic);
3716 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
3717 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
3719 const ShapeAdaptor biasShape = adaptor.getBias().getType();
3722 if (biasSize != 1) {
3723 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
3724 outputShape[outputChannelDim] =
3725 ShapedType::isDynamic(outputShape[outputChannelDim])
3727 : outputShape[outputChannelDim];
3734 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
3740 for (
int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
3741 if (!ShapedType::isStatic(inputSpatial[dim]) ||
3742 !ShapedType::isStatic(weightSpatial[dim]))
3745 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
3747 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
3748 const int64_t unstridedResult = inputSize - filterSize + 1;
3749 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
3756LogicalResult Conv2DOp::inferReturnTypeComponents(
3757 MLIRContext *context, ::std::optional<Location> location,
3758 Conv2DOp::Adaptor adaptor,
3759 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3763LogicalResult Conv2DOp::verify() {
3770LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3771 MLIRContext *context, ::std::optional<Location> location,
3772 Conv2DBlockScaledOp::Adaptor adaptor,
3773 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3777LogicalResult Conv2DBlockScaledOp::verify() {
3779 getWeightData().
getType(),
"input_data",
3782 getWeightScale().
getType(),
"input_scale",
3785 getOutput().
getType(),
"bias",
"output")))
3789 int64_t N = ShapedType::kDynamic;
3790 int64_t IH = ShapedType::kDynamic;
3791 int64_t IW = ShapedType::kDynamic;
3792 int64_t IC = ShapedType::kDynamic;
3793 int64_t multiplesOfIC = ShapedType::kDynamic;
3794 int64_t OC = ShapedType::kDynamic;
3795 int64_t KH = ShapedType::kDynamic;
3796 int64_t KW = ShapedType::kDynamic;
3798 const ShapeAdaptor inputDataShape(getInputData().
getType());
3799 if (inputDataShape.hasRank()) {
3800 N = inputDataShape.getDimSize(0);
3801 IH = inputDataShape.getDimSize(1);
3802 IW = inputDataShape.getDimSize(2);
3803 IC = inputDataShape.getDimSize(3);
3806 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
3807 if (inputScaleShape.hasRank()) {
3809 "input_scale",
"batch size")) ||
3811 "input_scale",
"input height")) ||
3813 "input_scale",
"input width")))
3815 multiplesOfIC = inputScaleShape.getDimSize(3);
3818 const ShapeAdaptor weightDataShape(getWeightData().
getType());
3819 if (weightDataShape.hasRank()) {
3820 OC = weightDataShape.getDimSize(0);
3821 KH = weightDataShape.getDimSize(1);
3822 KW = weightDataShape.getDimSize(2);
3824 "weight_data",
"input channels")))
3828 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
3829 if (weightScaleShape.hasRank()) {
3831 "weight_scale",
"output channels")) ||
3833 "weight_scale",
"kernel height")) ||
3835 "weight_scale",
"kernel width")) ||
3837 weightScaleShape.getDimSize(3),
3838 "weight_scale",
"input channel blocks")))
3843 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3844 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3845 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
3846 << IC <<
", block_size=" << blockSize;
3849 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3850 multiplesOfIC != IC / blockSize)
3852 "expect scale operands dimension 2 to equal IC/block_size (")
3853 << IC <<
"/" << blockSize <<
")"
3854 <<
", got " << multiplesOfIC;
3857 SmallVector<int64_t> padValues;
3859 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
3860 return emitOpError(
"expect all padding values to be >= 0, got ")
3864 SmallVector<int64_t> strideValues;
3866 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
3867 return emitOpError(
"expect all stride values to be >= 1, got ")
3871 SmallVector<int64_t> dilationValues;
3874 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
3875 return emitOpError(
"expect all dilation values to be >= 1, got ")
3880 const ShapeAdaptor outputShape(getOutput().
getType());
3881 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3882 outputShape.hasRank()) {
3884 padValues[0], padValues[1], strideValues[0],
3885 dilationValues[0],
"height",
"y",
"top",
3888 padValues[2], padValues[3], strideValues[1],
3889 dilationValues[1],
"width",
"x",
"left",
3895 const ShapeAdaptor biasShape(getBias().
getType());
3896 if (biasShape.hasRank() && outputShape.hasRank()) {
3897 const int64_t biasChannels = biasShape.getDimSize(0);
3898 const int64_t outputChannels =
3899 outputShape.getDimSize(outputShape.getRank() - 1);
3900 if (biasChannels == ShapedType::kDynamic ||
3901 outputChannels == ShapedType::kDynamic)
3905 if (biasChannels != outputChannels && biasChannels != 1)
3907 "bias channels expected to be equal to output channels (")
3908 << outputChannels <<
") or 1, got " << biasChannels;
3914LogicalResult Conv3DOp::inferReturnTypeComponents(
3915 MLIRContext *context, ::std::optional<Location> location,
3916 Conv3DOp::Adaptor adaptor,
3917 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3921LogicalResult Conv3DOp::verify() {
3928LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3929 MLIRContext *context, ::std::optional<Location> location,
3930 AvgPool2dOp::Adaptor adaptor,
3931 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3932 ShapeAdaptor inputShape(adaptor.getInput().getType());
3933 const Properties &prop = adaptor.getProperties();
3935 inferredReturnShapes);
3938LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3939 MLIRContext *context, ::std::optional<Location> location,
3940 MaxPool2dOp::Adaptor adaptor,
3941 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3942 ShapeAdaptor inputShape(adaptor.getInput().getType());
3943 const Properties &prop = adaptor.getProperties();
3945 inferredReturnShapes);
3948LogicalResult MaxPool2dOp::verify() {
3959LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3960 MLIRContext *context, ::std::optional<Location> location,
3961 DepthwiseConv2DOp::Adaptor adaptor,
3962 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3963 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3965 int64_t inputWidth = ShapedType::kDynamic;
3966 int64_t inputHeight = ShapedType::kDynamic;
3967 int64_t inputChannels = ShapedType::kDynamic;
3969 int64_t weightWidth = ShapedType::kDynamic;
3970 int64_t weightHeight = ShapedType::kDynamic;
3971 int64_t depthChannels = ShapedType::kDynamic;
3974 ShapeAdaptor inputShape(adaptor.getInput().getType());
3975 if (inputShape.hasRank()) {
3976 outputShape[0] = inputShape.getDimSize(0);
3977 inputHeight = inputShape.getDimSize(1);
3978 inputWidth = inputShape.getDimSize(2);
3979 inputChannels = inputShape.getDimSize(3);
3983 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3984 if (weightShape.hasRank()) {
3985 weightHeight = weightShape.getDimSize(0);
3986 weightWidth = weightShape.getDimSize(1);
3987 inputChannels = ShapedType::isDynamic(inputChannels)
3988 ? weightShape.getDimSize(2)
3990 depthChannels = weightShape.getDimSize(3);
3995 if (ShapedType::isStatic(inputChannels) &&
3996 ShapedType::isStatic(depthChannels)) {
3997 outputShape[3] = inputChannels * depthChannels;
4001 ShapeAdaptor biasShape(adaptor.getBias().getType());
4002 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4003 int64_t bc = biasShape.getDimSize(0);
4004 if (bc != ShapedType::kDynamic && bc != 1)
4005 outputShape[3] = bc;
4008 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4009 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4010 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4012 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4013 int64_t inputSize = inputHeight + padding[0] + padding[1];
4014 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4015 int64_t unstridedResult = inputSize - filterSize + 1;
4016 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4019 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4020 int64_t inputSize = inputWidth + padding[2] + padding[3];
4021 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4022 int64_t unstridedResult = inputSize - filterSize + 1;
4023 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4026 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4030LogicalResult DepthwiseConv2DOp::verify() {
4037LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4038 MLIRContext *context, ::std::optional<Location> location,
4039 TransposeConv2DOp::Adaptor adaptor,
4040 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4041 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4043 int64_t inputWidth = ShapedType::kDynamic;
4044 int64_t inputHeight = ShapedType::kDynamic;
4045 int64_t weightWidth = ShapedType::kDynamic;
4046 int64_t weightHeight = ShapedType::kDynamic;
4049 ShapeAdaptor inputShape(adaptor.getInput().getType());
4050 if (inputShape.hasRank()) {
4051 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4052 ? inputShape.getDimSize(0)
4054 inputHeight = inputShape.getDimSize(1);
4055 inputWidth = inputShape.getDimSize(2);
4059 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4060 if (weightShape.hasRank()) {
4061 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4062 ? weightShape.getDimSize(0)
4064 weightHeight = weightShape.getDimSize(1);
4065 weightWidth = weightShape.getDimSize(2);
4069 ShapeAdaptor biasShape(adaptor.getBias().getType());
4070 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4071 int64_t bc = biasShape.getDimSize(0);
4072 if (bc != ShapedType::kDynamic && bc != 1)
4073 outputShape[3] = bc;
4076 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4077 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4079 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4080 int64_t calculateSize =
4081 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4083 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4086 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4087 int64_t calculateSize =
4088 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4090 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4093 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4097LogicalResult TransposeConv2DOp::verify() {
4101 const llvm::ArrayRef<int64_t> strides = getStride();
4102 const int64_t strideY = strides[0];
4103 const int64_t strideX = strides[1];
4105 if (strideY < 1 || strideX < 1)
4106 return emitOpError(
"expect all stride values to be >= 1, got [")
4109 const auto checkPadAgainstKernelDim =
4110 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4111 llvm::StringRef kernelDimName) -> LogicalResult {
4112 if (padValue <= -kernelDimSize)
4114 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4115 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4120 const llvm::ArrayRef<int64_t> padding = getOutPad();
4121 const int64_t outPadTop = padding[0];
4122 const int64_t outPadBottom = padding[1];
4123 const int64_t outPadLeft = padding[2];
4124 const int64_t outPadRight = padding[3];
4126 const auto weightType =
4127 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4130 const int64_t kernelHeight = weightType.getDimSize(1);
4131 if (ShapedType::isStatic(kernelHeight)) {
4132 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4133 "out_pad_top",
"KH")))
4136 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4137 "out_pad_bottom",
"KH")))
4141 const int64_t kernelWidth = weightType.getDimSize(2);
4142 if (ShapedType::isStatic(kernelWidth)) {
4143 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4144 "out_pad_left",
"KW")))
4147 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4148 "out_pad_right",
"KW")))
4154 const auto outputType =
4155 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4159 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4160 if (inputType && weightType) {
4161 const int64_t inputHeight = inputType.getDimSize(1);
4162 const int64_t kernelHeight = weightType.getDimSize(1);
4163 const int64_t outputHeight = outputType.getDimSize(1);
4165 if (ShapedType::isStatic(inputHeight) &&
4166 ShapedType::isStatic(outputHeight)) {
4168 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4170 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4171 "+ out_pad_top + out_pad_bottom + KH, but got ")
4172 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4173 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4174 <<
" + " << kernelHeight;
4177 const int64_t inputWidth = inputType.getDimSize(2);
4178 const int64_t kernelWidth = weightType.getDimSize(2);
4179 const int64_t outputWidth = outputType.getDimSize(2);
4181 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4183 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4185 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4186 "+ out_pad_left + out_pad_right + KW, but got ")
4187 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4188 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4193 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4198 const int64_t biasChannels = biasType.getDimSize(0);
4201 if (biasChannels == ShapedType::kDynamic)
4204 const int64_t outputChannels = outputType.getDimSize(3);
4205 if (!ShapedType::isDynamic(outputChannels) &&
4206 biasChannels != outputChannels && biasChannels != 1)
4208 "bias channels expected to be equal to output channels (")
4209 << outputChannels <<
") or 1, got " << biasChannels;
4214LogicalResult RescaleOp::verify() {
4215 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
4217 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
4221 auto inputElementType =
4223 if (!mlir::isa<IntegerType>(inputElementType)) {
4224 emitOpError(
"expect input to have integer element type, got ")
4225 << inputElementType;
4229 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
4231 emitOpError(
"expect shaped tensor for output, got ")
4232 << getOutput().getType();
4236 auto outputElementType =
4238 if (!mlir::isa<IntegerType>(outputElementType)) {
4239 emitOpError(
"expect output to have integer element type, got ")
4240 << outputElementType;
4252 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4253 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4256 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4257 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4260 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
4261 if (!multiplierType) {
4262 emitOpError(
"expect shaped tensor for multiplier, got ")
4263 << getMultiplier().getType();
4267 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
4269 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
4274 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4275 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4276 << multiplierType.getElementType();
4281 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4283 "expect i16 element type for multiplier for scale32=false, got ")
4284 << multiplierType.getElementType();
4288 if (!inputType.hasRank())
4294 int64_t numChannels = 1;
4295 if (getPerChannel()) {
4296 if (inputType.getRank() < 1) {
4297 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4298 "true, but got rank ")
4299 << inputType.getRank();
4302 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4305 if (!multiplierType.hasRank())
4308 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4310 if (multiplierShape[0] != ShapedType::kDynamic &&
4311 multiplierShape[0] != numChannels) {
4313 << numChannels <<
" } for multiplier input, got { "
4314 << multiplierShape[0] <<
" }";
4318 if (!shiftType.hasRank())
4321 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4323 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4325 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
4332LogicalResult RescaleOp::inferReturnTypeComponents(
4333 MLIRContext *context, ::std::optional<Location> location,
4334 RescaleOp::Adaptor adaptor,
4335 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4336 ShapeAdaptor inputShape(adaptor.getInput().getType());
4337 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4341LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4342 MLIRContext *context, ::std::optional<Location> location,
4343 CastFromBlockScaledOp::Adaptor adaptor,
4344 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4345 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4346 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4350LogicalResult CastFromBlockScaledOp::verify() {
4351 const Type inputDataType = getInputData().getType();
4352 const Type outputDataType = getResult().getType();
4354 return emitOpError() <<
"require compatible shapes for input_data ("
4355 << inputDataType <<
") and " <<
"output_data ("
4356 << outputDataType <<
")";
4358 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4360 if (inputDataShape.
hasRank()) {
4361 const unsigned int blockSize =
4363 const int64_t inputDataLastDim =
4365 if (inputDataLastDim % blockSize != 0)
4366 return emitOpError() <<
"expect last dimension of input_data ("
4368 <<
") to be divisible by block_size (" << blockSize
4371 const Type inputScaleType = getInputScale().getType();
4372 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4374 if (inputScaleShape.
hasRank()) {
4375 SmallVector<int64_t> inputDataDims, inputScaleDims;
4376 inputDataShape.
getDims(inputDataDims);
4377 inputScaleShape.
getDims(inputScaleDims);
4379 if (inputDataDims.size() != inputScaleDims.size() ||
4381 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4382 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4384 <<
"require compatible shapes for input_data (" << inputDataType
4385 <<
") and " <<
"input_scale (" << inputScaleType
4386 <<
") except for the last dimension";
4388 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4389 inputScaleDims.back()};
4390 if (ShapedType::isStatic(inputDataLastDim) &&
4393 <<
"expect last dimension of input_scale ("
4394 << inputScaleDims.back()
4395 <<
") to be equal to last dimension of input_data / block_size ("
4396 << inputDataDims.back() / blockSize <<
")";
4403LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4404 MLIRContext *context, ::std::optional<Location> location,
4405 CastToBlockScaledOp::Adaptor adaptor,
4406 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4407 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4408 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4409 if (!inputShape.hasRank())
4413 SmallVector<int64_t> outputScaleShape;
4414 inputShape.getDims(outputScaleShape);
4415 const int64_t lastDimLoc = inputShape.getRank() - 1;
4416 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4417 if (ShapedType::isStatic(lastDimSize)) {
4418 const unsigned int blockSize =
4419 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4420 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4422 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4426LogicalResult CastToBlockScaledOp::verify() {
4427 const Type inputDataType = getInputData().getType();
4428 const Type outputDataType = getResult(0).getType();
4430 return emitOpError() <<
"require compatible shapes for input_data ("
4431 << inputDataType <<
") and " <<
"output_data ("
4432 << outputDataType <<
")";
4434 const unsigned int blockSize =
4436 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4437 if (inputDataShape.
hasRank()) {
4438 const int64_t inputDataLastDim =
4440 if (ShapedType::isStatic(inputDataLastDim) &&
4441 inputDataLastDim % blockSize != 0)
4442 return emitOpError() <<
"expect last dimension of input_data ("
4444 <<
") to be divisible by block_size (" << blockSize
4448 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4449 const Type outputScaleType = getResult(1).getType();
4450 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4452 SmallVector<int64_t> outputDataDims, outputScaleDims;
4453 outputDataShape.
getDims(outputDataDims);
4454 outputScaleShape.
getDims(outputScaleDims);
4456 if (outputDataDims.size() != outputScaleDims.size() ||
4458 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4459 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4460 return emitOpError() <<
"require compatible shapes for output_data ("
4461 << outputDataType <<
") and " <<
"output_scale ("
4463 <<
") except for the last dimension";
4465 const int64_t outputDataLastDim = outputDataDims.back();
4466 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4467 outputScaleDims.back()};
4468 if (ShapedType::isStatic(outputDataLastDim) &&
4471 <<
"expect last dimension of output_scale ("
4472 << outputScaleDims.back()
4473 <<
") to be equal to last dimension of output_data / block_size ("
4474 << outputDataDims.back() / blockSize <<
")";
4480LogicalResult IfOp::inferReturnTypeComponents(
4481 MLIRContext *context, ::std::optional<Location> location,
4482 IfOp::Adaptor adaptor,
4483 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4484 llvm::SmallVector<tosa::YieldOp> yieldOps;
4485 for (Region *region : adaptor.getRegions()) {
4486 for (
auto &block : *region)
4487 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4488 yieldOps.push_back(returnOp);
4491 if (yieldOps.empty())
4495 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4496 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4497 for (
auto operand : yieldOps.front().getOperands()) {
4498 resultKnowledge.push_back(
4502 for (
auto yieldOp : yieldOps) {
4503 if (resultKnowledge.size() != yieldOp.getNumOperands())
4506 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4507 int32_t index = it.index();
4509 resultKnowledge[index],
4513 resultKnowledge[index] = meet;
4517 for (
const ValueKnowledge &
result : resultKnowledge) {
4518 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4524LogicalResult WhileOp::inferReturnTypeComponents(
4525 MLIRContext *context, ::std::optional<Location> location,
4526 WhileOp::Adaptor adaptor,
4527 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4528 llvm::SmallVector<tosa::YieldOp> yieldOps;
4529 for (
auto &block : adaptor.getBodyGraph())
4530 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4531 yieldOps.push_back(returnOp);
4535 if (yieldOps.empty())
4539 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4540 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4541 for (
auto operand : yieldOps.front().getOperands()) {
4542 resultKnowledge.push_back(
4546 for (
auto yieldOp : yieldOps) {
4547 if (resultKnowledge.size() != yieldOp.getNumOperands())
4550 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4551 int32_t index = it.index();
4553 resultKnowledge[index],
4555 resultKnowledge[index] = meet;
4560 for (
const ValueKnowledge &
result : resultKnowledge) {
4561 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4567std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4568 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4569 return llvm::to_vector<4>(vt.getShape());
4570 return std::nullopt;
4576 StringRef prefix =
"") {
4577 assert(blocksArgs.size() == initializers.size() &&
4578 "expected same length of arguments and initializers");
4579 if (initializers.empty())
4582 parser << prefix <<
'(';
4583 llvm::interleaveComma(
4584 llvm::zip(blocksArgs, initializers), parser,
4585 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4590ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
4592 result.regions.reserve(2);
4593 Region *thenRegion =
result.addRegion();
4594 Region *elseRegion =
result.addRegion();
4596 OpAsmParser::UnresolvedOperand cond;
4601 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4602 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4605 OptionalParseResult listResult =
4613 "expected type for condition operand");
4619 "expected type for condition operand");
4627 FunctionType functionType;
4631 <<
"expected list of types for block arguments "
4632 <<
"followed by arrow type and list of return types";
4634 result.addTypes(functionType.getResults());
4636 if (functionType.getNumInputs() != operands.size()) {
4638 <<
"expected as many input types as operands " <<
"(expected "
4639 << operands.size() <<
" got " << functionType.getNumInputs()
4670void IfOp::print(OpAsmPrinter &p) {
4671 p <<
" " << getCondition();
4674 getInputList(),
" ");
4676 p << getCondition().getType();
4678 if (!getInputList().empty()) {
4680 llvm::interleaveComma(getInputList().getTypes(), p);
4689 auto &elseRegion = getElseGraph();
4690 if (!elseRegion.
empty()) {
4698LogicalResult IfOp::verify() {
4700 "'then_graph' arguments", getInputList(),
4706 "'else_graph' arguments", getInputList(),
4712 if (getThenGraph().front().mightHaveTerminator()) {
4714 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4716 *
this, thenYield.getInputs(),
"'then_graph' results",
4717 getOutputList(),
"'output_list'")
4723 if (getElseGraph().front().mightHaveTerminator()) {
4725 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4727 *
this, elseYield.getInputs(),
"'else_graph' results",
4728 getOutputList(),
"'output_list'")
4733 auto condType = getCondition().getType();
4735 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4741LogicalResult WhileOp::verify() {
4743 getOutputList(),
"'output_list'")
4748 "'cond_graph' arguments", getInputList(),
4754 "'body_graph' arguments", getInputList(),
4759 if (getBodyGraph().front().mightHaveTerminator()) {
4761 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4763 "'body_graph' results",
4764 getInputList(),
"'input_list'")
4771 if (!getCondGraph().front().mightHaveTerminator())
4775 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4779 if (condYield.getInputs().size() != 1)
4780 return emitOpError() <<
"require 'cond_graph' only have one result";
4782 auto condOutType = condYield.getInputs()[0].getType();
4784 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4788 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4794LogicalResult ReverseOp::verify() {
4799 TensorType inputType = getInput1().getType();
4800 TensorType outputType = getOutput().getType();
4801 int32_t reverseAxis = getAxis();
4803 if (reverseAxis < 0)
4804 return emitOpError(
"expected non-negative reverse axis");
4806 int64_t inputRank = inputType.getRank();
4809 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4811 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4815 int64_t outputRank = outputType.getRank();
4816 if (inputType.
hasRank() && outputRank != inputType.getRank())
4818 "expect output tensor rank to be equal to input tensor rank");
4819 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4821 << outputRank <<
") to be larger than reverse axis ("
4822 << reverseAxis <<
")";
4827LogicalResult tosa::SelectOp::verify() {
4838 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4839 if (!predicateType) {
4840 return emitOpError(
"expect shaped tensor for input1, got ")
4841 << getInput1().getType();
4843 auto predicateElementType = predicateType.getElementType();
4844 if (!predicateElementType.isInteger(1)) {
4845 return emitOpError(
"expect element type of bool for input1, got ")
4846 << predicateElementType;
4852LogicalResult tosa::VariableReadOp::verify() {
4860LogicalResult tosa::VariableWriteOp::verify() {
4869ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
4870 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4871 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4872 Region *cond =
result.addRegion();
4873 Region *body =
result.addRegion();
4875 OptionalParseResult listResult =
4880 FunctionType functionType;
4885 result.addTypes(functionType.getResults());
4887 if (functionType.getNumInputs() != operands.size()) {
4889 <<
"expected as many input types as operands " <<
"(expected "
4890 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
4900 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4901 regionArgs[i].type = functionType.getInput(i);
4903 return failure(parser.
parseRegion(*cond, regionArgs) ||
4908void WhileOp::print(OpAsmPrinter &parser) {
4910 getInputList(),
" ");
4913 getResults().getTypes());
4927 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4928 if (llvm::isa<FloatType>(srcElemType)) {
4930 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4931 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4933 if (llvm::isa<IntegerType>(srcElemType)) {
4936 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4938 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4939 return std::nullopt;
4947 return mlir::isa<tosa::shapeType>(t);
4954 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4960 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4961 Operation *definingOp = v.getDefiningOp();
4963 return op->
emitOpError(
"shape operand is not compile time resolvable");
4976 auto getRank = [](
const Type type) {
4977 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4983 for (
auto type : operandTypes) {
4984 if (getRank(type) != rank) {
4985 return op->
emitOpError(
"operands don't have matching ranks");
4988 for (
auto type : resultTypes) {
4989 if (getRank(type) != rank) {
4990 return op->
emitOpError(
"result shape has different rank than operands");
5000LogicalResult tosa::ConstShapeOp::verify() {
5002 auto valuesRank = getValues().getType().getRank();
5003 if (valuesRank != 1)
5004 return emitOpError(
"expect elements in attribute values with rank 1");
5006 auto count = getValues().getNumElements();
5007 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
5008 if (count != rank && (count != 1 || rank != 0)) {
5009 return emitOpError(
"expect number of elements in attribute values (")
5010 << count <<
") to be equal to the rank (" << rank
5011 <<
") for the result shape type";
5016LogicalResult tosa::DimOp::verify() {
5017 const tosa::shapeType outShapeType =
5018 cast<tosa::shapeType>(getResult().
getType());
5019 if (outShapeType.getRank() != 1)
5020 return emitOpError(
"expect output shape type to contain one element, got ")
5025 const int64_t inputRank = inputType.getRank();
5026 const int64_t axis = getAxisAttr().getInt();
5027 if (axis < 0 || axis >= inputRank)
5028 return emitOpError(
"expect axis to be in the range [0, ")
5029 << inputRank <<
"), got " << axis;
5034LogicalResult tosa::ConcatShapeOp::verify() {
5035 const tosa::shapeType outShapeType =
5036 cast<tosa::shapeType>(getResult().
getType());
5037 const int64_t outputRank = outShapeType.getRank();
5040 if (inputList.size() == 0)
5041 return emitOpError(
"requires at least one input shape");
5043 if (llvm::any_of(inputList, [](Value v) {
5044 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
5046 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
5048 const int64_t inputsRank =
5049 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
5050 const tosa::shapeType inShapeType =
5051 cast<tosa::shapeType>(input.
getType());
5052 return acc + inShapeType.getRank();
5054 if (outputRank != inputsRank)
5055 return emitOpError(
"requires output shape rank to be equal to the sum of "
5056 "the input shape ranks (")
5057 << inputsRank <<
"), got " << outputRank;
5062LogicalResult tosa::SliceShapeOp::verify() {
5063 std::optional<int32_t> start;
5064 DenseIntElementsAttr startAttr;
5066 start = startAttr.getValues<int32_t>()[0];
5067 if (start && start.value() < 0)
5068 return emitOpError(
"expected non-negative start index, got ")
5071 std::optional<int32_t> size;
5072 DenseIntElementsAttr sizeAttr;
5074 size = sizeAttr.getValues<int32_t>()[0];
5075 if (size && size.value() <= 0)
5076 return emitOpError(
"expected positive size, got ") << size.value();
5081 const tosa::shapeType outShapeType =
5082 cast<tosa::shapeType>(getResult().
getType());
5083 const int64_t outputRank = outShapeType.getRank();
5084 if (outputRank != size)
5086 "expected output type size to be equal to size attribute, got ")
5087 << outputRank <<
" vs " << size.value();
5092 const tosa::shapeType inShapeType =
5093 cast<tosa::shapeType>(getInput().
getType());
5094 const int64_t inputRank = inShapeType.getRank();
5095 const int64_t sliceSize = start.value() + size.value();
5096 if (sliceSize > inputRank)
5097 return emitOpError(
"expected start + size to be less than or equal to "
5098 "input shape rank (")
5099 << inputRank <<
"), got " << sliceSize;
5108#define GET_ATTRDEF_CLASSES
5109#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5114#define GET_TYPEDEF_CLASSES
5115#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5121#define GET_OP_CLASSES
5122#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static void updateIfDynamic(int64_t ¤t, int64_t candidate)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
int64_t getOutputRank() const
int64_t getNumSpatialDims() const
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
int64_t getNumSpatialDims() const
int64_t getOutputRank() const
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
int64_t getNumSpatialDims() const
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
int64_t getOutputRank() const
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)