29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
56struct TosaInlinerInterface :
public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
65 IRMapping &map)
const final {
71 IRMapping &map)
const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
78struct TosaDialectBytecodeInterface :
public BytecodeDialectInterface {
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
85 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer)
const override {
91 return ::writeAttribute(attr, writer);
97 Type readType(DialectBytecodeReader &reader)
const override {
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer)
const override {
103 return ::writeType(type, writer);
106 void writeVersion(DialectBytecodeWriter &writer)
const final {
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader)
const final {
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version)
const final {
131 return {&getBodyGraph()};
140 return dim == -1 ? ShapedType::kDynamic : dim;
146 Type elementType = variableOp.getType();
149 return RankedTensorType::get(
shape, elementType);
156void TosaDialect::initialize() {
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
200ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
202 TypeAttr &typeAttr) {
203 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
206 <<
"expected ranked type";
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
216 <<
"expected shaped type";
233 <<
"expected attribute";
235 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
240 <<
"expected Typed attr";
243 initialValueAttr =
nullptr;
247 <<
"expected type after colon";
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
254 TypeAttr typeAttr,
Attribute initialValueAttr) {
255 bool needsSpace =
false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
267 if (initialValueAttr) {
278template <
typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
281 llvm::StringRef name;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name ==
"rounding_mode" &&
291 auto sym = symbolizeRoundingMode(kw);
294 <<
"invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 auto sym = symbolizeResizeMode(kw);
306 <<
"invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 auto sym = symbolizeNanPropagationMode(kw);
319 <<
"invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 auto sym = symbolizeBlockSize(kw);
332 <<
"invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
345template <
typename EnumType>
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
382 parser << namedAttr.
getName().strref() <<
" = ";
384 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 }
else if (
auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
410 toPrint.erase(attr.getName());
416 if (!toPrint.empty()) {
418 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
435 llvm::interleaveComma(op->
getAttrs(), parser,
437 printNamedAttr(parser, namedAttr);
449 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
453 printWithEnumHandling(parser, *
this);
457 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
461 printWithEnumHandling(parser, *
this);
465 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
469 printWithEnumHandling(parser, *
this);
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
477 printWithNanPropagationHandling(parser, *
this);
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
485 printWithNanPropagationHandling(parser, *
this);
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
493 printWithNanPropagationHandling(parser, *
this);
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
501 printWithNanPropagationHandling(parser, *
this);
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
509 printWithNanPropagationHandling(parser, *
this);
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
517 printWithNanPropagationHandling(parser, *
this);
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
525 printWithNanPropagationHandling(parser, *
this);
528ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
534 printWithEnumHandling(parser, *
this);
537ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
542void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *
this);
546ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552 printWithEnumHandling(parser, *
this);
555ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
561 printWithEnumHandling(parser, *
this);
576 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
586 Value valZp, StringRef name) {
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
595 if (!bothInts || !sameBitWidth) {
597 <<
"expected " << name <<
" and " << name
598 <<
"_zp to both be integer of the same bitwidth, but got " << eType
599 <<
" vs. " << eZpType;
606 Value src, int32_t val) {
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
621 if (dyn_cast<tosa::mxint8Type>(type))
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
635 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
637 << dimName <<
" of " << operandName <<
" to match size " << currDim
638 <<
", got " << newDim;
646 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
654 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
657 if (!calculatedOutSizeMinusOne.has_value())
659 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
660 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
661 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
662 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
663 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
669 << dimName <<
" did not match expected: "
670 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
693 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
696 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
699 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
702 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
709 "expect both bias and result to have same element type, got ")
710 << biasEType <<
" and " << resultEType;
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
718 "expect both input and weight to have same element type, got ")
719 << inputEType <<
" and " << weightEType;
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
728 if (inputIsFloat != weightIsFloat) {
730 "expect both input and weight to be float or not together, got ")
731 << inputEType <<
" and " << weightEType;
736 if (inputEType != inputZpEType) {
737 return op.emitOpError(
"expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType <<
" and " << inputZpEType;
743 if (weightEType != weightZpEType) {
744 return op.emitOpError(
"expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType <<
" and " << weightZpEType;
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
760LogicalResult tosa::ConstOp::verify() {
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
765 if (!attrType || !outputType) {
766 emitOpError(
"expected tensors for attr/result type");
770 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError(
"expected same attr/result element types");
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
789 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
822 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
836 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
837 return op.emitOpError(
"expect all padding values to be >= 0, got ")
841 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
842 return op.emitOpError(
"expect all stride values to be >= 1, got ")
846 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
847 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
861 if (inputType && weightType) {
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0],
"height",
"y",
"top",
"bottom")))
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1],
"width",
"x",
"left",
"right")))
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0],
"height",
"y",
"top",
"bottom")))
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1],
"width",
"x",
"left",
"right")))
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0],
"depth",
"d",
"front",
"back")))
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1],
"height",
"y",
"top",
"bottom")))
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2],
"width",
"x",
"left",
"right")))
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
920 const int64_t biasChannels = biasType.getDimSize(0);
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels <<
") or 1, got " << biasChannels;
938 StringRef name1,
Type type2,
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
949 <<
"require same element type for " << name1 <<
" (" << elemType1
950 <<
") and " << name2 <<
" (" << elemType2 <<
")";
954 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
955 << name2 <<
" (" << type2 <<
")";
965 if (list1.size() != list2.size())
967 <<
"require same number of values in " << name1 <<
" ("
968 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
970 for (
auto [type1, type2] :
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
997 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1001 return op->emitOpError(
"'")
1002 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1014template <
typename T>
1016 StringRef aName =
"input",
1017 StringRef bName =
"output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1021 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1025 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1041 op.emitOpError(
"expect ")
1042 << aName <<
" and " << bName <<
" to have same element type, got "
1043 << aElementType <<
" and " << bElementType;
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1053 if (
const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError(
"result tensor is not of integer type");
1057 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1058 if (!inputType.hasRank())
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError(
"specified axis is outside the rank of the tensor");
1066 if (!resultType.hasRank())
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1075 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1080template <
typename T>
1083 if (llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1084 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1088 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1089 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1093 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1094 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError(
"expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError(
"expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop <<
", pad_bottom=" << padBottom
1113 <<
", kernel_y=" << kernelY;
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1122 const auto verifyOutputSize =
1126 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError(
"expected input_")
1136 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1137 << padAfterName <<
" - kernel_" << dimAxis
1138 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1139 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1140 << kernelSize <<
") / " << strideSize;
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError(
"calculated output ")
1145 << dimName <<
" did not match expected: " <<
"calculated="
1146 << calculatedOutSize <<
", expected=" << outputSize;
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height",
"y",
"top",
"bottom")))
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width",
"x",
"left",
"right")))
1164LogicalResult tosa::AvgPool2dOp::verify() {
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError(
"accumulator type for integer tensor is not i32");
1177 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1180 if (inputETy.
isBF16() && !accType.isF32())
1181 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1183 if (inputETy.
isF32() && !accType.isF32())
1184 return emitOpError(
"accumulator type for f32 tensor is not f32");
1186 if (inputETy != inputZpETy)
1187 return emitOpError(
"expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy <<
" and " << inputZpETy;
1191 if (resultETy != outputZpETy)
1192 return emitOpError(
"expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy <<
" and " << outputZpETy;
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1207LogicalResult tosa::ClampOp::verify() {
1209 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1210 if (
auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1215 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1216 if (
auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1220 if (inputETy != outputETy)
1221 return emitOpError(
"input/output element types are incompatible.");
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1228 if (inputETy.
isInteger(dataTypeBitWidth)) {
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError(
"min/max attributes types are incompatible with "
1238 "input/output element types.");
1241 const bool isBoolean = inputETy.
isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError(
"expected min_val <= max_val, got min_val=")
1246 << minValAttr <<
", max_val=" << maxValAttr;
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError(
"min/max attributes types are incompatible with "
1257 "input/output element types.");
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr <<
", max_val=" << maxValAttr;
1265 if (maxVal < minVal)
1266 return emitOpError(
"expected min_val <= max_val, got min_val=")
1267 << minValAttr <<
", max_val=" << maxValAttr;
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute(
"pad", pad);
1289 result.addAttribute(
"stride", stride);
1290 result.addAttribute(
"dilation", dilation);
1291 result.addAttribute(
"acc_type", accType);
1292 Type finalOutputType = outputType;
1298 result.addTypes(finalOutputType);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute(
"out_pad", outpad);
1311 result.addAttribute(
"stride", stride);
1312 result.addAttribute(
"acc_type", accType);
1313 Type finalOutputType = outputType;
1319 result.addTypes(finalOutputType);
1330 result.addOperands({a,
b, zps.first, zps.second});
1332 Type finalOutputType{outputType};
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType &&
"Output must be a shaped type");
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1346 finalOutputType = outputShapedType.clone(accElementType);
1348 result.addTypes(finalOutputType);
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1363 if (
auto quantAttr =
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1368 const std::optional<Value> inputZpOp =
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1375 const std::optional<Value> outputZpOp =
1378 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1388 result.addOperands({input});
1390 result.addAttribute(
"kernel", kernel);
1391 result.addAttribute(
"stride", stride);
1392 result.addAttribute(
"pad", pad);
1393 result.addAttribute(
"acc_type", accType);
1394 result.types.push_back(outputType);
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1411 const std::optional<Value> input1ZpOp =
1415 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1418 const std::optional<Value> outputZpOp =
1422 loc,
"Failed to create output zero point for quantized NEGATE op");
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1431 result.addOperands({input});
1434 result.types.push_back(outputType);
1447 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1455 StringRef name,
Type variableType,
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1462 (
void)
emitError(loc,
"variable type must be a shaped type");
1465 if (!shapedType.hasRank()) {
1466 (
void)
emitError(loc,
"variable type must be a ranked type");
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1475 result.addAttribute(
"sym_name", nameAttr);
1476 result.addAttribute(
"var_shape", varShapeAttr);
1477 result.addAttribute(
"type", elementTypeAttr);
1478 result.addAttribute(
"initial_value", initialValue);
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1501 for (
int i = 0, e = operands.size(); i != e; ++i) {
1503 if (!
shape.hasRank()) {
1508 outRank = std::max<int64_t>(outRank,
shape.getRank());
1511 outShape.resize(outRank, 1);
1513 for (
int i = 0, e = operands.size(); i != e; ++i) {
1515 auto rankDiff = outShape.size() -
shape.getRank();
1517 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 =
shape.getDimSize(i);
1521 const FailureOr<int64_t> maybeResolvedDim =
1523 if (failed(maybeResolvedDim))
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1541 if (!inputShape.hasRank()) {
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1551 outShape.push_back(inputShape.getDimSize(i));
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1564 if (!inputShape.hasRank())
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1589 << dimName <<
" to be a power of two, got " << dimSize;
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1597 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1620 outputType.getShape().drop_back())))
1621 return emitOpError(
"expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType <<
" output=" << outputType;
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1640 inferredReturnShapes.push_back(
1642 inferredReturnShapes.push_back(
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1652 if (!inputRealType || !inputImagType)
1655 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a :
b;
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1682 bool hasRankedInput =
false;
1683 for (
auto operand : adaptor.getOperands()) {
1685 if (!operandShape.hasRank())
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1693 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1705 hasRankedInput =
true;
1708 if (adaptor.getInput1().empty())
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1720 for (
auto operand : adaptor.getOperands()) {
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1730 concatDimSize += operandShape.getDimSize(axis);
1733 outputShape[axis] = concatDimSize;
1739LogicalResult tosa::ConcatOp::verify() {
1741 auto outType = getOutput().getType();
1745 if (inputList.empty())
1748 if (!llvm::all_of(inputList, [&](
auto input) {
1750 *
this, input.getType(), outType));
1755 const int32_t axis = getAxis();
1757 for (
const auto &input : inputList) {
1758 const Type inputType = input.getType();
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1763 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1764 return emitOpError(
"expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1771 const auto allOperandsHasRank = [](
const Value input) {
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1777 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum =
index + 1;
1783 if (inputRank != firstInputRank)
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1790 for (
int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1796 if (inputDim != firstInputDim)
1797 return emitOpError(
"expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1800 <<
" on operands 0 and " << operandNum;
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError(
"expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() <<
" vs " << firstInputRank;
1811 for (
const auto &input : inputList) {
1813 if (inputShape.isDynamicDim(axis)) {
1818 axisSum += inputShape.getDimSize(axis);
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError(
"requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum <<
" and " << outputShape.getDimSize(axis);
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1837 auto elementType = IntegerType::get(context, 1);
1850 if (l.size() != r.size() || l.size() != 1)
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1864 outShape.resize(3, ShapedType::kDynamic);
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1874 outShape[2] = rhsShape.getDimSize(2);
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1887 return emitOpError(
"expect a shaped tensor for input a, got ")
1888 << getA().getType();
1891 return emitOpError(
"expect a shaped tensor for input b, got ")
1892 << getB().getType();
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError(
"expect operands to be both quantized or both not "
1906 << aElementType <<
" and " << bElementType;
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError(
"expect quantized operands to have same widths, got ")
1913 << aQuantWidth <<
" and " << bQuantWidth;
1920 if (aEType != aZpEType) {
1921 return emitOpError(
"expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType <<
" and " << aZpEType;
1928 if (bEType != bZpEType) {
1929 return emitOpError(
"expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType <<
" and " << bZpEType;
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1989LogicalResult MatmulTBlockScaledOp::verify() {
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2015 "a_scale",
"batch")) ||
2017 "a_scale",
"height")))
2025 "b_data",
"batch")) ||
2027 "b_data",
"channels")))
2035 "b_scale",
"batch")) ||
2037 "b_scale",
"width")) ||
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2048 << D <<
" vs N=" << N;
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError(
"expect C to be a multiple of block size, got C=")
2054 <<
C <<
", block_size=" << blockSize;
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2064 N = ShapedType::isDynamic(N) ? D : N;
2066 const auto outputType = cast<ShapedType>(getResult().
getType());
2067 if (outputType.hasRank() &&
2071 auto stringifyDim = [&](
int64_t d) {
2072 if (ShapedType::isDynamic(d))
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError <<
" to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2112 outputShape.reserve(inputShape.getRank());
2113 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2122 outputShape.push_back(ShapedType::kDynamic);
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2133LogicalResult tosa::PadOp::verify() {
2140 if (
auto padConst = getPadConst()) {
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2152 if (!inputType || !outputType)
2159 auto inputRank = inputType.getRank();
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2166 return emitOpError() <<
"padding tensor must have " << inputRank
2167 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2168 << paddingValues.size();
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2173 for (
int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2179 <<
"invalid padding values at dimension " << i
2180 <<
": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart <<
", " << padEnd <<
"]";
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() <<
"mismatch in output shape at dimension " << i
2191 <<
": expected " << inputShape[i] <<
" + "
2192 << padStart <<
" + " << padEnd <<
" = "
2193 << (inputShape[i] + padStart + padEnd)
2194 <<
", but got " << outputShape[i];
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2223 if (inputShape.hasRank()) {
2224 for (
size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2232 outputShape[i] = size[i];
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2240 outputShape[i] = size[i];
2252LogicalResult tosa::SliceOp::verify() {
2253 const Value input = getInput1();
2254 const Value output = getOutput();
2260 const Value start = getStart();
2261 const Value size = getSize();
2265 if (inputShape.hasRank()) {
2266 const auto inputRank = inputShape.getRank();
2267 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2269 "expect input1 and output to have the same ranks, got ")
2270 << inputRank <<
" and " << outputShape.getRank();
2272 const auto startShapeRank =
2273 llvm::cast<tosa::shapeType>(start.
getType()).getRank();
2274 if (inputRank != startShapeRank)
2275 return emitOpError(
"length of start is not equal to rank of input shape");
2277 const auto sizeShapeRank =
2278 llvm::cast<tosa::shapeType>(size.
getType()).getRank();
2279 if (inputRank != sizeShapeRank)
2280 return emitOpError(
"length of size is not equal to rank of input shape");
2285 if (startValues.size()) {
2286 if (llvm::any_of(startValues, [](
const int64_t v) {
2289 return emitOpError(
"start values must be non-negative, got [")
2290 << startValues <<
"]";
2297 if (llvm::any_of(sizeValues, [](
const int64_t v) {
2300 return emitOpError(
"size values must be > 0, got [") << sizeValues <<
"]";
2301 if (outputShape.hasRank()) {
2303 outputShape.getDims(outputDims);
2304 const bool hasNoInferableDims = llvm::all_of(
2306 if (hasNoInferableDims &&
2308 return emitOpError(
"expected output shape to match size values, got ")
2309 << output.
getType() <<
" vs [" << sizeValues <<
"]";
2312 if (inputShape.hasRank() && startValues.size()) {
2314 inputShape.getDims(inputDims);
2315 for (
const auto &[
index, vals] :
2316 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2317 const auto &[start, size, inputDim] = vals;
2319 ShapedType::isDynamic(inputDim))
2321 if (start + size > inputDim)
2322 return emitOpError(
"start + size must be less than or equal to input "
2323 "dimension size, got start=")
2324 << start <<
", size=" << size
2325 <<
" vs input dim size=" << inputDim <<
" at dimension "
2333LogicalResult tosa::MulOp::inferReturnTypeComponents(
2334 MLIRContext *context, ::std::optional<Location> location,
2349LogicalResult tosa::MulOp::verify() {
2350 const Value output = getOutput();
2355 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2356 IntegerType lhsIntType =
2358 IntegerType rhsIntType =
2360 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2361 return emitOpError(
"requires the same element type for all operands");
2366 if (lhsIntType.getWidth() > resIntType.getWidth())
2367 return emitOpError(
"invalid data type size for operands or result");
2372 for (
int i = 0; i < 2; ++i) {
2375 "requires the same element type for all operands and results");
2379 ElementsAttr shiftElem;
2381 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2383 return emitOpError() <<
"require shift to be 0 for float type";
2391 TypeRange operandTypes = getOperandTypes();
2392 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2393 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2395 const bool aHasRank = aType.hasRank();
2396 const bool bHasRank = bType.hasRank();
2397 if (aHasRank && bHasRank) {
2398 const int64_t aRank = aType.getRank();
2399 const int64_t bRank = bType.getRank();
2401 return emitOpError(
"a and b operands don't have matching ranks, got ")
2402 << aRank <<
" and " << bRank;
2407 aType.getShape(), bType.getShape(), resultShape))
2408 return emitOpError(
"a and b operands don't have broadcast-compatible "
2410 << aType <<
" and " << bType;
2413 ShapedType resultType = cast<ShapedType>(output.
getType());
2414 if (!resultType.hasRank())
2417 const int64_t resultRank = resultType.getRank();
2418 if (aHasRank && resultRank != aType.getRank())
2419 return emitOpError(
"result type has different rank than a, got ")
2420 << resultRank <<
" vs " << aType.getRank();
2421 if (bHasRank && resultRank != bType.getRank())
2422 return emitOpError(
"result type has different rank than b, got ")
2423 << resultRank <<
" vs " << bType.getRank();
2428LogicalResult tosa::TableOp::inferReturnTypeComponents(
2429 MLIRContext *context, ::std::optional<Location> location,
2430 TableOp::Adaptor adaptor,
2432 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2434 if (!inputShape.hasRank()) {
2439 inferredReturnShapes.resize(1);
2440 inputShape.getDims(inferredReturnShapes[0]);
2444LogicalResult tosa::TableOp::verify() {
2445 const TensorType inputType = getInput1().getType();
2446 const TensorType outputType = getOutput().getType();
2455 auto inputDims = inputType.
getShape();
2456 auto outputDims = outputType.
getShape();
2457 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2459 auto [inputDim, outputDim] = it.value();
2460 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2461 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2462 <<
" doesn't match dim(input, " << dim
2463 <<
") = " << inputDim;
2476 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2477 [](
const APInt &val) { return val.getSExtValue(); });
2481LogicalResult tosa::TileOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2483 TileOp::Adaptor adaptor,
2490 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2497 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2499 if (!inputShape.hasRank()) {
2500 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2501 inferredReturnShapes.push_back(
2505 if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2509 outputShape.reserve(multiples.size());
2510 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2511 if (multiples[i] == ShapedType::kDynamic) {
2512 outputShape.push_back(ShapedType::kDynamic);
2514 int64_t dim = inputShape.getDimSize(i);
2515 if (dim != ShapedType::kDynamic)
2516 dim *= multiples[i];
2517 outputShape.push_back(dim);
2525LogicalResult tosa::TileOp::verify() {
2531 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2532 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2534 shapeType multiplesType =
2535 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2537 auto multiplesRank = multiplesType.getRank();
2539 if (inputType.hasRank()) {
2540 if (inputType.getRank() != multiplesRank)
2541 return emitOpError(
"expect 'multiples' to have rank ")
2542 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2543 if (outputType.hasRank() &&
2547 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2548 return emitOpError(
"expect 'multiples' array to have length ")
2549 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2552 if (getConstantMultiples(multiples).succeeded() &&
2553 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2555 "expect element of 'multiples' to be positive integer or -1.");
2561 if (l.size() != r.size() || l.size() != 1)
2566LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2567 MLIRContext *context, ::std::optional<Location> location,
2568 ReshapeOp::Adaptor adaptor,
2570 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2575 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2584 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2585 inferredReturnShapes.push_back(
2593 int64_t numElements = inputShape.getNumElements();
2595 for (
auto val : newShapeValue) {
2596 if (ShapedType::isStatic(val)) {
2602 for (
auto &val : newShapeValue) {
2603 if (ShapedType::isDynamic(val))
2604 val = numElements / staticMul;
2607 inferredReturnShapes.push_back(
2612llvm::LogicalResult tosa::ReshapeOp::verify() {
2618 TensorType inputType = getInput1().getType();
2623 return mlir::success();
2627 if (missingDims > 1)
2628 return emitOpError() <<
"expected at most one target dimension to be "
2631 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2635 if ((
int64_t)shapeValues.size() != outputType.getRank())
2636 return emitOpError() <<
"new shape does not match result rank";
2638 for (
auto [newShapeDim, outputShapeDim] :
2639 zip(shapeValues, outputType.getShape())) {
2641 newShapeDim != ShapedType::kDynamic &&
2642 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2643 return emitOpError() <<
"new shape is inconsistent with result shape";
2646 return emitOpError() <<
"new shape has invalid tensor dimension size "
2650 if (inputType.hasStaticShape()) {
2651 int64_t inputElementsNum = inputType.getNumElements();
2652 if (outputType.hasStaticShape()) {
2653 int64_t outputElementsNum = outputType.getNumElements();
2654 if (inputElementsNum != outputElementsNum) {
2655 return emitOpError() <<
"cannot reshape " << inputElementsNum
2656 <<
" elements into " << outputElementsNum;
2662 return (dim > 0) ?
acc * dim :
acc;
2664 bool isStaticNewShape =
2665 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2666 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2667 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2668 return emitOpError() <<
"cannot reshape " << inputElementsNum
2669 <<
" elements into " << newShapeElementsNum;
2673 return mlir::success();
2680 ElementsAttr zpAttr;
2685 Type zpElemType = zpAttr.getElementType();
2687 if (llvm::isa<FloatType>(zpElemType)) {
2688 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2695 if (llvm::isa<IntegerType>(zpElemType)) {
2697 return zpAttr.getValues<APInt>()[0].getSExtValue();
2698 return zpAttr.getValues<APInt>()[0].getZExtValue();
2705template <
typename T>
2707 const std::string &operand) {
2710 if (!zpElemType.
isInteger(8) && zp != 0) {
2712 std::string lower = operand;
2713 llvm::transform(lower, lower.begin(), ::tolower);
2714 return op.emitOpError()
2715 << lower <<
" zero point must be zero for non-int8 integer types";
2723 const std::string &operand) {
2724 bool isInputZp = (operand ==
"Input");
2726 bool tensorUnsigned =
2727 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2728 StringRef tensorName = isInputZp ?
"input" :
"output";
2734 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2735 return op.emitOpError()
2736 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2738 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2739 return op.emitOpError() <<
"expect " << tensorName
2740 <<
"_zp of 0 or 32768 for unsigned int16 "
2741 << tensorName <<
", got " << zp;
2748#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2749 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2750 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2752 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2753 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2772#undef ZERO_POINT_HELPER
2774LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2775 MLIRContext *context, ::std::optional<Location> location,
2776 TransposeOp::Adaptor adaptor,
2778 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2787 const auto inputRank = inputShape.
getRank();
2791 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2797 if (inputRank == 0) {
2803 bool allTheSame =
true;
2804 for (
int i = 1, s = inputRank; i < s; i++) {
2814 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2819 outputShape.resize(inputRank, ShapedType::kDynamic);
2822 if (llvm::any_of(adaptor.getPerms(),
2823 [inputRank](
const auto i) { return i >= inputRank; }))
2826 outputShape.reserve(inputRank);
2827 for (
int i = 0, s = inputRank; i < s; i++) {
2828 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2835LogicalResult tosa::TransposeOp::verify() {
2847 if (inputShape.hasRank() &&
2848 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2849 return emitOpError() <<
"expected perms attribute to have size "
2850 << inputShape.getRank()
2851 <<
" (input rank) but got size "
2852 << constantPerms.size();
2854 if (inputShape.hasRank() && outputShape.hasRank() &&
2855 inputShape.getRank() != outputShape.getRank())
2857 <<
"expected input tensor rank to equal result tensor rank";
2859 if (outputShape.hasRank() &&
2860 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2861 return emitOpError() <<
"expected perms attribute to have size "
2862 << outputShape.getRank()
2863 <<
" (output rank) but got size "
2864 << constantPerms.size();
2866 if (!llvm::all_of(constantPerms,
2867 [&constantPerms](int32_t s) {
2869 static_cast<size_t>(s) < constantPerms.size();
2872 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
2873 return emitOpError() <<
"expected valid permutation indices";
2876 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2877 inputShape.getNumElements() != outputShape.getNumElements())
2878 return emitOpError() <<
"expected input1 and output to have same numbers "
2880 << inputShape.getNumElements() <<
" and "
2881 << outputShape.getNumElements();
2885 if (inputShape.hasRank() && outputShape.hasRank()) {
2886 for (
auto i = 0; i < outputShape.getRank(); i++) {
2887 if (inputShape.isDynamicDim(constantPerms[i]) ||
2888 outputShape.isDynamicDim(i))
2891 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2893 <<
"expected output tensor dim " << i <<
" to match "
2894 <<
"input dim " << constantPerms[i] <<
" with value of "
2895 << inputShape.getDimSize(constantPerms[i]);
2902LogicalResult TransposeOp::reifyResultShapes(
2905 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2907 Value input = getInput1();
2908 auto inputType = cast<TensorType>(input.
getType());
2910 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2911 for (
auto dim : transposePerms) {
2912 int32_t dimInInput = transposePerms[dim];
2913 if (inputType.isDynamicDim(dimInInput))
2915 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2919 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2922 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2926LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2927 MLIRContext *context, ::std::optional<Location> location,
2928 GatherOp::Adaptor adaptor,
2929 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2930 llvm::SmallVector<int64_t> outputShape;
2931 outputShape.resize(3, ShapedType::kDynamic);
2933 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2934 if (valuesShape.hasRank()) {
2935 outputShape[0] = valuesShape.getDimSize(0);
2936 outputShape[2] = valuesShape.getDimSize(2);
2939 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2940 if (indicesShape.hasRank()) {
2941 if (outputShape[0] == ShapedType::kDynamic)
2942 outputShape[0] = indicesShape.getDimSize(0);
2943 if (outputShape[1] == ShapedType::kDynamic)
2944 outputShape[1] = indicesShape.getDimSize(1);
2947 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2951LogicalResult tosa::GatherOp::verify() {
2958 const ShapeAdaptor valuesShape(getValues().
getType());
2960 const ShapeAdaptor outputShape(getOutput().
getType());
2962 int64_t n = ShapedType::kDynamic;
2963 int64_t w = ShapedType::kDynamic;
2964 int64_t c = ShapedType::kDynamic;
2966 if (valuesShape.hasRank()) {
2967 n = valuesShape.getDimSize(0);
2968 c = valuesShape.getDimSize(2);
2970 if (indicesShape.hasRank()) {
2971 const int64_t indicesN = indicesShape.getDimSize(0);
2972 w = indicesShape.getDimSize(1);
2973 if (n == ShapedType::kDynamic)
2975 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2976 return emitOpError() <<
"requires indices dimension 0 to have size " << n
2977 <<
", got " << indicesN;
2979 if (outputShape.hasRank()) {
2980 const int64_t outputN = outputShape.getDimSize(0);
2981 const int64_t outputW = outputShape.getDimSize(1);
2982 const int64_t outputC = outputShape.getDimSize(2);
2983 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2985 return emitOpError() <<
"requires output dimension 0 to have size " << n
2986 <<
", got " << outputN;
2988 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2990 return emitOpError() <<
"requires output dimension 1 to have size " << w
2991 <<
", got " << outputW;
2992 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2994 return emitOpError() <<
"requires output dimension 2 to have size " << c
2995 <<
", got " << outputC;
3000LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3001 MLIRContext *context, ::std::optional<Location> location,
3002 ResizeOp::Adaptor adaptor,
3003 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3004 llvm::SmallVector<int64_t, 4> outputShape;
3005 outputShape.resize(4, ShapedType::kDynamic);
3007 ShapeAdaptor inputShape(adaptor.getInput().getType());
3008 if (!inputShape.hasRank())
3011 outputShape[0] = inputShape.getDimSize(0);
3012 outputShape[3] = inputShape.getDimSize(3);
3013 int64_t inputHeight = inputShape.getDimSize(1);
3014 int64_t inputWidth = inputShape.getDimSize(2);
3016 if ((inputHeight == ShapedType::kDynamic) ||
3017 (inputWidth == ShapedType::kDynamic))
3020 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3031 const int64_t outputHeight =
3032 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3036 const int64_t outputWidth =
3037 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3041 if (outputHeight < 0 || outputWidth < 0) {
3044 "calculated output height and width must be non-negative, "
3046 outputHeight,
", width = ", outputWidth);
3049 outputShape[1] = outputHeight;
3050 outputShape[2] = outputWidth;
3051 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3055LogicalResult tosa::ResizeOp::verify() {
3056 const Value input = getInput();
3057 const Value output = getOutput();
3058 const RankedTensorType inputType =
3059 llvm::dyn_cast<RankedTensorType>(input.
getType());
3060 const RankedTensorType outputType =
3061 llvm::dyn_cast<RankedTensorType>(output.
getType());
3063 SmallVector<int64_t> scaleValues;
3064 SmallVector<int64_t> offsetValues;
3065 SmallVector<int64_t> borderValues;
3073 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3074 return emitOpError(
"expect all scale values to be > 0, got ")
3077 const int64_t scaleYN = scaleValues[0];
3078 const int64_t scaleYD = scaleValues[1];
3079 const int64_t scaleXN = scaleValues[2];
3080 const int64_t scaleXD = scaleValues[3];
3082 const int64_t offsetY = offsetValues[0];
3083 const int64_t offsetX = offsetValues[1];
3085 const int64_t borderY = borderValues[0];
3086 const int64_t borderX = borderValues[1];
3093 const int64_t oh = outputType.getDimSize(1);
3094 const int64_t ow = outputType.getDimSize(2);
3095 const int64_t ih = inputType.getDimSize(1);
3096 const int64_t iw = inputType.getDimSize(2);
3102 if (ih != ShapedType::kDynamic && ih != 1) {
3103 const std::optional<int64_t> calculatedOutHeightMinusOne =
3104 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3105 if (!calculatedOutHeightMinusOne.has_value())
3106 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3108 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3109 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3110 <<
") / " << scaleYD;
3111 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3112 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3113 return emitOpError(
"calculated output height did not match expected: ")
3114 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3121 if (iw != ShapedType::kDynamic && iw != 1) {
3122 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3123 const std::optional<int64_t> calculatedOutWidthMinusOne =
3125 if (!calculatedOutWidthMinusOne.has_value())
3126 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3128 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3129 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3130 <<
") / " << scaleXD;
3131 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3132 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3133 return emitOpError(
"calculated output width did not match expected: ")
3134 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3140LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3141 MLIRContext *context, ::std::optional<Location> location,
3142 ScatterOp::Adaptor adaptor,
3143 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3144 llvm::SmallVector<int64_t> outputShape;
3145 outputShape.resize(3, ShapedType::kDynamic);
3147 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3148 if (valuesInShape.hasRank()) {
3149 outputShape[0] = valuesInShape.getDimSize(0);
3150 outputShape[1] = valuesInShape.getDimSize(1);
3151 outputShape[2] = valuesInShape.getDimSize(2);
3154 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3155 if (indicesShape.hasRank()) {
3156 if (outputShape[0] == ShapedType::kDynamic)
3157 outputShape[0] = indicesShape.getDimSize(0);
3160 ShapeAdaptor inputShape(adaptor.getInput().getType());
3161 if (inputShape.hasRank()) {
3162 if (outputShape[0] == ShapedType::kDynamic)
3163 outputShape[0] = inputShape.getDimSize(0);
3164 if (outputShape[2] == ShapedType::kDynamic)
3165 outputShape[2] = inputShape.getDimSize(2);
3168 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3172LogicalResult tosa::ScatterOp::verify() {
3182 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3184 const ShapeAdaptor inputShape(getInput().
getType());
3185 const ShapeAdaptor outputShape(getValuesOut().
getType());
3187 int64_t n = ShapedType::kDynamic;
3188 int64_t k = ShapedType::kDynamic;
3189 int64_t w = ShapedType::kDynamic;
3190 int64_t c = ShapedType::kDynamic;
3191 if (valuesInShape.hasRank()) {
3192 n = valuesInShape.getDimSize(0);
3193 k = valuesInShape.getDimSize(1);
3194 c = valuesInShape.getDimSize(2);
3196 if (indicesShape.hasRank()) {
3197 const int64_t indicesN = indicesShape.getDimSize(0);
3198 w = indicesShape.getDimSize(1);
3199 if (n == ShapedType::kDynamic)
3201 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3202 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3203 <<
", got " << indicesN;
3205 if (inputShape.hasRank()) {
3206 const int64_t inputN = inputShape.getDimSize(0);
3207 const int64_t inputW = inputShape.getDimSize(1);
3208 const int64_t inputC = inputShape.getDimSize(2);
3209 if (n == ShapedType::kDynamic)
3211 else if (inputN != ShapedType::kDynamic && n != inputN)
3212 return emitOpError() <<
"requires input dimension 0 to have size " << n
3213 <<
", got " << inputN;
3214 if (w == ShapedType::kDynamic)
3216 else if (inputW != ShapedType::kDynamic && w != inputW)
3217 return emitOpError() <<
"requires input dimension 1 to have size " << w
3218 <<
", got " << inputW;
3220 if (c == ShapedType::kDynamic)
3222 else if (inputC != ShapedType::kDynamic && c != inputC)
3223 return emitOpError() <<
"requires input dimension 2 to have size " << c
3224 <<
", got " << inputC;
3226 if (outputShape.hasRank()) {
3227 const int64_t outputN = outputShape.getDimSize(0);
3228 const int64_t outputK = outputShape.getDimSize(1);
3229 const int64_t outputC = outputShape.getDimSize(2);
3230 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3232 return emitOpError() <<
"requires values_out dimension 0 to have size "
3233 << n <<
", got " << outputN;
3234 if (k == ShapedType::kDynamic)
3236 else if (outputK != ShapedType::kDynamic && k != outputK)
3237 return emitOpError() <<
"requires values_out dimension 1 to have size "
3238 << k <<
", got " << outputK;
3239 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3241 return emitOpError() <<
"requires values_out dimension 2 to have size "
3242 << c <<
", got " << outputC;
3244 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3245 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3254 int64_t axisVal = axis.getValue().getSExtValue();
3255 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3261 operandShape.
getDims(outputShape);
3262 outputShape[axisVal] = 1;
3267#define COMPATIBLE_RETURN_TYPES(OP) \
3268 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3269 if (l.size() != r.size() || l.size() != 1) \
3271 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3273 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3276#define REDUCE_SHAPE_INFER(OP) \
3277 LogicalResult OP::inferReturnTypeComponents( \
3278 MLIRContext *context, ::std::optional<Location> location, \
3279 OP::Adaptor adaptor, \
3280 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3282 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3283 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3284 const Properties &prop = adaptor.getProperties(); \
3285 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3286 inferredReturnShapes); \
3288 COMPATIBLE_RETURN_TYPES(OP)
3296#undef REDUCE_SHAPE_INFER
3298#undef COMPATIBLE_RETURN_TYPES
3300template <
typename T>
3303 TensorType inputType = op.getInput().getType();
3304 TensorType outputType = op.getOutput().getType();
3305 int32_t reduceAxis = op.getAxis();
3307 if (reduceAxis < 0) {
3308 op.emitOpError(
"reduce axis must not be negative");
3312 int64_t inputRank = inputType.getRank();
3315 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3316 op.emitOpError(
"expect input tensor rank (")
3317 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3323 int64_t outputRank = outputType.getRank();
3324 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3326 "expect output tensor rank to be equal to input tensor rank");
3329 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3330 op.emitOpError(
"expect output tensor rank (")
3331 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3337 if (outputRank != 0) {
3338 auto outputShape = outputType.
getShape();
3339 if (!outputType.isDynamicDim(reduceAxis) &&
3340 outputShape[reduceAxis] != 1) {
3341 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3342 << outputShape[reduceAxis];
3350LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
3351LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
3352LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
3353LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
3354LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
3355LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
3369#define NARY_SHAPE_INFER(OP) \
3370 LogicalResult OP::inferReturnTypeComponents( \
3371 MLIRContext *context, ::std::optional<Location> location, \
3372 ValueShapeRange operands, DictionaryAttr attributes, \
3373 PropertyRef properties, RegionRange regions, \
3374 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3375 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3415#undef PRED_SHAPE_INFER
3417LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3418 MLIRContext *context, ::std::optional<Location> location,
3419 NegateOp::Adaptor adaptor,
3421 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3426LogicalResult tosa::NegateOp::verify() {
3428 const Type input1Type = getInput1().getType();
3429 const Type outputType = getOutput().getType();
3434 const SmallVector<Type, 2> types = {input1Type, outputType};
3436 return emitOpError() <<
"requires the same shape for input1 and output";
3439 const Type input1ZpEType =
3441 if (input1EType != input1ZpEType) {
3442 return emitOpError(
"expect both input1 and its zero point are the same "
3443 "element type, got ")
3444 << input1EType <<
" and " << input1ZpEType;
3447 const Type outputZpEType =
3449 if (outputEType != outputZpEType) {
3450 return emitOpError(
"expect both output and its zero point are the same "
3451 "element type, got ")
3452 << outputEType <<
" and " << outputZpEType;
3455 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3456 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3459 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3460 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3471 outputShape.resize(4, ShapedType::kDynamic);
3486 if (ShapedType::isStatic(height)) {
3487 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3488 outputShape[1] = padded / stride[0] + 1;
3491 if (ShapedType::isStatic(width)) {
3492 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3493 outputShape[2] = padded / stride[1] + 1;
3500template <
typename AdaptorT>
3506 if (ShapedType::isDynamic(current))
3507 current = candidate;
3516 : adaptor(adaptor) {}
3520 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3528 outputShape[0] = outputBatch;
3529 inputSpatial[0] = inputHeight;
3530 inputSpatial[1] = inputWidth;
3535 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3543 outputShape[3] = outputChannels;
3544 weightSpatial[0] = kernelHeight;
3545 weightSpatial[1] = kernelWidth;
3554 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3555 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3556 dilationValues.assign(adaptor.getDilation().begin(),
3557 adaptor.getDilation().end());
3562 Conv2DOp::Adaptor adaptor;
3570 : adaptor(adaptor) {}
3574 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3575 if (inputDataShape.
hasRank()) {
3580 outputShape[0] = outputBatch;
3581 inputSpatial[0] = inputHeight;
3582 inputSpatial[1] = inputWidth;
3585 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3586 if (!inputScaleShape.
hasRank())
3600 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3601 if (weightDataShape.
hasRank()) {
3606 outputShape[3] = outputChannels;
3607 weightSpatial[0] = kernelHeight;
3608 weightSpatial[1] = kernelWidth;
3611 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3612 if (!weightScaleShape.
hasRank())
3641 Conv2DBlockScaledOp::Adaptor adaptor;
3649 : adaptor(adaptor) {}
3653 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3662 outputShape[0] = outputBatch;
3663 inputSpatial[0] = inputDepth;
3664 inputSpatial[1] = inputHeight;
3665 inputSpatial[2] = inputWidth;
3670 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3679 outputShape[4] = outputChannels;
3680 weightSpatial[0] = kernelDepth;
3681 weightSpatial[1] = kernelHeight;
3682 weightSpatial[2] = kernelWidth;
3691 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3692 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3693 dilationValues.assign(adaptor.getDilation().begin(),
3694 adaptor.getDilation().end());
3699 Conv3DOp::Adaptor adaptor;
3702template <
typename AdaptorT>
3708 ShapedType::kDynamic);
3710 ShapedType::kDynamic);
3712 ShapedType::kDynamic);
3714 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
3715 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
3717 const ShapeAdaptor biasShape = adaptor.getBias().getType();
3720 if (biasSize != 1) {
3721 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
3722 outputShape[outputChannelDim] =
3723 ShapedType::isDynamic(outputShape[outputChannelDim])
3725 : outputShape[outputChannelDim];
3732 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
3738 for (
int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
3739 if (!ShapedType::isStatic(inputSpatial[dim]) ||
3740 !ShapedType::isStatic(weightSpatial[dim]))
3743 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
3745 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
3746 const int64_t unstridedResult = inputSize - filterSize + 1;
3747 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
3754LogicalResult Conv2DOp::inferReturnTypeComponents(
3755 MLIRContext *context, ::std::optional<Location> location,
3756 Conv2DOp::Adaptor adaptor,
3757 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3761LogicalResult Conv2DOp::verify() {
3768LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3769 MLIRContext *context, ::std::optional<Location> location,
3770 Conv2DBlockScaledOp::Adaptor adaptor,
3771 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3775LogicalResult Conv2DBlockScaledOp::verify() {
3777 getWeightData().
getType(),
"input_data",
3780 getWeightScale().
getType(),
"input_scale",
3783 getOutput().
getType(),
"bias",
"output")))
3787 int64_t N = ShapedType::kDynamic;
3788 int64_t IH = ShapedType::kDynamic;
3789 int64_t IW = ShapedType::kDynamic;
3790 int64_t IC = ShapedType::kDynamic;
3791 int64_t multiplesOfIC = ShapedType::kDynamic;
3792 int64_t OC = ShapedType::kDynamic;
3793 int64_t KH = ShapedType::kDynamic;
3794 int64_t KW = ShapedType::kDynamic;
3796 const ShapeAdaptor inputDataShape(getInputData().
getType());
3797 if (inputDataShape.hasRank()) {
3798 N = inputDataShape.getDimSize(0);
3799 IH = inputDataShape.getDimSize(1);
3800 IW = inputDataShape.getDimSize(2);
3801 IC = inputDataShape.getDimSize(3);
3804 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
3805 if (inputScaleShape.hasRank()) {
3807 "input_scale",
"batch size")) ||
3809 "input_scale",
"input height")) ||
3811 "input_scale",
"input width")))
3813 multiplesOfIC = inputScaleShape.getDimSize(3);
3816 const ShapeAdaptor weightDataShape(getWeightData().
getType());
3817 if (weightDataShape.hasRank()) {
3818 OC = weightDataShape.getDimSize(0);
3819 KH = weightDataShape.getDimSize(1);
3820 KW = weightDataShape.getDimSize(2);
3822 "weight_data",
"input channels")))
3826 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
3827 if (weightScaleShape.hasRank()) {
3829 "weight_scale",
"output channels")) ||
3831 "weight_scale",
"kernel height")) ||
3833 "weight_scale",
"kernel width")) ||
3835 weightScaleShape.getDimSize(3),
3836 "weight_scale",
"input channel blocks")))
3841 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3842 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3843 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
3844 << IC <<
", block_size=" << blockSize;
3847 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3848 multiplesOfIC != IC / blockSize)
3850 "expect scale operands dimension 2 to equal IC/block_size (")
3851 << IC <<
"/" << blockSize <<
")"
3852 <<
", got " << multiplesOfIC;
3855 SmallVector<int64_t> padValues;
3857 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
3858 return emitOpError(
"expect all padding values to be >= 0, got ")
3862 SmallVector<int64_t> strideValues;
3864 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
3865 return emitOpError(
"expect all stride values to be >= 1, got ")
3869 SmallVector<int64_t> dilationValues;
3872 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
3873 return emitOpError(
"expect all dilation values to be >= 1, got ")
3878 const ShapeAdaptor outputShape(getOutput().
getType());
3879 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3880 outputShape.hasRank()) {
3882 padValues[0], padValues[1], strideValues[0],
3883 dilationValues[0],
"height",
"y",
"top",
3886 padValues[2], padValues[3], strideValues[1],
3887 dilationValues[1],
"width",
"x",
"left",
3893 const ShapeAdaptor biasShape(getBias().
getType());
3894 if (biasShape.hasRank() && outputShape.hasRank()) {
3895 const int64_t biasChannels = biasShape.getDimSize(0);
3896 const int64_t outputChannels =
3897 outputShape.getDimSize(outputShape.getRank() - 1);
3898 if (biasChannels == ShapedType::kDynamic ||
3899 outputChannels == ShapedType::kDynamic)
3903 if (biasChannels != outputChannels && biasChannels != 1)
3905 "bias channels expected to be equal to output channels (")
3906 << outputChannels <<
") or 1, got " << biasChannels;
3912LogicalResult Conv3DOp::inferReturnTypeComponents(
3913 MLIRContext *context, ::std::optional<Location> location,
3914 Conv3DOp::Adaptor adaptor,
3915 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3919LogicalResult Conv3DOp::verify() {
3926LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3927 MLIRContext *context, ::std::optional<Location> location,
3928 AvgPool2dOp::Adaptor adaptor,
3929 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3930 ShapeAdaptor inputShape(adaptor.getInput().getType());
3931 const Properties &prop = adaptor.getProperties();
3933 inferredReturnShapes);
3936LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3937 MLIRContext *context, ::std::optional<Location> location,
3938 MaxPool2dOp::Adaptor adaptor,
3939 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3940 ShapeAdaptor inputShape(adaptor.getInput().getType());
3941 const Properties &prop = adaptor.getProperties();
3943 inferredReturnShapes);
3946LogicalResult MaxPool2dOp::verify() {
3957LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3958 MLIRContext *context, ::std::optional<Location> location,
3959 DepthwiseConv2DOp::Adaptor adaptor,
3960 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3961 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3963 int64_t inputWidth = ShapedType::kDynamic;
3964 int64_t inputHeight = ShapedType::kDynamic;
3965 int64_t inputChannels = ShapedType::kDynamic;
3967 int64_t weightWidth = ShapedType::kDynamic;
3968 int64_t weightHeight = ShapedType::kDynamic;
3969 int64_t depthChannels = ShapedType::kDynamic;
3972 ShapeAdaptor inputShape(adaptor.getInput().getType());
3973 if (inputShape.hasRank()) {
3974 outputShape[0] = inputShape.getDimSize(0);
3975 inputHeight = inputShape.getDimSize(1);
3976 inputWidth = inputShape.getDimSize(2);
3977 inputChannels = inputShape.getDimSize(3);
3981 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3982 if (weightShape.hasRank()) {
3983 weightHeight = weightShape.getDimSize(0);
3984 weightWidth = weightShape.getDimSize(1);
3985 inputChannels = ShapedType::isDynamic(inputChannels)
3986 ? weightShape.getDimSize(2)
3988 depthChannels = weightShape.getDimSize(3);
3993 if (ShapedType::isStatic(inputChannels) &&
3994 ShapedType::isStatic(depthChannels)) {
3995 outputShape[3] = inputChannels * depthChannels;
3999 ShapeAdaptor biasShape(adaptor.getBias().getType());
4000 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4001 int64_t bc = biasShape.getDimSize(0);
4002 if (bc != ShapedType::kDynamic && bc != 1)
4003 outputShape[3] = bc;
4006 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4007 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4008 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4010 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4011 int64_t inputSize = inputHeight + padding[0] + padding[1];
4012 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4013 int64_t unstridedResult = inputSize - filterSize + 1;
4014 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4017 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4018 int64_t inputSize = inputWidth + padding[2] + padding[3];
4019 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4020 int64_t unstridedResult = inputSize - filterSize + 1;
4021 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4024 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4028LogicalResult DepthwiseConv2DOp::verify() {
4035LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4036 MLIRContext *context, ::std::optional<Location> location,
4037 TransposeConv2DOp::Adaptor adaptor,
4038 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4039 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4041 int64_t inputWidth = ShapedType::kDynamic;
4042 int64_t inputHeight = ShapedType::kDynamic;
4043 int64_t weightWidth = ShapedType::kDynamic;
4044 int64_t weightHeight = ShapedType::kDynamic;
4047 ShapeAdaptor inputShape(adaptor.getInput().getType());
4048 if (inputShape.hasRank()) {
4049 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4050 ? inputShape.getDimSize(0)
4052 inputHeight = inputShape.getDimSize(1);
4053 inputWidth = inputShape.getDimSize(2);
4057 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4058 if (weightShape.hasRank()) {
4059 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4060 ? weightShape.getDimSize(0)
4062 weightHeight = weightShape.getDimSize(1);
4063 weightWidth = weightShape.getDimSize(2);
4067 ShapeAdaptor biasShape(adaptor.getBias().getType());
4068 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4069 int64_t bc = biasShape.getDimSize(0);
4070 if (bc != ShapedType::kDynamic && bc != 1)
4071 outputShape[3] = bc;
4074 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4075 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4077 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4078 int64_t calculateSize =
4079 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4081 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4084 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4085 int64_t calculateSize =
4086 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4088 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4091 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4095LogicalResult TransposeConv2DOp::verify() {
4099 const llvm::ArrayRef<int64_t> strides = getStride();
4100 const int64_t strideY = strides[0];
4101 const int64_t strideX = strides[1];
4103 if (strideY < 1 || strideX < 1)
4104 return emitOpError(
"expect all stride values to be >= 1, got [")
4107 const auto checkPadAgainstKernelDim =
4108 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4109 llvm::StringRef kernelDimName) -> LogicalResult {
4110 if (padValue <= -kernelDimSize)
4112 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4113 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4118 const llvm::ArrayRef<int64_t> padding = getOutPad();
4119 const int64_t outPadTop = padding[0];
4120 const int64_t outPadBottom = padding[1];
4121 const int64_t outPadLeft = padding[2];
4122 const int64_t outPadRight = padding[3];
4124 const auto weightType =
4125 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4128 const int64_t kernelHeight = weightType.getDimSize(1);
4129 if (ShapedType::isStatic(kernelHeight)) {
4130 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4131 "out_pad_top",
"KH")))
4134 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4135 "out_pad_bottom",
"KH")))
4139 const int64_t kernelWidth = weightType.getDimSize(2);
4140 if (ShapedType::isStatic(kernelWidth)) {
4141 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4142 "out_pad_left",
"KW")))
4145 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4146 "out_pad_right",
"KW")))
4152 const auto outputType =
4153 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4157 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4158 if (inputType && weightType) {
4159 const int64_t inputHeight = inputType.getDimSize(1);
4160 const int64_t kernelHeight = weightType.getDimSize(1);
4161 const int64_t outputHeight = outputType.getDimSize(1);
4163 if (ShapedType::isStatic(inputHeight) &&
4164 ShapedType::isStatic(outputHeight)) {
4166 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4168 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4169 "+ out_pad_top + out_pad_bottom + KH, but got ")
4170 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4171 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4172 <<
" + " << kernelHeight;
4175 const int64_t inputWidth = inputType.getDimSize(2);
4176 const int64_t kernelWidth = weightType.getDimSize(2);
4177 const int64_t outputWidth = outputType.getDimSize(2);
4179 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4181 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4183 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4184 "+ out_pad_left + out_pad_right + KW, but got ")
4185 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4186 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4191 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4196 const int64_t biasChannels = biasType.getDimSize(0);
4199 if (biasChannels == ShapedType::kDynamic)
4202 const int64_t outputChannels = outputType.getDimSize(3);
4203 if (!ShapedType::isDynamic(outputChannels) &&
4204 biasChannels != outputChannels && biasChannels != 1)
4206 "bias channels expected to be equal to output channels (")
4207 << outputChannels <<
") or 1, got " << biasChannels;
4212LogicalResult RescaleOp::verify() {
4213 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
4215 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
4219 auto inputElementType =
4221 if (!mlir::isa<IntegerType>(inputElementType)) {
4222 emitOpError(
"expect input to have integer element type, got ")
4223 << inputElementType;
4227 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
4229 emitOpError(
"expect shaped tensor for output, got ")
4230 << getOutput().getType();
4234 auto outputElementType =
4236 if (!mlir::isa<IntegerType>(outputElementType)) {
4237 emitOpError(
"expect output to have integer element type, got ")
4238 << outputElementType;
4250 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4251 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4254 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4255 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4258 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
4259 if (!multiplierType) {
4260 emitOpError(
"expect shaped tensor for multiplier, got ")
4261 << getMultiplier().getType();
4265 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
4267 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
4272 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4273 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
4274 << multiplierType.getElementType();
4279 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4281 "expect i16 element type for multiplier for scale32=false, got ")
4282 << multiplierType.getElementType();
4286 if (!inputType.hasRank())
4292 int64_t numChannels = 1;
4293 if (getPerChannel()) {
4294 if (inputType.getRank() < 1) {
4295 emitOpError(
"requires input to be at least rank 1 when per_channel is "
4296 "true, but got rank ")
4297 << inputType.getRank();
4300 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4303 if (!multiplierType.hasRank())
4306 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4308 if (multiplierShape[0] != ShapedType::kDynamic &&
4309 multiplierShape[0] != numChannels) {
4311 << numChannels <<
" } for multiplier input, got { "
4312 << multiplierShape[0] <<
" }";
4316 if (!shiftType.hasRank())
4319 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4321 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4323 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
4330LogicalResult RescaleOp::inferReturnTypeComponents(
4331 MLIRContext *context, ::std::optional<Location> location,
4332 RescaleOp::Adaptor adaptor,
4333 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4334 ShapeAdaptor inputShape(adaptor.getInput().getType());
4335 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4339LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4340 MLIRContext *context, ::std::optional<Location> location,
4341 CastFromBlockScaledOp::Adaptor adaptor,
4342 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4343 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4344 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4348LogicalResult CastFromBlockScaledOp::verify() {
4349 const Type inputDataType = getInputData().getType();
4350 const Type outputDataType = getResult().getType();
4352 return emitOpError() <<
"require compatible shapes for input_data ("
4353 << inputDataType <<
") and " <<
"output_data ("
4354 << outputDataType <<
")";
4356 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4358 if (inputDataShape.
hasRank()) {
4359 const unsigned int blockSize =
4361 const int64_t inputDataLastDim =
4363 if (inputDataLastDim % blockSize != 0)
4364 return emitOpError() <<
"expect last dimension of input_data ("
4366 <<
") to be divisible by block_size (" << blockSize
4369 const Type inputScaleType = getInputScale().getType();
4370 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4372 if (inputScaleShape.
hasRank()) {
4373 SmallVector<int64_t> inputDataDims, inputScaleDims;
4374 inputDataShape.
getDims(inputDataDims);
4375 inputScaleShape.
getDims(inputScaleDims);
4377 if (inputDataDims.size() != inputScaleDims.size() ||
4379 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4380 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4382 <<
"require compatible shapes for input_data (" << inputDataType
4383 <<
") and " <<
"input_scale (" << inputScaleType
4384 <<
") except for the last dimension";
4386 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4387 inputScaleDims.back()};
4388 if (ShapedType::isStatic(inputDataLastDim) &&
4391 <<
"expect last dimension of input_scale ("
4392 << inputScaleDims.back()
4393 <<
") to be equal to last dimension of input_data / block_size ("
4394 << inputDataDims.back() / blockSize <<
")";
4401LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4402 MLIRContext *context, ::std::optional<Location> location,
4403 CastToBlockScaledOp::Adaptor adaptor,
4404 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4405 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4406 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4407 if (!inputShape.hasRank())
4411 SmallVector<int64_t> outputScaleShape;
4412 inputShape.getDims(outputScaleShape);
4413 const int64_t lastDimLoc = inputShape.getRank() - 1;
4414 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4415 if (ShapedType::isStatic(lastDimSize)) {
4416 const unsigned int blockSize =
4417 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4418 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4420 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4424LogicalResult CastToBlockScaledOp::verify() {
4425 const Type inputDataType = getInputData().getType();
4426 const Type outputDataType = getResult(0).getType();
4428 return emitOpError() <<
"require compatible shapes for input_data ("
4429 << inputDataType <<
") and " <<
"output_data ("
4430 << outputDataType <<
")";
4432 const unsigned int blockSize =
4434 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4435 if (inputDataShape.
hasRank()) {
4436 const int64_t inputDataLastDim =
4438 if (ShapedType::isStatic(inputDataLastDim) &&
4439 inputDataLastDim % blockSize != 0)
4440 return emitOpError() <<
"expect last dimension of input_data ("
4442 <<
") to be divisible by block_size (" << blockSize
4446 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4447 const Type outputScaleType = getResult(1).getType();
4448 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4450 SmallVector<int64_t> outputDataDims, outputScaleDims;
4451 outputDataShape.
getDims(outputDataDims);
4452 outputScaleShape.
getDims(outputScaleDims);
4454 if (outputDataDims.size() != outputScaleDims.size() ||
4456 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4457 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4458 return emitOpError() <<
"require compatible shapes for output_data ("
4459 << outputDataType <<
") and " <<
"output_scale ("
4461 <<
") except for the last dimension";
4463 const int64_t outputDataLastDim = outputDataDims.back();
4464 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4465 outputScaleDims.back()};
4466 if (ShapedType::isStatic(outputDataLastDim) &&
4469 <<
"expect last dimension of output_scale ("
4470 << outputScaleDims.back()
4471 <<
") to be equal to last dimension of output_data / block_size ("
4472 << outputDataDims.back() / blockSize <<
")";
4478LogicalResult IfOp::inferReturnTypeComponents(
4479 MLIRContext *context, ::std::optional<Location> location,
4480 IfOp::Adaptor adaptor,
4481 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4482 llvm::SmallVector<tosa::YieldOp> yieldOps;
4483 for (Region *region : adaptor.getRegions()) {
4484 for (
auto &block : *region)
4485 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4486 yieldOps.push_back(returnOp);
4489 if (yieldOps.empty())
4493 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4494 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4495 for (
auto operand : yieldOps.front().getOperands()) {
4496 resultKnowledge.push_back(
4500 for (
auto yieldOp : yieldOps) {
4501 if (resultKnowledge.size() != yieldOp.getNumOperands())
4504 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4505 int32_t index = it.index();
4507 resultKnowledge[index],
4511 resultKnowledge[index] = meet;
4515 for (
const ValueKnowledge &
result : resultKnowledge) {
4516 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4522LogicalResult WhileOp::inferReturnTypeComponents(
4523 MLIRContext *context, ::std::optional<Location> location,
4524 WhileOp::Adaptor adaptor,
4525 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4526 llvm::SmallVector<tosa::YieldOp> yieldOps;
4527 for (
auto &block : adaptor.getBodyGraph())
4528 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4529 yieldOps.push_back(returnOp);
4533 if (yieldOps.empty())
4537 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4538 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4539 for (
auto operand : yieldOps.front().getOperands()) {
4540 resultKnowledge.push_back(
4544 for (
auto yieldOp : yieldOps) {
4545 if (resultKnowledge.size() != yieldOp.getNumOperands())
4548 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4549 int32_t index = it.index();
4551 resultKnowledge[index],
4553 resultKnowledge[index] = meet;
4558 for (
const ValueKnowledge &
result : resultKnowledge) {
4559 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
4565std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4566 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4567 return llvm::to_vector<4>(vt.getShape());
4568 return std::nullopt;
4574 StringRef prefix =
"") {
4575 assert(blocksArgs.size() == initializers.size() &&
4576 "expected same length of arguments and initializers");
4577 if (initializers.empty())
4580 parser << prefix <<
'(';
4581 llvm::interleaveComma(
4582 llvm::zip(blocksArgs, initializers), parser,
4583 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4588ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
4590 result.regions.reserve(2);
4591 Region *thenRegion =
result.addRegion();
4592 Region *elseRegion =
result.addRegion();
4594 OpAsmParser::UnresolvedOperand cond;
4599 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4600 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4603 OptionalParseResult listResult =
4611 "expected type for condition operand");
4617 "expected type for condition operand");
4625 FunctionType functionType;
4629 <<
"expected list of types for block arguments "
4630 <<
"followed by arrow type and list of return types";
4632 result.addTypes(functionType.getResults());
4634 if (functionType.getNumInputs() != operands.size()) {
4636 <<
"expected as many input types as operands " <<
"(expected "
4637 << operands.size() <<
" got " << functionType.getNumInputs()
4668void IfOp::print(OpAsmPrinter &p) {
4669 p <<
" " << getCondition();
4672 getInputList(),
" ");
4674 p << getCondition().getType();
4676 if (!getInputList().empty()) {
4678 llvm::interleaveComma(getInputList().getTypes(), p);
4687 auto &elseRegion = getElseGraph();
4688 if (!elseRegion.
empty()) {
4696LogicalResult IfOp::verify() {
4698 "'then_graph' arguments", getInputList(),
4704 "'else_graph' arguments", getInputList(),
4710 if (getThenGraph().front().mightHaveTerminator()) {
4712 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4714 *
this, thenYield.getInputs(),
"'then_graph' results",
4715 getOutputList(),
"'output_list'")
4721 if (getElseGraph().front().mightHaveTerminator()) {
4723 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4725 *
this, elseYield.getInputs(),
"'else_graph' results",
4726 getOutputList(),
"'output_list'")
4731 auto condType = getCondition().getType();
4733 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4739LogicalResult WhileOp::verify() {
4741 getOutputList(),
"'output_list'")
4746 "'cond_graph' arguments", getInputList(),
4752 "'body_graph' arguments", getInputList(),
4757 if (getBodyGraph().front().mightHaveTerminator()) {
4759 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4761 "'body_graph' results",
4762 getInputList(),
"'input_list'")
4769 if (!getCondGraph().front().mightHaveTerminator())
4773 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4777 if (condYield.getInputs().size() != 1)
4778 return emitOpError() <<
"require 'cond_graph' only have one result";
4780 auto condOutType = condYield.getInputs()[0].getType();
4782 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4786 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4792LogicalResult ReverseOp::verify() {
4797 TensorType inputType = getInput1().getType();
4798 TensorType outputType = getOutput().getType();
4799 int32_t reverseAxis = getAxis();
4801 if (reverseAxis < 0)
4802 return emitOpError(
"expected non-negative reverse axis");
4804 int64_t inputRank = inputType.getRank();
4807 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4809 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4813 int64_t outputRank = outputType.getRank();
4814 if (inputType.
hasRank() && outputRank != inputType.getRank())
4816 "expect output tensor rank to be equal to input tensor rank");
4817 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4819 << outputRank <<
") to be larger than reverse axis ("
4820 << reverseAxis <<
")";
4825LogicalResult tosa::SelectOp::verify() {
4836 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4837 if (!predicateType) {
4838 return emitOpError(
"expect shaped tensor for input1, got ")
4839 << getInput1().getType();
4841 auto predicateElementType = predicateType.getElementType();
4842 if (!predicateElementType.isInteger(1)) {
4843 return emitOpError(
"expect element type of bool for input1, got ")
4844 << predicateElementType;
4850LogicalResult tosa::VariableReadOp::verify() {
4858LogicalResult tosa::VariableWriteOp::verify() {
4867ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
4868 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4869 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4870 Region *cond =
result.addRegion();
4871 Region *body =
result.addRegion();
4873 OptionalParseResult listResult =
4878 FunctionType functionType;
4883 result.addTypes(functionType.getResults());
4885 if (functionType.getNumInputs() != operands.size()) {
4887 <<
"expected as many input types as operands " <<
"(expected "
4888 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
4898 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4899 regionArgs[i].type = functionType.getInput(i);
4901 return failure(parser.
parseRegion(*cond, regionArgs) ||
4906void WhileOp::print(OpAsmPrinter &parser) {
4908 getInputList(),
" ");
4911 getResults().getTypes());
4925 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4926 if (llvm::isa<FloatType>(srcElemType)) {
4928 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4929 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4931 if (llvm::isa<IntegerType>(srcElemType)) {
4934 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4936 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4937 return std::nullopt;
4945 return mlir::isa<tosa::shapeType>(t);
4952 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4958 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4959 Operation *definingOp = v.getDefiningOp();
4961 return op->
emitOpError(
"shape operand is not compile time resolvable");
4974 auto getRank = [](
const Type type) {
4975 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4981 for (
auto type : operandTypes) {
4982 if (getRank(type) != rank) {
4983 return op->
emitOpError(
"operands don't have matching ranks");
4986 for (
auto type : resultTypes) {
4987 if (getRank(type) != rank) {
4988 return op->
emitOpError(
"result shape has different rank than operands");
4998LogicalResult tosa::ConstShapeOp::verify() {
5000 auto valuesRank = getValues().getType().getRank();
5001 if (valuesRank != 1)
5002 return emitOpError(
"expect elements in attribute values with rank 1");
5004 auto count = getValues().getNumElements();
5005 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
5006 if (count != rank && (count != 1 || rank != 0)) {
5007 return emitOpError(
"expect number of elements in attribute values (")
5008 << count <<
") to be equal to the rank (" << rank
5009 <<
") for the result shape type";
5014LogicalResult tosa::DimOp::verify() {
5015 const tosa::shapeType outShapeType =
5016 cast<tosa::shapeType>(getResult().
getType());
5017 if (outShapeType.getRank() != 1)
5018 return emitOpError(
"expect output shape type to contain one element, got ")
5023 const int64_t inputRank = inputType.getRank();
5024 const int64_t axis = getAxisAttr().getInt();
5025 if (axis < 0 || axis >= inputRank)
5026 return emitOpError(
"expect axis to be in the range [0, ")
5027 << inputRank <<
"), got " << axis;
5032LogicalResult tosa::ConcatShapeOp::verify() {
5033 const tosa::shapeType outShapeType =
5034 cast<tosa::shapeType>(getResult().
getType());
5035 const int64_t outputRank = outShapeType.getRank();
5038 if (inputList.size() == 0)
5039 return emitOpError(
"requires at least one input shape");
5041 if (llvm::any_of(inputList, [](Value v) {
5042 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
5044 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
5046 const int64_t inputsRank =
5047 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
5048 const tosa::shapeType inShapeType =
5049 cast<tosa::shapeType>(input.
getType());
5050 return acc + inShapeType.getRank();
5052 if (outputRank != inputsRank)
5053 return emitOpError(
"requires output shape rank to be equal to the sum of "
5054 "the input shape ranks (")
5055 << inputsRank <<
"), got " << outputRank;
5060LogicalResult tosa::SliceShapeOp::verify() {
5061 std::optional<int32_t> start;
5062 DenseIntElementsAttr startAttr;
5064 start = startAttr.getValues<int32_t>()[0];
5065 if (start && start.value() < 0)
5066 return emitOpError(
"expected non-negative start index, got ")
5069 std::optional<int32_t> size;
5070 DenseIntElementsAttr sizeAttr;
5072 size = sizeAttr.getValues<int32_t>()[0];
5073 if (size && size.value() <= 0)
5074 return emitOpError(
"expected positive size, got ") << size.value();
5079 const tosa::shapeType outShapeType =
5080 cast<tosa::shapeType>(getResult().
getType());
5081 const int64_t outputRank = outShapeType.getRank();
5082 if (outputRank != size)
5084 "expected output type size to be equal to size attribute, got ")
5085 << outputRank <<
" vs " << size.value();
5090 const tosa::shapeType inShapeType =
5091 cast<tosa::shapeType>(getInput().
getType());
5092 const int64_t inputRank = inShapeType.getRank();
5093 const int64_t sliceSize = start.value() + size.value();
5094 if (sliceSize > inputRank)
5095 return emitOpError(
"expected start + size to be less than or equal to "
5096 "input shape rank (")
5097 << inputRank <<
"), got " << sliceSize;
5106#define GET_ATTRDEF_CLASSES
5107#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5112#define GET_TYPEDEF_CLASSES
5113#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5119#define GET_OP_CLASSES
5120#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static void updateIfDynamic(int64_t ¤t, int64_t candidate)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
int64_t getOutputRank() const
int64_t getNumSpatialDims() const
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
int64_t getNumSpatialDims() const
int64_t getOutputRank() const
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
int64_t getNumSpatialDims() const
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
int64_t getOutputRank() const
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)