29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
57struct TosaInlinerInterface :
public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
66 IRMapping &map)
const final {
72 IRMapping &map)
const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
79struct TosaDialectBytecodeInterface :
public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
86 Attribute readAttribute(DialectBytecodeReader &reader)
const override {
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer)
const override {
92 return ::writeAttribute(attr, writer);
98 Type readType(DialectBytecodeReader &reader)
const override {
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer)
const override {
104 return ::writeType(type, writer);
107 void writeVersion(DialectBytecodeWriter &writer)
const final {
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader)
const final {
114 reader.
emitError(
"Dialect does not support versioning");
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version)
const final {
132 return {&getBodyGraph()};
141 return dim == -1 ? ShapedType::kDynamic : dim;
147 Type elementType = variableOp.getType();
150 return RankedTensorType::get(
shape, elementType);
157void TosaDialect::initialize() {
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
201ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
203 TypeAttr &typeAttr) {
204 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
207 <<
"expected ranked type";
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
217 <<
"expected shaped type";
234 <<
"expected attribute";
236 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
241 <<
"expected Typed attr";
244 initialValueAttr =
nullptr;
248 <<
"expected type after colon";
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
255 TypeAttr typeAttr,
Attribute initialValueAttr) {
256 bool needsSpace =
false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
268 if (initialValueAttr) {
279template <
typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
282 llvm::StringRef name;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name ==
"rounding_mode" &&
292 auto sym = symbolizeRoundingMode(kw);
295 <<
"invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.
getContext(), sym.value());
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
304 auto sym = symbolizeResizeMode(kw);
307 <<
"invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.
getContext(), sym.value());
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
317 auto sym = symbolizeNanPropagationMode(kw);
320 <<
"invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.
getContext(), sym.value());
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
330 auto sym = symbolizeBlockSize(kw);
333 <<
"invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.
getContext(), sym.value());
346template <
typename EnumType>
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
383 parser << namedAttr.
getName().strref() <<
" = ";
385 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 }
else if (
auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
408 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
411 toPrint.erase(attr.getName());
417 if (!toPrint.empty()) {
419 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
436 llvm::interleaveComma(op->
getAttrs(), parser,
438 printNamedAttr(parser, namedAttr);
450 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
454 printWithEnumHandling(parser, *
this);
458 return parseWithEnumHandling<tosa::RoundingMode>(parser,
result);
462 printWithEnumHandling(parser, *
this);
466 return parseWithEnumHandling<tosa::ResizeMode>(parser,
result);
470 printWithEnumHandling(parser, *
this);
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
478 printWithNanPropagationHandling(parser, *
this);
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
486 printWithNanPropagationHandling(parser, *
this);
489ParseResult MaxPool2dAdaptiveOp::parse(
OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
495 printWithNanPropagationHandling(parser, *
this);
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
503 printWithNanPropagationHandling(parser, *
this);
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
511 printWithNanPropagationHandling(parser, *
this);
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
519 printWithNanPropagationHandling(parser, *
this);
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
527 printWithNanPropagationHandling(parser, *
this);
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser,
result);
535 printWithNanPropagationHandling(parser, *
this);
538ParseResult MatmulTBlockScaledOp::parse(
OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
544 printWithEnumHandling(parser, *
this);
547ParseResult CastFromBlockScaledOp::parse(
OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
552void CastFromBlockScaledOp::print(
OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *
this);
556ParseResult CastToBlockScaledOp::parse(
OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
562 printWithEnumHandling(parser, *
this);
565ParseResult Conv2DBlockScaledOp::parse(
OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser,
result);
571 printWithEnumHandling(parser, *
this);
586 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
596 Value valZp, StringRef name) {
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
605 if (!bothInts || !sameBitWidth) {
607 <<
"expected " << name <<
" and " << name
608 <<
"_zp to both be integer of the same bitwidth, but got " << eType
609 <<
" vs. " << eZpType;
616 Value src, int32_t val) {
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
631 if (dyn_cast<tosa::mxint8Type>(type))
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
645 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
647 << dimName <<
" of " << operandName <<
" to match size " << currDim
648 <<
", got " << newDim;
655 auto printDim = [&](
int64_t dim) {
656 if (ShapedType::isDynamic(dim))
662 llvm::interleaveComma(
shape,
diag, printDim);
668 StringRef outputName =
"output") {
669 assert(outputType.hasRank() &&
"expected output type to be ranked");
675 diag << outputName <<
" shape ";
677 diag <<
" to be compatible with inferred shape ";
685 const int64_t stride,
const int64_t dilation,
const llvm::StringRef dimName,
686 const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
687 const llvm::StringRef padAfterName) {
688 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
693 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
694 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
696 if (!calculatedOutSizeMinusOne.has_value())
698 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
699 << padAfterName <<
" - (kernel_" << dimName <<
" - 1) * dilation_"
700 << dimAxis <<
" to be wholly divisible by stride_" << dimAxis
701 <<
", got (" << inputSize <<
" - 1 + " << padBefore <<
" + "
702 << padAfter <<
" - (" << kernelSize <<
" - 1) * " << dilation
705 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
706 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
708 << dimName <<
" did not match expected: "
709 <<
"calculated=" << calculatedOutSize <<
", expected=" << outputSize;
717size_t mlir::tosa::mxint8Type::getDenseElementBitSize()
const {
return 8; }
720mlir::tosa::mxint8Type::convertToAttribute(
ArrayRef<char> rawData)
const {
721 assert(rawData.size() == 1 &&
"expected 1 byte for tosa.mxint8 element");
722 const auto intType = IntegerType::get(
getContext(), 8);
723 return intType.convertToAttribute(rawData);
726LogicalResult mlir::tosa::mxint8Type::convertFromAttribute(
728 const auto intAttr = dyn_cast<IntegerAttr>(attr);
731 const Type attrType = intAttr.getType();
734 return cast<IntegerType>(attrType).convertFromAttribute(attr,
result);
743 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
744 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
746 auto inputEType = inputType.getElementType();
747 auto weightEType = weightType.getElementType();
749 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
751 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
752 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
753 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
755 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
758 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
761 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
764 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
767 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
771 "expect both bias and result to have same element type, got ")
772 << biasEType <<
" and " << resultEType;
776 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
777 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
778 if (inputEType != weightEType) {
780 "expect both input and weight to have same element type, got ")
781 << inputEType <<
" and " << weightEType;
786 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
787 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
790 if (inputIsFloat != weightIsFloat) {
792 "expect both input and weight to be float or not together, got ")
793 << inputEType <<
" and " << weightEType;
798 if (inputEType != inputZpEType) {
799 return op.emitOpError(
"expect both input and its zero point are the same "
800 "element type, got ")
801 << inputEType <<
" and " << inputZpEType;
805 if (weightEType != weightZpEType) {
806 return op.emitOpError(
"expect both weight and its zero point are the same "
807 "element type, got ")
808 << weightEType <<
" and " << weightZpEType;
811 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
812 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
815 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
816 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
822LogicalResult tosa::ConstOp::verify() {
824 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
825 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
827 if (!attrType || !outputType) {
828 emitOpError(
"expected tensors for attr/result type");
832 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
833 outputType.getElementType())) {
838 if (attrType.getElementType() != outputType.getElementType()) {
839 emitOpError(
"expected same attr/result element types");
849 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
851 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
854 auto accType = op.getAccType();
855 if (inputEType.isInteger(8) && !accType.isInteger(32))
856 return op.emitOpError(
"accumulator type for i8 tensor is not i32, got ")
859 if (inputEType.isInteger(16) && !accType.isInteger(48))
860 return op.emitOpError(
"accumulator type for i16 tensor is not i48, got ")
863 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
864 !(accType.isF16() || accType.isF32()))
865 return op.emitOpError(
"accumulator type for f8 tensor is not f16/f32, got ")
868 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
869 return op.emitOpError(
870 "accumulator type for f16 tensor is not f16/f32, got ")
873 if (inputEType.isBF16() && !accType.isF32())
874 return op.emitOpError(
"accumulator type for bf16 tensor is not f32, got ")
877 if (inputEType.isF32() && !accType.isF32())
878 return op.emitOpError(
"accumulator type for f32 tensor is not f32, got ")
882 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
884 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
898 if (llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
899 return op.emitOpError(
"expect all padding values to be >= 0, got ")
903 if (llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
904 return op.emitOpError(
"expect all stride values to be >= 1, got ")
908 if (llvm::any_of(dilations, [](
int64_t d) {
return d < 1; }))
909 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
912 const RankedTensorType outputType =
913 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
918 const RankedTensorType inputType =
919 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
920 const RankedTensorType weightType =
921 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
923 if (inputType && weightType) {
925 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
927 op, inputType.getDimSize(1), weightType.getDimSize(1),
928 outputType.getDimSize(1), padding[0], padding[1], strides[0],
929 dilations[0],
"height",
"y",
"top",
"bottom")))
933 op, inputType.getDimSize(2), weightType.getDimSize(2),
934 outputType.getDimSize(2), padding[2], padding[3], strides[1],
935 dilations[1],
"width",
"x",
"left",
"right")))
940 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
942 op, inputType.getDimSize(1), weightType.getDimSize(0),
943 outputType.getDimSize(1), padding[0], padding[1], strides[0],
944 dilations[0],
"height",
"y",
"top",
"bottom")))
948 op, inputType.getDimSize(2), weightType.getDimSize(1),
949 outputType.getDimSize(2), padding[2], padding[3], strides[1],
950 dilations[1],
"width",
"x",
"left",
"right")))
955 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
957 op, inputType.getDimSize(1), weightType.getDimSize(1),
958 outputType.getDimSize(1), padding[0], padding[1], strides[0],
959 dilations[0],
"depth",
"d",
"front",
"back")))
963 op, inputType.getDimSize(2), weightType.getDimSize(2),
964 outputType.getDimSize(2), padding[2], padding[3], strides[1],
965 dilations[1],
"height",
"y",
"top",
"bottom")))
969 op, inputType.getDimSize(3), weightType.getDimSize(3),
970 outputType.getDimSize(3), padding[4], padding[5], strides[2],
971 dilations[2],
"width",
"x",
"left",
"right")))
976 const RankedTensorType biasType =
977 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
982 const int64_t biasChannels = biasType.getDimSize(0);
984 outputType.getDimSize(outputType.getRank() - 1);
985 if (biasChannels == ShapedType::kDynamic ||
986 outputChannels == ShapedType::kDynamic)
990 if (biasChannels != outputChannels && biasChannels != 1)
991 return op.emitOpError(
992 "bias channels expected to be equal to output channels (")
993 << outputChannels <<
") or 1, got " << biasChannels;
1000 StringRef name1,
Type type2,
1002 auto shapeType1 = dyn_cast<ShapedType>(type1);
1003 auto shapeType2 = dyn_cast<ShapedType>(type2);
1004 if (!shapeType1 || !shapeType2)
1007 auto elemType1 = shapeType1.getElementType();
1008 auto elemType2 = shapeType2.getElementType();
1009 if (elemType1 != elemType2)
1011 <<
"require same element type for " << name1 <<
" (" << elemType1
1012 <<
") and " << name2 <<
" (" << elemType2 <<
")";
1016 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
1017 << name2 <<
" (" << type2 <<
")";
1027 if (list1.size() != list2.size())
1029 <<
"require same number of values in " << name1 <<
" ("
1030 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
1032 for (
auto [type1, type2] :
1049template <
typename T>
1052 op->template getParentWithTrait<OpTrait::SymbolTable>();
1059 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
1063 return op->emitOpError(
"'")
1064 << op.getName() <<
"' has not been declared by 'tosa.variable'";
1076template <
typename T>
1078 StringRef aName =
"input",
1079 StringRef bName =
"output") {
1080 auto aTType = llvm::dyn_cast<TensorType>(aType);
1081 auto bTType = llvm::dyn_cast<TensorType>(bType);
1083 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
1087 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
1090 auto aElementType = aTType.getElementType();
1091 auto bElementType = bTType.getElementType();
1093 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1095 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1096 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1097 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1098 aElementType != bElementType) {
1103 op.emitOpError(
"expect ")
1104 << aName <<
" and " << bName <<
" to have same element type, got "
1105 << aElementType <<
" and " << bElementType;
1111LogicalResult tosa::ArgMaxOp::verify() {
1112 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1115 if (
const auto resultETy = resultType.getElementType();
1116 !resultETy.isIntOrIndex())
1117 return emitOpError(
"result tensor is not of integer type");
1119 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1120 if (!inputType.hasRank())
1124 const int64_t axis = getAxisAttr().getInt();
1125 if (((axis < 0) || axis >= inputType.getRank()))
1126 return emitOpError(
"specified axis is outside the rank of the tensor");
1128 if (!resultType.hasRank())
1134 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1137 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1147 const bool hasKernel = kernel.size() > 0;
1148 const bool hasStrides = strides.size() > 0;
1149 const bool hasPad = padding.size() > 0;
1151 if (hasKernel && llvm::any_of(kernel, [](
int64_t s) {
return s < 1; }))
1152 return op->
emitOpError(
"expect all kernel values to be >= 1, got ")
1155 if (hasStrides && llvm::any_of(strides, [](
int64_t s) {
return s < 1; }))
1156 return op->
emitOpError(
"expect all stride values to be >= 1, got ")
1159 if (hasPad && llvm::any_of(padding, [](
int64_t p) {
return p < 0; }))
1160 return op->
emitOpError(
"expect all padding values to be >= 0, got ")
1163 if (hasKernel && hasPad) {
1165 const int64_t kernelX = kernel[1];
1166 const int64_t padLeft = padding[2];
1167 const int64_t padRight = padding[3];
1168 if (padRight >= kernelX || padLeft >= kernelX)
1169 return op->
emitOpError(
"expected left/right padding to be less than the "
1170 "width of the kernel, got pad_left=")
1171 << padLeft <<
", pad_right=" << padRight
1172 <<
", kernel_x=" << kernelX;
1174 const int64_t kernelY = kernel[0];
1175 const int64_t padTop = padding[0];
1176 const int64_t padBottom = padding[1];
1177 if (padTop >= kernelY || padBottom >= kernelY)
1178 return op->
emitOpError(
"expected top/bottom padding to be less than the "
1179 "height of the kernel, got pad_top=")
1180 << padTop <<
", pad_bottom=" << padBottom
1181 <<
", kernel_y=" << kernelY;
1184 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.
getType());
1185 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.
getType());
1186 if (!inputType || !outputType)
1189 if (hasKernel && hasStrides && hasPad) {
1190 const auto verifyOutputSize =
1194 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1195 const llvm::StringRef padBeforeName,
1196 const llvm::StringRef padAfterName) -> LogicalResult {
1197 if (ShapedType::isDynamic(inputSize))
1200 const std::optional<int64_t> calculatedOutSizeMinusOne =
1201 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1202 if (!calculatedOutSizeMinusOne.has_value())
1204 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1205 << padAfterName <<
" - kernel_" << dimAxis
1206 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1207 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1208 << kernelSize <<
") / " << strideSize;
1210 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1211 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1213 << dimName <<
" did not match expected: " <<
"calculated="
1214 << calculatedOutSize <<
", expected=" << outputSize;
1219 if (failed(verifyOutputSize(inputType.getDimSize(1),
1220 outputType.getDimSize(1), kernel[0], strides[0],
1221 padding[0], padding[1],
"height",
"y",
"top",
1225 if (failed(verifyOutputSize(
1226 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1227 strides[1], padding[2], padding[3],
"width",
"x",
"left",
"right")))
1233template <
typename T>
1236 op.getPad(), op.getInput(), op.getOutput());
1239template <
typename T>
1243 const Type inputZpETy =
1245 const Type outputZpETy =
1248 auto accType = op.getAccType();
1249 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1250 return op.emitOpError(
"accumulator type for integer tensor is not i32");
1252 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1253 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1255 if (inputETy.
isBF16() && !accType.isF32())
1256 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
1258 if (inputETy.
isF32() && !accType.isF32())
1259 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
1261 if (inputETy != inputZpETy)
1262 return op.emitOpError(
"expect both input and its zero point are the same "
1263 "element type, got ")
1264 << inputETy <<
" and " << inputZpETy;
1266 if (resultETy != outputZpETy)
1267 return op.emitOpError(
"expect both output and its zero point are the same "
1268 "element type, got ")
1269 << resultETy <<
" and " << outputZpETy;
1271 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1272 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1275 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1276 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1283struct AdaptivePoolingConstShapeValues {
1284 llvm::SmallVector<int64_t> kernel;
1285 llvm::SmallVector<int64_t> stride;
1286 llvm::SmallVector<int64_t> pad;
1290template <
typename T>
1292 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1293 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1295template <
typename T,
1296 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1299 T op, AdaptivePoolingConstShapeValues &values) {
1305LogicalResult tosa::AvgPool2dOp::verify() {
1313LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1314 AdaptivePoolingConstShapeValues values;
1323 values.pad, getInput(), getOutput())))
1332LogicalResult tosa::ClampOp::verify() {
1334 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1335 if (
auto quantType =
1336 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1340 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1341 if (
auto quantType =
1342 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1345 if (inputETy != outputETy)
1346 return emitOpError(
"input/output element types are incompatible.");
1348 auto maxValAttr = getMaxValAttr();
1349 auto minValAttr = getMinValAttr();
1353 if (inputETy.
isInteger(dataTypeBitWidth)) {
1357 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1358 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1359 if (!intMaxValAttr || !intMinValAttr ||
1360 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1361 (intMaxValAttr.getType() != inputETy))
1362 return emitOpError(
"min/max attributes types are incompatible with "
1363 "input/output element types.");
1366 const bool isBoolean = inputETy.
isInteger(1);
1367 const APInt minVal = intMinValAttr.getValue();
1368 const APInt maxVal = intMaxValAttr.getValue();
1369 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1370 return emitOpError(
"expected min_val <= max_val, got min_val=")
1371 << minValAttr <<
", max_val=" << maxValAttr;
1376 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1377 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1378 if (!floatMaxValAttr || !floatMinValAttr ||
1379 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1380 (floatMaxValAttr.getType() != inputETy))
1381 return emitOpError(
"min/max attributes types are incompatible with "
1382 "input/output element types.");
1384 const APFloat minVal = floatMinValAttr.getValue();
1385 const APFloat maxVal = floatMaxValAttr.getValue();
1386 if (minVal.isNaN() || maxVal.isNaN())
1387 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1388 << minValAttr <<
", max_val=" << maxValAttr;
1390 if (maxVal < minVal)
1391 return emitOpError(
"expected min_val <= max_val, got min_val=")
1392 << minValAttr <<
", max_val=" << maxValAttr;
1412 result.addOperands({input, weight, bias, zps.first, zps.second});
1413 result.addAttribute(
"pad", pad);
1414 result.addAttribute(
"stride", stride);
1415 result.addAttribute(
"dilation", dilation);
1416 result.addAttribute(
"acc_type", accType);
1417 Type finalOutputType = outputType;
1423 result.addTypes(finalOutputType);
1434 result.addOperands({input, weight, bias, zps.first, zps.second});
1435 result.addAttribute(
"out_pad", outpad);
1436 result.addAttribute(
"stride", stride);
1437 result.addAttribute(
"acc_type", accType);
1438 Type finalOutputType = outputType;
1444 result.addTypes(finalOutputType);
1455 result.addOperands({a,
b, zps.first, zps.second});
1457 Type finalOutputType{outputType};
1460 auto inputBits = eType.getIntOrFloatBitWidth();
1462 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1463 assert(outputShapedType &&
"Output must be a shaped type");
1465 IntegerType accElementType;
1466 if (inputBits == 16)
1471 finalOutputType = outputShapedType.clone(accElementType);
1473 result.addTypes(finalOutputType);
1482 DenseArrayAttr kernel, DenseArrayAttr stride,
1483 DenseArrayAttr pad, TypeAttr accType) {
1488 if (
auto quantAttr =
1490 inputZp = quantAttr.getInputZp();
1491 outputZp = quantAttr.getOutputZp();
1493 const std::optional<Value> inputZpOp =
1498 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1500 const std::optional<Value> outputZpOp =
1503 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1504 "quantized AVG_POOL2D op");
1507 if (inputZpOp && outputZpOp) {
1508 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1513 result.addOperands({input});
1515 result.addAttribute(
"kernel", kernel);
1516 result.addAttribute(
"stride", stride);
1517 result.addAttribute(
"pad", pad);
1518 result.addAttribute(
"acc_type", accType);
1519 result.types.push_back(outputType);
1532 if (
auto quantAttr =
1534 inputZp = quantAttr.getInputZp();
1535 outputZp = quantAttr.getOutputZp();
1537 const std::optional<Value> inputZpOp =
1541 "Failed to create input zero point tensor for quantized "
1542 "AVG_POOL2D_ADAPTIVE op");
1544 const std::optional<Value> outputZpOp =
1547 (
void)
emitError(loc,
"Failed to create output zero point tensor for "
1548 "quantized AVG_POOL2D_ADAPTIVE op");
1551 if (inputZpOp && outputZpOp) {
1556 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1557 kernelShape, strideShape, padShape});
1562 result.addOperands({input});
1564 result.addAttribute(
"acc_type", accType);
1565 result.types.push_back(outputType);
1579 input1Zp = quantAttr.getInputZp();
1580 outputZp = quantAttr.getOutputZp();
1582 const std::optional<Value> input1ZpOp =
1586 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1589 const std::optional<Value> outputZpOp =
1593 loc,
"Failed to create output zero point for quantized NEGATE op");
1596 if (input1ZpOp && outputZpOp) {
1597 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1602 result.addOperands({input});
1605 result.types.push_back(outputType);
1618 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1621 result.addOperands({input, paddings, padConstOp});
1622 result.types.push_back(outputType);
1626 StringRef name,
Type variableType,
1631 auto shapedType = dyn_cast<ShapedType>(variableType);
1633 (
void)
emitError(loc,
"variable type must be a shaped type");
1636 if (!shapedType.hasRank()) {
1637 (
void)
emitError(loc,
"variable type must be a ranked type");
1641 auto elementType = shapedType.getElementType();
1642 auto elementTypeAttr = TypeAttr::get(elementType);
1646 result.addAttribute(
"sym_name", nameAttr);
1647 result.addAttribute(
"var_shape", varShapeAttr);
1648 result.addAttribute(
"type", elementTypeAttr);
1649 result.addAttribute(
"initial_value", initialValue);
1662 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1666 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1672 for (
int i = 0, e = operands.size(); i != e; ++i) {
1674 if (!
shape.hasRank()) {
1679 outRank = std::max<int64_t>(outRank,
shape.getRank());
1682 outShape.resize(outRank, 1);
1684 for (
int i = 0, e = operands.size(); i != e; ++i) {
1686 auto rankDiff = outShape.size() -
shape.getRank();
1688 for (
size_t i = 0, e =
shape.getRank(); i < e; ++i) {
1689 auto dim1 = outShape[i + rankDiff];
1690 auto dim2 =
shape.getDimSize(i);
1692 const FailureOr<int64_t> maybeResolvedDim =
1694 if (failed(maybeResolvedDim))
1696 const int64_t resolvedDim = *maybeResolvedDim;
1697 outShape[i + rankDiff] = resolvedDim;
1704LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1705 MLIRContext *context, ::std::optional<Location> location,
1706 ArgMaxOp::Adaptor adaptor,
1709 IntegerAttr axis = adaptor.getProperties().axis;
1710 int32_t axisVal = axis.getValue().getSExtValue();
1712 if (!inputShape.hasRank()) {
1718 outShape.reserve(inputShape.getRank() - 1);
1719 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1722 outShape.push_back(inputShape.getDimSize(i));
1729LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1730 MLIRContext *context, ::std::optional<Location> location,
1731 RFFT2dOp::Adaptor adaptor,
1733 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1735 if (!inputShape.hasRank())
1739 outputShape.resize(3, ShapedType::kDynamic);
1740 outputShape[0] = inputShape.getDimSize(0);
1741 outputShape[1] = inputShape.getDimSize(1);
1742 int64_t inWidth = inputShape.getDimSize(2);
1746 if (inWidth != ShapedType::kDynamic)
1747 outputShape[2] = inWidth / 2 + 1;
1756 const llvm::StringRef dimName) {
1757 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1760 << dimName <<
" to be a power of two, got " << dimSize;
1765LogicalResult tosa::RFFT2dOp::verify() {
1766 const auto outputTypes = getResultTypes();
1768 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1770 const auto inputType =
1771 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1775 const int64_t height = inputType.getDimSize(1);
1776 if (ShapedType::isStatic(height) &&
1780 const int64_t width = inputType.getDimSize(2);
1781 if (ShapedType::isStatic(width) &&
1785 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1791 outputType.getShape().drop_back())))
1792 return emitOpError(
"expected batch and height dimensions of input/output "
1793 "to match, got input=")
1794 << inputType <<
" output=" << outputType;
1797 const int64_t outputWidth = outputType.getDimSize(2);
1798 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1799 (outputWidth != (width / 2) + 1))
1801 "expected output width to be equal to input_width / 2 + 1, got ")
1807LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1808 MLIRContext *context, ::std::optional<Location> location,
1809 FFT2dOp::Adaptor adaptor,
1811 inferredReturnShapes.push_back(
1813 inferredReturnShapes.push_back(
1818LogicalResult tosa::FFT2dOp::verify() {
1819 const auto inputRealType =
1820 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1821 const auto inputImagType =
1822 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1823 if (!inputRealType || !inputImagType)
1826 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1827 return ShapedType::isDynamic(a) ? a :
b;
1830 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1831 inputImagType.getDimSize(1));
1832 if (ShapedType::isStatic(height) &&
1836 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1837 inputImagType.getDimSize(2));
1838 if (ShapedType::isStatic(width) &&
1845LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1846 MLIRContext *context, ::std::optional<Location> location,
1847 ConcatOp::Adaptor adaptor,
1850 const Properties &prop = adaptor.getProperties();
1851 int32_t axis = prop.axis.getValue().getSExtValue();
1853 bool hasRankedInput =
false;
1854 for (
auto operand : adaptor.getOperands()) {
1856 if (!operandShape.hasRank())
1860 if (!hasRankedInput)
1861 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1864 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1865 if (i == axis || operandShape.isDynamicDim(i))
1867 if (outputShape[i] == ShapedType::kDynamic)
1868 outputShape[i] = operandShape.getDimSize(i);
1869 if (outputShape[i] != operandShape.getDimSize(i))
1871 "Cannot concat tensors with different sizes"
1872 " on the non-axis dimension ",
1876 hasRankedInput =
true;
1879 if (adaptor.getInput1().empty())
1883 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1884 if (!hasRankedInput) {
1891 for (
auto operand : adaptor.getOperands()) {
1896 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1897 concatDimSize = ShapedType::kDynamic;
1901 concatDimSize += operandShape.getDimSize(axis);
1904 outputShape[axis] = concatDimSize;
1910LogicalResult tosa::ConcatOp::verify() {
1912 auto outType = getOutput().getType();
1916 if (inputList.empty())
1919 if (!llvm::all_of(inputList, [&](
auto input) {
1921 *
this, input.getType(), outType));
1926 const int32_t axis = getAxis();
1928 for (
const auto &input : inputList) {
1929 const Type inputType = input.getType();
1931 if (currShape.hasRank()) {
1932 firstRankedInputShape = currShape;
1934 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1935 return emitOpError(
"expect axis to be within range 0 < axis < "
1936 "rank(input1[firstRankedTensorIdx]), got ")
1942 const auto allOperandsHasRank = [](
const Value input) {
1945 if (llvm::all_of(inputList, allOperandsHasRank)) {
1948 for (
const auto &[
index, input] : llvm::enumerate(inputList.drop_front())) {
1950 const int64_t inputRank = inputShape.getRank();
1951 const size_t operandNum =
index + 1;
1954 if (inputRank != firstInputRank)
1956 "expect all operands to have the same rank, but got ")
1957 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1961 for (
int i = 0; i < inputRank; i++) {
1962 const int64_t inputDim = inputShape.getDimSize(i);
1964 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1965 inputShape.isDynamicDim(i))
1967 if (inputDim != firstInputDim)
1968 return emitOpError(
"expect all operand shapes to have the same sizes "
1969 "on non-axis dimensions, but got ")
1970 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1971 <<
" on operands 0 and " << operandNum;
1976 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1977 return emitOpError(
"expect output rank to match inputs rank, got ")
1978 << outputShape.getRank() <<
" vs " << firstInputRank;
1982 for (
const auto &input : inputList) {
1984 if (inputShape.isDynamicDim(axis)) {
1989 axisSum += inputShape.getDimSize(axis);
1992 if (axisSum >= 0 && outputShape.hasRank() &&
1993 !outputShape.isDynamicDim(axis) &&
1994 axisSum != outputShape.getDimSize(axis))
1995 return emitOpError(
"requires sum of axis dimensions of input1 "
1996 "equal to output axis dimension, got ")
1997 << axisSum <<
" and " << outputShape.getDimSize(axis);
2003LogicalResult tosa::EqualOp::inferReturnTypeComponents(
2004 MLIRContext *context, ::std::optional<Location> location,
2008 auto elementType = IntegerType::get(context, 1);
2021 if (l.size() != r.size() || l.size() != 1)
2026LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
2027 MLIRContext *context, ::std::optional<Location> location,
2028 MatMulOp::Adaptor adaptor,
2035 outShape.resize(3, ShapedType::kDynamic);
2037 if (lhsShape.hasRank()) {
2038 outShape[0] = lhsShape.getDimSize(0);
2039 outShape[1] = lhsShape.getDimSize(1);
2042 if (rhsShape.hasRank()) {
2043 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2045 outShape[2] = rhsShape.getDimSize(2);
2052LogicalResult MatMulOp::verify() {
2055 const Type aElementType = aShape.getElementType();
2056 const Type bElementType = bShape.getElementType();
2058 const auto aQuantizedEType =
2059 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2060 const auto bQuantizedEType =
2061 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2063 if (aQuantizedEType || bQuantizedEType) {
2064 if (!aQuantizedEType || !bQuantizedEType) {
2065 return emitOpError(
"expect operands to be both quantized or both not "
2067 << aElementType <<
" and " << bElementType;
2070 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2071 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2072 if (aQuantWidth != bQuantWidth) {
2073 return emitOpError(
"expect quantized operands to have same widths, got ")
2074 << aQuantWidth <<
" and " << bQuantWidth;
2081 if (aEType != aZpEType)
2082 return emitOpError(
"expect input a and a_zp have the same "
2083 "element type, got ")
2084 << aEType <<
" and " << aZpEType;
2088 if (bEType != bZpEType)
2089 return emitOpError(
"expect input b and b_zp have the same "
2090 "element type, got ")
2091 << bEType <<
" and " << bZpEType;
2093 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2094 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2097 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2098 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2102 int64_t N = ShapedType::kDynamic;
2103 int64_t H = ShapedType::kDynamic;
2107 if (aShape.hasRank()) {
2108 N = aShape.getDimSize(0);
2109 H = aShape.getDimSize(1);
2110 C = aShape.getDimSize(2);
2113 if (bShape.hasRank()) {
2119 W = bShape.getDimSize(2);
2123 const auto outputType = cast<ShapedType>(getResult().
getType());
2124 if (outputType.hasRank() &&
2129 opError <<
" to be compatible with expected output shape ";
2137LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2138 MLIRContext *context, ::std::optional<Location> location,
2139 MatmulTBlockScaledOp::Adaptor adaptor,
2143 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2144 if (aDataShape.hasRank()) {
2145 outShape[0] = aDataShape.getDimSize(0);
2146 outShape[1] = aDataShape.getDimSize(1);
2149 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2150 if (aScaleShape.hasRank()) {
2151 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2153 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2158 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2159 if (bDataShape.hasRank()) {
2160 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2161 if (bDataBatchSize != 1)
2163 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2164 outShape[2] = bDataShape.getDimSize(1);
2167 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2168 if (bScaleShape.hasRank()) {
2169 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2170 if (bScaleBatchSize != 1)
2172 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2173 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2181LogicalResult MatmulTBlockScaledOp::verify() {
2183 const Type aDataType = getAData().getType();
2184 const Type bDataType = getBData().getType();
2190 int64_t N = ShapedType::kDynamic;
2191 int64_t D = ShapedType::kDynamic;
2192 int64_t H = ShapedType::kDynamic;
2195 int64_t multiplesOfC = ShapedType::kDynamic;
2207 "a_scale",
"batch")) ||
2209 "a_scale",
"height")))
2217 "b_data",
"batch")) ||
2219 "b_data",
"channels")))
2227 "b_scale",
"batch")) ||
2229 "b_scale",
"width")) ||
2237 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2238 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2240 << D <<
" vs N=" << N;
2243 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2244 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2245 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2246 if (ShapedType::isStatic(C) && C % blockSize != 0)
2247 return emitOpError(
"expect C to be a multiple of block size, got C=")
2248 <<
C <<
", block_size=" << blockSize;
2251 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2252 multiplesOfC != C / blockSize)
2254 "expect scale operands dimension 2 to equal C/block_size (")
2255 <<
C <<
"/" << blockSize <<
")" <<
", got " << multiplesOfC;
2258 N = ShapedType::isDynamic(N) ? D : N;
2260 const auto outputType = cast<ShapedType>(getResult().
getType());
2261 if (outputType.hasRank() &&
2266 opError <<
" to be compatible with expected output shape ";
2274LogicalResult tosa::PadOp::inferReturnTypeComponents(
2275 MLIRContext *context, ::std::optional<Location> location,
2276 PadOp::Adaptor adaptor,
2278 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2280 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2285 if (!inputShape.hasRank()) {
2286 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2295 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2300 outputShape.reserve(inputShape.getRank());
2301 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2302 if (inputShape.isDynamicDim(i)) {
2303 outputShape.push_back(ShapedType::kDynamic);
2306 auto padFront = paddingValues[i * 2];
2307 auto padBack = paddingValues[i * 2 + 1];
2308 if (padFront < 0 || padBack < 0) {
2310 outputShape.push_back(ShapedType::kDynamic);
2314 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2321LogicalResult tosa::PadOp::verify() {
2328 if (
auto padConst = getPadConst()) {
2336 RankedTensorType inputType =
2337 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2338 RankedTensorType outputType =
2339 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2340 if (!inputType || !outputType)
2347 auto inputRank = inputType.getRank();
2352 auto paddingValues = paddingAttr.getValues<APInt>();
2353 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2354 return emitOpError() <<
"padding tensor must have " << inputRank
2355 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2356 << paddingValues.size();
2358 auto inputShape = inputType.getShape();
2359 auto outputShape = outputType.getShape();
2361 for (
int64_t i = 0; i < inputRank; ++i) {
2362 int64_t padStart = paddingValues[i * 2].getSExtValue();
2363 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2365 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2367 <<
"invalid padding values at dimension " << i
2368 <<
": values must be non-negative or -1 for dynamic padding, got ["
2369 << padStart <<
", " << padEnd <<
"]";
2373 if (inputShape[i] == ShapedType::kDynamic ||
2374 outputShape[i] == ShapedType::kDynamic)
2377 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2378 return emitOpError() <<
"mismatch in output shape at dimension " << i
2379 <<
": expected " << inputShape[i] <<
" + "
2380 << padStart <<
" + " << padEnd <<
" = "
2381 << (inputShape[i] + padStart + padEnd)
2382 <<
", but got " << outputShape[i];
2389LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2390 MLIRContext *context, ::std::optional<Location> location,
2391 SliceOp::Adaptor adaptor,
2400 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2408 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2411 if (inputShape.hasRank()) {
2412 for (
size_t i = 0; i < size.size(); i++) {
2413 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2414 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2415 start[i] < inputShape.getDimSize(i))) {
2417 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2420 outputShape[i] = size[i];
2424 if (size[i] == -1) {
2425 outputShape[i] = inputShape.getDimSize(i) - start[i];
2426 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2428 outputShape[i] = size[i];
2440LogicalResult tosa::SliceOp::verify() {
2441 const Value input = getInput1();
2442 const Value output = getOutput();
2448 const Value start = getStart();
2449 const Value size = getSize();
2453 if (inputShape.hasRank()) {
2454 const auto inputRank = inputShape.getRank();
2455 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2457 "expect input1 and output to have the same ranks, got ")
2458 << inputRank <<
" and " << outputShape.getRank();
2460 const auto startShapeRank =
2461 llvm::cast<tosa::shapeType>(start.
getType()).getRank();
2462 if (inputRank != startShapeRank)
2463 return emitOpError(
"length of start is not equal to rank of input shape");
2465 const auto sizeShapeRank =
2466 llvm::cast<tosa::shapeType>(size.
getType()).getRank();
2467 if (inputRank != sizeShapeRank)
2468 return emitOpError(
"length of size is not equal to rank of input shape");
2473 if (startValues.size()) {
2474 if (llvm::any_of(startValues, [](
const int64_t v) {
2477 return emitOpError(
"start values must be non-negative, got [")
2478 << startValues <<
"]";
2485 if (llvm::any_of(sizeValues, [](
const int64_t v) {
2488 return emitOpError(
"size values must be > 0, got [") << sizeValues <<
"]";
2489 if (outputShape.hasRank()) {
2491 outputShape.getDims(outputDims);
2492 const bool hasNoInferableDims = llvm::all_of(
2494 if (hasNoInferableDims &&
2496 return emitOpError(
"expected output shape to match size values, got ")
2497 << output.
getType() <<
" vs [" << sizeValues <<
"]";
2500 if (inputShape.hasRank() && startValues.size()) {
2502 inputShape.getDims(inputDims);
2503 for (
const auto &[
index, vals] :
2504 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2505 const auto &[start, size, inputDim] = vals;
2507 ShapedType::isDynamic(inputDim))
2509 if (start + size > inputDim)
2510 return emitOpError(
"start + size must be less than or equal to input "
2511 "dimension size, got start=")
2512 << start <<
", size=" << size
2513 <<
" vs input dim size=" << inputDim <<
" at dimension "
2521LogicalResult tosa::MulOp::inferReturnTypeComponents(
2522 MLIRContext *context, ::std::optional<Location> location,
2537LogicalResult tosa::MulOp::verify() {
2538 const Value output = getOutput();
2543 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2544 IntegerType lhsIntType =
2546 IntegerType rhsIntType =
2548 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2549 return emitOpError(
"requires the same element type for all operands");
2554 if (lhsIntType.getWidth() > resIntType.getWidth())
2555 return emitOpError(
"invalid data type size for operands or result");
2560 for (
int i = 0; i < 2; ++i) {
2563 "requires the same element type for all operands and results");
2567 ElementsAttr shiftElem;
2569 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2571 return emitOpError() <<
"require shift to be 0 for float type";
2579 TypeRange operandTypes = getOperandTypes();
2580 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2581 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2583 const bool aHasRank = aType.hasRank();
2584 const bool bHasRank = bType.hasRank();
2586 bool hasExpectedOutputShape =
false;
2589 if (aHasRank && bHasRank) {
2590 const int64_t aRank = aType.getRank();
2591 const int64_t bRank = bType.getRank();
2593 return emitOpError(
"a and b operands don't have matching ranks, got ")
2594 << aRank <<
" and " << bRank;
2598 aType.getShape(), bType.getShape(), expectedOutputShape))
2599 return emitOpError(
"a and b operands don't have broadcast-compatible "
2601 << aType <<
" and " << bType;
2602 hasExpectedOutputShape =
true;
2605 ShapedType resultType = cast<ShapedType>(output.
getType());
2606 if (!resultType.hasRank())
2609 const int64_t resultRank = resultType.getRank();
2610 if (aHasRank && resultRank != aType.getRank())
2611 return emitOpError(
"result type has different rank than a, got ")
2612 << resultRank <<
" vs " << aType.getRank();
2613 if (bHasRank && resultRank != bType.getRank())
2614 return emitOpError(
"result type has different rank than b, got ")
2615 << resultRank <<
" vs " << bType.getRank();
2617 if (hasExpectedOutputShape &&
2619 expectedOutputShape)))
2625LogicalResult tosa::TableOp::inferReturnTypeComponents(
2626 MLIRContext *context, ::std::optional<Location> location,
2627 TableOp::Adaptor adaptor,
2629 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2631 if (!inputShape.hasRank()) {
2636 inferredReturnShapes.resize(1);
2637 inputShape.getDims(inferredReturnShapes[0]);
2641LogicalResult tosa::TableOp::verify() {
2642 const TensorType inputType = getInput1().getType();
2643 const TensorType outputType = getOutput().getType();
2652 auto inputDims = inputType.
getShape();
2653 auto outputDims = outputType.
getShape();
2654 for (
auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2656 auto [inputDim, outputDim] = it.value();
2657 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2658 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2659 <<
" doesn't match dim(input, " << dim
2660 <<
") = " << inputDim;
2673 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2674 [](
const APInt &val) { return val.getSExtValue(); });
2678LogicalResult tosa::TileOp::inferReturnTypeComponents(
2679 MLIRContext *context, ::std::optional<Location> location,
2680 TileOp::Adaptor adaptor,
2687 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2694 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2696 if (!inputShape.hasRank()) {
2697 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2698 inferredReturnShapes.push_back(
2702 if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2706 outputShape.reserve(multiples.size());
2707 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2708 if (multiples[i] == ShapedType::kDynamic) {
2709 outputShape.push_back(ShapedType::kDynamic);
2711 int64_t dim = inputShape.getDimSize(i);
2712 if (dim != ShapedType::kDynamic)
2713 dim *= multiples[i];
2714 outputShape.push_back(dim);
2722LogicalResult tosa::TileOp::verify() {
2728 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2729 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2731 shapeType multiplesType =
2732 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2734 auto multiplesRank = multiplesType.getRank();
2736 if (inputType.hasRank()) {
2737 if (inputType.getRank() != multiplesRank)
2738 return emitOpError(
"expect 'multiples' to have rank ")
2739 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2740 if (outputType.hasRank() &&
2744 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2745 return emitOpError(
"expect 'multiples' array to have length ")
2746 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2749 if (getConstantMultiples(multiples).succeeded() &&
2750 llvm::any_of(multiples, [](
int64_t v) {
return v <= 0 && v != -1; }))
2752 "expect element of 'multiples' to be positive integer or -1.");
2758 if (l.size() != r.size() || l.size() != 1)
2763LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2764 MLIRContext *context, ::std::optional<Location> location,
2765 ReshapeOp::Adaptor adaptor,
2767 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2772 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2781 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2782 inferredReturnShapes.push_back(
2790 int64_t numElements = inputShape.getNumElements();
2792 for (
auto val : newShapeValue) {
2793 if (ShapedType::isStatic(val)) {
2799 for (
auto &val : newShapeValue) {
2800 if (ShapedType::isDynamic(val))
2801 val = numElements / staticMul;
2804 inferredReturnShapes.push_back(
2809llvm::LogicalResult tosa::ReshapeOp::verify() {
2815 TensorType inputType = getInput1().getType();
2820 return mlir::success();
2824 if (missingDims > 1)
2825 return emitOpError() <<
"expected at most one target dimension to be "
2828 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2832 if ((
int64_t)shapeValues.size() != outputType.getRank())
2833 return emitOpError() <<
"new shape does not match result rank";
2835 for (
auto [newShapeDim, outputShapeDim] :
2836 zip(shapeValues, outputType.getShape())) {
2838 newShapeDim != ShapedType::kDynamic &&
2839 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2840 return emitOpError() <<
"new shape is inconsistent with result shape";
2843 return emitOpError() <<
"new shape has invalid tensor dimension size "
2847 if (inputType.hasStaticShape()) {
2848 int64_t inputElementsNum = inputType.getNumElements();
2849 if (outputType.hasStaticShape()) {
2850 int64_t outputElementsNum = outputType.getNumElements();
2851 if (inputElementsNum != outputElementsNum) {
2852 return emitOpError() <<
"cannot reshape " << inputElementsNum
2853 <<
" elements into " << outputElementsNum;
2859 return (dim > 0) ?
acc * dim :
acc;
2861 bool isStaticNewShape =
2862 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
2863 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2864 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2865 return emitOpError() <<
"cannot reshape " << inputElementsNum
2866 <<
" elements into " << newShapeElementsNum;
2870 return mlir::success();
2873bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(
TypeRange l,
2875 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2883LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2884 MLIRContext *context, ::std::optional<Location> location,
2885 ReshapeBlockScaledOp::Adaptor adaptor,
2888 const auto numInputs = adaptor.getInput().size();
2889 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2892 const auto newShape = adaptor.getNewValueShape();
2894 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2903 const uint32_t blockSize =
2904 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2907 if (numInputs == 2) {
2908 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2909 if (ShapedType::isStatic(newScaleShapeValue.back()))
2910 newScaleShapeValue.back() /= blockSize;
2913 inferredReturnShapes.push_back(
2915 if (numInputs == 2) {
2917 for (
size_t idx = 0; idx < newShapeValue.size(); idx++) {
2918 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2919 newScaleShapeValue[idx] = newShapeValue[idx];
2920 if (idx == (newShapeValue.size() - 1))
2921 newScaleShapeValue[idx] /= blockSize;
2932llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2936 if (inputList.size() == 0)
2937 return emitOpError(
"requires at least one input");
2939 if (inputList.size() > 2)
2940 return emitOpError(
"requires at most two inputs");
2942 if (inputList.size() != outputList.size())
2943 return emitOpError(
"requires number of results to match inputs");
2951 const auto inputType = llvm::cast<ShapedType>(inputList[0].
getType());
2952 if (!inputType.hasRank())
2954 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2956 if (inputList.size() == 2) {
2957 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2958 return emitOpError(
"expect block size to be 32, got ") << blockSize;
2959 if (llvm::any_of(inputList, [](
Value v) {
2960 const auto input = cast<ShapedType>(v.
getType());
2961 return input.hasRank() && input.getRank() == 0;
2964 "requires all input shapes have a rank greater than 0");
2965 if (llvm::any_of(outputList, [](
Value v) {
2966 const auto output = cast<ShapedType>(v.
getType());
2967 return output.hasRank() && output.getRank() == 0;
2970 "requires all result shapes have a rank greater than 0");
2978 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].
getType());
2979 if (inputScaleType.hasRank()) {
2980 if (inputType.getRank() != inputScaleType.getRank())
2981 return emitOpError(
"input shapes do not have same rank");
2984 for (
auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2985 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2986 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2987 if (ShapedType::isStatic(inputValueDim) &&
2988 ShapedType::isStatic(inputScaleDim) &&
2989 inputValueDim != inputScaleDim)
2990 return emitOpError(
"input shapes for data and scale do not match on "
2997 inputType.getDimSize(inputType.getRank() - 1);
2998 if (ShapedType::isStatic(lastValueDim)) {
2999 if (lastValueDim % blockSize != 0)
3000 return emitOpError(
"expect last dimension of input_data (")
3001 << lastValueDim <<
") to be divisible by block_size ("
3002 << blockSize <<
")";
3005 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
3007 if (ShapedType::isStatic(lastScaleDim) &&
3008 lastScaleDim != lastValueDim / blockSize)
3009 return emitOpError(
"expect last dimension of scale_data (")
3010 << lastScaleDim <<
") to be " << lastValueDim <<
"/"
3015 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
3016 return emitOpError(
"expect block size to be 1, got ") << blockSize;
3024 return mlir::success();
3027 if (inputList.size() == 2) {
3028 if (
static_cast<int64_t>(shapeValues.size()) == 0)
3029 return emitOpError(
"requires new shape to have a rank greater than 0");
3031 const int64_t lastShapeDim = shapeValues.back();
3032 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
3033 return emitOpError(
"expect last dimension of new shape (")
3034 << lastShapeDim <<
") to be divisible by block_size (" << blockSize
3038 const auto outputType = llvm::cast<ShapedType>(outputList[0].
getType());
3039 if (!outputType.hasRank())
3042 if (
static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
3043 return emitOpError() <<
"result does not match new shape rank";
3045 for (
auto [newShapeDim, outputShapeDim] :
3046 zip(shapeValues, outputType.getShape())) {
3047 if (ShapedType::isStatic(newShapeDim) &&
3048 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3049 return emitOpError() <<
"result shape is inconsistent with new shape";
3052 if (outputList.size() == 2) {
3056 scaleShapeValues.back() /= blockSize;
3058 const auto outputScaleType =
3059 llvm::cast<ShapedType>(outputList[1].
getType());
3060 if (outputScaleType.hasRank()) {
3061 if ((
int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3062 return emitOpError() <<
"result scale does not match new shape rank";
3064 for (
auto [newScaleShapeDim, outputScaleShapeDim] :
3065 zip(scaleShapeValues, outputScaleType.getShape())) {
3066 if (ShapedType::isStatic(newScaleShapeDim) &&
3067 ShapedType::isStatic(outputScaleShapeDim) &&
3068 newScaleShapeDim != outputScaleShapeDim)
3070 <<
"result scale shape is inconsistent with new shape";
3075 if (inputType.hasStaticShape()) {
3076 int64_t inputElementsNum = inputType.getNumElements();
3077 if (outputType.hasStaticShape()) {
3078 int64_t outputElementsNum = outputType.getNumElements();
3079 if (inputElementsNum != outputElementsNum) {
3080 return emitOpError() <<
"cannot reshape " << inputElementsNum
3081 <<
" elements into " << outputElementsNum;
3087 return (dim > 0) ?
acc * dim :
acc;
3089 bool isStaticNewShape =
3090 llvm::all_of(shapeValues, [](
int64_t s) {
return s > 0; });
3091 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3092 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3093 return emitOpError() <<
"cannot reshape " << inputElementsNum
3094 <<
" elements into " << newShapeElementsNum;
3098 return mlir::success();
3105 ElementsAttr zpAttr;
3110 Type zpElemType = zpAttr.getElementType();
3112 if (llvm::isa<FloatType>(zpElemType)) {
3113 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3120 if (llvm::isa<IntegerType>(zpElemType)) {
3122 return zpAttr.getValues<APInt>()[0].getSExtValue();
3123 return zpAttr.getValues<APInt>()[0].getZExtValue();
3130template <
typename T>
3132 const std::string &operand) {
3135 if (!zpElemType.
isInteger(8) && zp != 0) {
3137 std::string lower = operand;
3138 llvm::transform(lower, lower.begin(), ::tolower);
3139 return op.emitOpError()
3140 << lower <<
" zero point must be zero for non-int8 integer types";
3148 const std::string &operand) {
3149 bool isInputZp = (operand ==
"Input");
3151 bool tensorUnsigned =
3152 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3153 StringRef tensorName = isInputZp ?
"input" :
"output";
3159 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
3160 return op.emitOpError()
3161 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
3163 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
3164 return op.emitOpError() <<
"expect " << tensorName
3165 <<
"_zp of 0 or 32768 for unsigned int16 "
3166 << tensorName <<
", got " << zp;
3173#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3174 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3175 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3177 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3178 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3199#undef ZERO_POINT_HELPER
3201LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3202 MLIRContext *context, ::std::optional<Location> location,
3203 TransposeOp::Adaptor adaptor,
3205 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3214 const auto inputRank = inputShape.
getRank();
3218 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
3224 if (inputRank == 0) {
3230 bool allTheSame =
true;
3231 for (
int i = 1, s = inputRank; i < s; i++) {
3241 outputShape.resize(inputRank, inputShape.
getDimSize(0));
3246 outputShape.resize(inputRank, ShapedType::kDynamic);
3249 if (llvm::any_of(adaptor.getPerms(),
3250 [inputRank](
const auto i) { return i >= inputRank; }))
3253 outputShape.reserve(inputRank);
3254 for (
int i = 0, s = inputRank; i < s; i++) {
3255 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
3262LogicalResult tosa::TransposeOp::verify() {
3274 if (inputShape.hasRank() &&
3275 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
3276 return emitOpError() <<
"expected perms attribute to have size "
3277 << inputShape.getRank()
3278 <<
" (input rank) but got size "
3279 << constantPerms.size();
3281 if (inputShape.hasRank() && outputShape.hasRank() &&
3282 inputShape.getRank() != outputShape.getRank())
3284 <<
"expected input tensor rank to equal result tensor rank";
3286 if (outputShape.hasRank() &&
3287 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
3288 return emitOpError() <<
"expected perms attribute to have size "
3289 << outputShape.getRank()
3290 <<
" (output rank) but got size "
3291 << constantPerms.size();
3293 if (!llvm::all_of(constantPerms,
3294 [&constantPerms](int32_t s) {
3296 static_cast<size_t>(s) < constantPerms.size();
3299 constantPerms, [](int32_t v) ->
int64_t {
return v; })))
3300 return emitOpError() <<
"expected valid permutation indices";
3303 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3304 inputShape.getNumElements() != outputShape.getNumElements())
3305 return emitOpError() <<
"expected input1 and output to have same numbers "
3307 << inputShape.getNumElements() <<
" and "
3308 << outputShape.getNumElements();
3312 if (inputShape.hasRank() && outputShape.hasRank()) {
3313 for (
auto i = 0; i < outputShape.getRank(); i++) {
3314 if (inputShape.isDynamicDim(constantPerms[i]) ||
3315 outputShape.isDynamicDim(i))
3318 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3320 <<
"expected output tensor dim " << i <<
" to match "
3321 <<
"input dim " << constantPerms[i] <<
" with value of "
3322 << inputShape.getDimSize(constantPerms[i]);
3329LogicalResult TransposeOp::reifyResultShapes(
3332 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3334 Value input = getInput1();
3335 auto inputType = cast<TensorType>(input.
getType());
3337 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3338 for (
auto dim : transposePerms) {
3339 int32_t dimInInput = transposePerms[dim];
3340 if (inputType.isDynamicDim(dimInInput))
3342 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3346 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
3349 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3353LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3354 MLIRContext *context, ::std::optional<Location> location,
3355 GatherOp::Adaptor adaptor,
3356 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3357 llvm::SmallVector<int64_t> outputShape;
3358 outputShape.resize(3, ShapedType::kDynamic);
3360 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3361 if (valuesShape.hasRank()) {
3362 outputShape[0] = valuesShape.getDimSize(0);
3363 outputShape[2] = valuesShape.getDimSize(2);
3366 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3367 if (indicesShape.hasRank()) {
3368 if (outputShape[0] == ShapedType::kDynamic)
3369 outputShape[0] = indicesShape.getDimSize(0);
3370 if (outputShape[1] == ShapedType::kDynamic)
3371 outputShape[1] = indicesShape.getDimSize(1);
3374 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3378LogicalResult tosa::RowGatherOp::inferReturnTypeComponents(
3379 MLIRContext *context, ::std::optional<Location> location,
3380 RowGatherOp::Adaptor adaptor,
3381 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3382 llvm::SmallVector<int64_t> outputShape;
3383 outputShape.resize(3, ShapedType::kDynamic);
3385 const ShapeAdaptor valuesShape(adaptor.getValues().getType());
3386 if (valuesShape.hasRank()) {
3387 outputShape[0] = valuesShape.getDimSize(0);
3388 outputShape[2] = valuesShape.getDimSize(2);
3391 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3392 if (indicesShape.hasRank()) {
3393 if (outputShape[0] == ShapedType::kDynamic)
3394 outputShape[0] = indicesShape.getDimSize(0);
3396 const FailureOr<int32_t> maybeRowCount =
3398 if (succeeded(maybeRowCount)) {
3399 const int64_t indicesW = indicesShape.getDimSize(1);
3400 if (ShapedType::isStatic(indicesW))
3401 outputShape[1] = indicesW * maybeRowCount.value();
3405 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3409LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3410 MLIRContext *context, ::std::optional<Location> location,
3411 RowGatherBlockScaledOp::Adaptor adaptor,
3412 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3413 const auto values = adaptor.getValues();
3417 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3418 const ShapeAdaptor valuesShape(values.front().getType());
3419 if (valuesShape.hasRank()) {
3420 dataShape[0] = valuesShape.getDimSize(0);
3421 dataShape[2] = valuesShape.getDimSize(2);
3424 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3425 if (indicesShape.hasRank()) {
3426 if (dataShape[0] == ShapedType::kDynamic)
3427 dataShape[0] = indicesShape.getDimSize(0);
3431 succeeded(rowCount) && rowCount.value() > 0) {
3432 const int64_t indicesW = indicesShape.getDimSize(1);
3433 if (ShapedType::isStatic(indicesW))
3434 dataShape[1] = indicesW * rowCount.value();
3438 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3439 if (values.size() == 1)
3442 SmallVector<int64_t> scaleShape = dataShape;
3443 const uint32_t blockSize =
3444 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3445 if (ShapedType::isStatic(dataShape[2]))
3446 scaleShape[2] = dataShape[2] / blockSize;
3448 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3452LogicalResult tosa::GatherOp::verify() {
3459 const ShapeAdaptor valuesShape(getValues().
getType());
3461 const ShapeAdaptor outputShape(getOutput().
getType());
3463 int64_t n = ShapedType::kDynamic;
3464 int64_t w = ShapedType::kDynamic;
3465 int64_t c = ShapedType::kDynamic;
3467 if (valuesShape.hasRank()) {
3468 n = valuesShape.getDimSize(0);
3469 c = valuesShape.getDimSize(2);
3471 if (indicesShape.hasRank()) {
3472 const int64_t indicesN = indicesShape.getDimSize(0);
3473 w = indicesShape.getDimSize(1);
3474 if (n == ShapedType::kDynamic)
3476 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3477 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3478 <<
", got " << indicesN;
3480 if (outputShape.hasRank()) {
3481 const int64_t outputN = outputShape.getDimSize(0);
3482 const int64_t outputW = outputShape.getDimSize(1);
3483 const int64_t outputC = outputShape.getDimSize(2);
3484 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3486 return emitOpError() <<
"requires output dimension 0 to have size " << n
3487 <<
", got " << outputN;
3489 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3491 return emitOpError() <<
"requires output dimension 1 to have size " << w
3492 <<
", got " << outputW;
3493 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3495 return emitOpError() <<
"requires output dimension 2 to have size " << c
3496 <<
", got " << outputC;
3501LogicalResult tosa::RowGatherOp::verify() {
3506 const FailureOr<int32_t> maybeRowCount =
3508 if (succeeded(maybeRowCount) && maybeRowCount.value() <= 0)
3509 return emitOpError() <<
"requires row_count to be > 0, got "
3510 << maybeRowCount.value();
3512 int64_t n = ShapedType::kDynamic;
3513 int64_t c = ShapedType::kDynamic;
3514 int64_t w = ShapedType::kDynamic;
3516 const ShapeAdaptor valuesShape(getValues().
getType());
3517 if (valuesShape.hasRank()) {
3518 n = valuesShape.getDimSize(0);
3519 c = valuesShape.getDimSize(2);
3523 if (indicesShape.hasRank()) {
3525 "indices",
"batch")))
3527 w = indicesShape.getDimSize(1);
3530 const ShapeAdaptor outputShape(getOutput().
getType());
3531 if (outputShape.hasRank()) {
3533 "output",
"batch")) ||
3535 "output",
"channels")))
3538 if (succeeded(maybeRowCount) && maybeRowCount.value() > 0 &&
3539 ShapedType::isStatic(w)) {
3540 const int64_t expectedOutputRows = w * maybeRowCount.value();
3541 if (ShapedType::isStatic(outputShape.getDimSize(1)) &&
3542 outputShape.getDimSize(1) != expectedOutputRows)
3544 <<
"requires output dimension to be equal to "
3545 "indices[1]*row_count ("
3546 << expectedOutputRows <<
"), got " << outputShape.getDimSize(1);
3553LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3554 const OperandRange values = getValues();
3555 const ResultRange output = getOutput();
3556 if (values.empty() || values.size() > 2)
3558 <<
"expects values tensor list length to be 1 or 2, got "
3560 if (output.size() != values.size())
3562 <<
"expects output tensor list length to match values tensor list "
3564 << output.size() <<
" results for " << values.size()
3565 <<
" input tensors";
3567 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
3568 if (values.size() == 1 && blockSize != 1)
3570 <<
"requires block_size to be BLOCK_SIZE_1 when values tensor list "
3572 if (values.size() == 2 && blockSize == 1)
3574 <<
"requires block_size to not be BLOCK_SIZE_1 when values tensor "
3578 output[0].
getType(),
"values[0]",
3583 "values[1]",
"output[1]")))
3587 succeeded(rowCount) && rowCount.value() <= 0)
3588 return emitOpError() <<
"requires row_count to be > 0, got "
3589 << rowCount.value();
3591 int64_t n = ShapedType::kDynamic;
3592 int64_t k = ShapedType::kDynamic;
3593 int64_t c = ShapedType::kDynamic;
3594 int64_t w = ShapedType::kDynamic;
3595 int64_t multiplesOfC = ShapedType::kDynamic;
3597 const ShapeAdaptor valuesDataShape(values[0].
getType());
3598 if (valuesDataShape.hasRank()) {
3599 n = valuesDataShape.getDimSize(0);
3600 k = valuesDataShape.getDimSize(1);
3601 c = valuesDataShape.getDimSize(2);
3604 if (ShapedType::isStatic(c) && c % blockSize != 0)
3605 return emitOpError() <<
"expects channels of values[0] (" << c
3606 <<
") to be divisible by block_size (" << blockSize
3610 if (indicesShape.hasRank()) {
3612 "indices",
"batch")))
3614 w = indicesShape.getDimSize(1);
3617 const ShapeAdaptor outputDataShape(output[0].
getType());
3618 if (outputDataShape.hasRank()) {
3620 "output[0]",
"batch")) ||
3622 "output[0]",
"channels")))
3626 succeeded(rowCount) && rowCount.value() > 0 &&
3627 ShapedType::isStatic(w)) {
3628 const int64_t expectedOutputRows = w * rowCount.value();
3629 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3630 outputDataShape.getDimSize(1) != expectedOutputRows)
3631 return emitOpError() <<
"requires output[0] dimension 1 to have size "
3632 << expectedOutputRows <<
", got "
3633 << outputDataShape.getDimSize(1);
3637 if (values.size() == 2) {
3638 const ShapeAdaptor valuesScaleShape(values[1].
getType());
3639 if (valuesScaleShape.hasRank()) {
3641 "values[1]",
"batch")) ||
3643 "values[1]",
"rows")))
3645 multiplesOfC = valuesScaleShape.getDimSize(2);
3648 const ShapeAdaptor outputScaleShape(output[1].
getType());
3649 if (outputScaleShape.hasRank()) {
3651 "output[1]",
"batch")))
3655 succeeded(rowCount) && rowCount.value() > 0 &&
3656 ShapedType::isStatic(w)) {
3657 const int64_t expectedOutputRows = w * rowCount.value();
3658 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3659 outputScaleShape.getDimSize(1) != expectedOutputRows)
3660 return emitOpError() <<
"requires output[1] dimension 1 to have size "
3661 << expectedOutputRows <<
", got "
3662 << outputScaleShape.getDimSize(1);
3665 if (ShapedType::isDynamic(multiplesOfC))
3666 multiplesOfC = outputScaleShape.getDimSize(2);
3667 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3668 multiplesOfC != outputScaleShape.getDimSize(2))
3670 <<
"expected channels of output[1] to match size "
3671 << multiplesOfC <<
", got " << outputScaleShape.getDimSize(2);
3674 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3675 multiplesOfC != c / blockSize)
3677 <<
"expects channels of scale tensors to equal C/block_size (" << c
3678 <<
"/" << blockSize <<
"), got " << multiplesOfC;
3684LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3685 MLIRContext *context, ::std::optional<Location> location,
3686 ResizeOp::Adaptor adaptor,
3687 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3688 llvm::SmallVector<int64_t, 4> outputShape;
3689 outputShape.resize(4, ShapedType::kDynamic);
3691 ShapeAdaptor inputShape(adaptor.getInput().getType());
3692 if (!inputShape.hasRank())
3695 outputShape[0] = inputShape.getDimSize(0);
3696 outputShape[3] = inputShape.getDimSize(3);
3697 int64_t inputHeight = inputShape.getDimSize(1);
3698 int64_t inputWidth = inputShape.getDimSize(2);
3700 if ((inputHeight == ShapedType::kDynamic) ||
3701 (inputWidth == ShapedType::kDynamic))
3704 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3715 const int64_t outputHeight =
3716 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3720 const int64_t outputWidth =
3721 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3725 if (outputHeight < 0 || outputWidth < 0) {
3728 "calculated output height and width must be non-negative, "
3730 outputHeight,
", width = ", outputWidth);
3733 outputShape[1] = outputHeight;
3734 outputShape[2] = outputWidth;
3735 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3739LogicalResult tosa::ResizeOp::verify() {
3740 const Value input = getInput();
3741 const Value output = getOutput();
3742 const RankedTensorType inputType =
3743 llvm::dyn_cast<RankedTensorType>(input.
getType());
3744 const RankedTensorType outputType =
3745 llvm::dyn_cast<RankedTensorType>(output.
getType());
3747 SmallVector<int64_t> scaleValues;
3748 SmallVector<int64_t> offsetValues;
3749 SmallVector<int64_t> borderValues;
3757 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
3758 return emitOpError(
"expect all scale values to be > 0, got ")
3761 const int64_t scaleYN = scaleValues[0];
3762 const int64_t scaleYD = scaleValues[1];
3763 const int64_t scaleXN = scaleValues[2];
3764 const int64_t scaleXD = scaleValues[3];
3766 const int64_t offsetY = offsetValues[0];
3767 const int64_t offsetX = offsetValues[1];
3769 const int64_t borderY = borderValues[0];
3770 const int64_t borderX = borderValues[1];
3777 const int64_t oh = outputType.getDimSize(1);
3778 const int64_t ow = outputType.getDimSize(2);
3779 const int64_t ih = inputType.getDimSize(1);
3780 const int64_t iw = inputType.getDimSize(2);
3786 if (ih != ShapedType::kDynamic && ih != 1) {
3787 const std::optional<int64_t> calculatedOutHeightMinusOne =
3788 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3789 if (!calculatedOutHeightMinusOne.has_value())
3790 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3792 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3793 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3794 <<
") / " << scaleYD;
3795 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3796 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3797 return emitOpError(
"calculated output height did not match expected: ")
3798 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3805 if (iw != ShapedType::kDynamic && iw != 1) {
3806 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3807 const std::optional<int64_t> calculatedOutWidthMinusOne =
3809 if (!calculatedOutWidthMinusOne.has_value())
3810 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3812 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3813 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3814 <<
") / " << scaleXD;
3815 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3816 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3817 return emitOpError(
"calculated output width did not match expected: ")
3818 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3824LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3825 MLIRContext *context, ::std::optional<Location> location,
3826 ScatterOp::Adaptor adaptor,
3827 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3828 llvm::SmallVector<int64_t> outputShape;
3829 outputShape.resize(3, ShapedType::kDynamic);
3831 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3832 if (valuesInShape.hasRank()) {
3833 outputShape[0] = valuesInShape.getDimSize(0);
3834 outputShape[1] = valuesInShape.getDimSize(1);
3835 outputShape[2] = valuesInShape.getDimSize(2);
3838 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3839 if (indicesShape.hasRank()) {
3840 if (outputShape[0] == ShapedType::kDynamic)
3841 outputShape[0] = indicesShape.getDimSize(0);
3844 ShapeAdaptor inputShape(adaptor.getInput().getType());
3845 if (inputShape.hasRank()) {
3846 if (outputShape[0] == ShapedType::kDynamic)
3847 outputShape[0] = inputShape.getDimSize(0);
3848 if (outputShape[2] == ShapedType::kDynamic)
3849 outputShape[2] = inputShape.getDimSize(2);
3852 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3856LogicalResult tosa::ScatterOp::verify() {
3866 const ShapeAdaptor valuesInShape(getValuesIn().
getType());
3868 const ShapeAdaptor inputShape(getInput().
getType());
3869 const ShapeAdaptor outputShape(getValuesOut().
getType());
3871 int64_t n = ShapedType::kDynamic;
3872 int64_t k = ShapedType::kDynamic;
3873 int64_t w = ShapedType::kDynamic;
3874 int64_t c = ShapedType::kDynamic;
3875 if (valuesInShape.hasRank()) {
3876 n = valuesInShape.getDimSize(0);
3877 k = valuesInShape.getDimSize(1);
3878 c = valuesInShape.getDimSize(2);
3880 if (indicesShape.hasRank()) {
3881 const int64_t indicesN = indicesShape.getDimSize(0);
3882 w = indicesShape.getDimSize(1);
3883 if (n == ShapedType::kDynamic)
3885 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3886 return emitOpError() <<
"requires indices dimension 0 to have size " << n
3887 <<
", got " << indicesN;
3889 if (inputShape.hasRank()) {
3890 const int64_t inputN = inputShape.getDimSize(0);
3891 const int64_t inputW = inputShape.getDimSize(1);
3892 const int64_t inputC = inputShape.getDimSize(2);
3893 if (n == ShapedType::kDynamic)
3895 else if (inputN != ShapedType::kDynamic && n != inputN)
3896 return emitOpError() <<
"requires input dimension 0 to have size " << n
3897 <<
", got " << inputN;
3898 if (w == ShapedType::kDynamic)
3900 else if (inputW != ShapedType::kDynamic && w != inputW)
3901 return emitOpError() <<
"requires input dimension 1 to have size " << w
3902 <<
", got " << inputW;
3904 if (c == ShapedType::kDynamic)
3906 else if (inputC != ShapedType::kDynamic && c != inputC)
3907 return emitOpError() <<
"requires input dimension 2 to have size " << c
3908 <<
", got " << inputC;
3910 if (outputShape.hasRank()) {
3911 const int64_t outputN = outputShape.getDimSize(0);
3912 const int64_t outputK = outputShape.getDimSize(1);
3913 const int64_t outputC = outputShape.getDimSize(2);
3914 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3916 return emitOpError() <<
"requires values_out dimension 0 to have size "
3917 << n <<
", got " << outputN;
3918 if (k == ShapedType::kDynamic)
3920 else if (outputK != ShapedType::kDynamic && k != outputK)
3921 return emitOpError() <<
"requires values_out dimension 1 to have size "
3922 << k <<
", got " << outputK;
3923 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3925 return emitOpError() <<
"requires values_out dimension 2 to have size "
3926 << c <<
", got " << outputC;
3928 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3929 return emitOpError() <<
"requires dimensions K >= W, got K=" << k
3938 int64_t axisVal = axis.getValue().getSExtValue();
3939 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3945 operandShape.
getDims(outputShape);
3946 outputShape[axisVal] = 1;
3951#define COMPATIBLE_RETURN_TYPES(OP) \
3952 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3953 if (l.size() != r.size() || l.size() != 1) \
3955 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3957 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3960#define REDUCE_SHAPE_INFER(OP) \
3961 LogicalResult OP::inferReturnTypeComponents( \
3962 MLIRContext *context, ::std::optional<Location> location, \
3963 OP::Adaptor adaptor, \
3964 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3966 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3967 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3968 const Properties &prop = adaptor.getProperties(); \
3969 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3970 inferredReturnShapes); \
3972 COMPATIBLE_RETURN_TYPES(OP)
3980#undef REDUCE_SHAPE_INFER
3982#undef COMPATIBLE_RETURN_TYPES
3984template <
typename T>
3987 TensorType inputType = op.getInput().getType();
3988 TensorType outputType = op.getOutput().getType();
3989 int32_t reduceAxis = op.getAxis();
3991 if (reduceAxis < 0) {
3992 op.emitOpError(
"reduce axis must not be negative");
3996 int64_t inputRank = inputType.getRank();
3999 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
4000 op.emitOpError(
"expect input tensor rank (")
4001 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
4007 int64_t outputRank = outputType.getRank();
4008 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
4010 "expect output tensor rank to be equal to input tensor rank");
4013 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
4014 op.emitOpError(
"expect output tensor rank (")
4015 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
4021 if (outputRank != 0) {
4022 auto outputShape = outputType.
getShape();
4023 if (!outputType.isDynamicDim(reduceAxis) &&
4024 outputShape[reduceAxis] != 1) {
4025 op.emitOpError(
"expect reduced dimension size to be 1, got ")
4026 << outputShape[reduceAxis];
4034LogicalResult tosa::ReduceAllOp::verify() {
return verifyReduceOp(*
this); }
4035LogicalResult tosa::ReduceAnyOp::verify() {
return verifyReduceOp(*
this); }
4036LogicalResult tosa::ReduceMaxOp::verify() {
return verifyReduceOp(*
this); }
4037LogicalResult tosa::ReduceMinOp::verify() {
return verifyReduceOp(*
this); }
4038LogicalResult tosa::ReduceProductOp::verify() {
return verifyReduceOp(*
this); }
4039LogicalResult tosa::ReduceSumOp::verify() {
return verifyReduceOp(*
this); }
4053#define NARY_SHAPE_INFER(OP) \
4054 LogicalResult OP::inferReturnTypeComponents( \
4055 MLIRContext *context, ::std::optional<Location> location, \
4056 ValueShapeRange operands, DictionaryAttr attributes, \
4057 PropertyRef properties, RegionRange regions, \
4058 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
4059 return NAryInferReturnTypes(operands, inferredReturnShapes); \
4099#undef PRED_SHAPE_INFER
4101LogicalResult tosa::NegateOp::inferReturnTypeComponents(
4102 MLIRContext *context, ::std::optional<Location> location,
4103 NegateOp::Adaptor adaptor,
4105 ShapeAdaptor inputShape(adaptor.getInput1().getType());
4110LogicalResult tosa::NegateOp::verify() {
4112 const Type input1Type = getInput1().getType();
4113 const Type outputType = getOutput().getType();
4118 const SmallVector<Type, 2> types = {input1Type, outputType};
4120 return emitOpError() <<
"requires the same shape for input1 and output";
4123 const Type input1ZpEType =
4125 if (input1EType != input1ZpEType) {
4126 return emitOpError(
"expect both input1 and its zero point are the same "
4127 "element type, got ")
4128 << input1EType <<
" and " << input1ZpEType;
4131 const Type outputZpEType =
4133 if (outputEType != outputZpEType) {
4134 return emitOpError(
"expect both output and its zero point are the same "
4135 "element type, got ")
4136 << outputEType <<
" and " << outputZpEType;
4139 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4140 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4143 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4144 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4155 outputShape.resize(4, ShapedType::kDynamic);
4170 if (ShapedType::isStatic(height)) {
4171 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4172 outputShape[1] = padded / stride[0] + 1;
4175 if (ShapedType::isStatic(width)) {
4176 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4177 outputShape[2] = padded / stride[1] + 1;
4184template <
typename AdaptorT>
4190 if (ShapedType::isDynamic(current))
4191 current = candidate;
4200 : adaptor(adaptor) {}
4204 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4212 outputShape[0] = outputBatch;
4213 inputSpatial[0] = inputHeight;
4214 inputSpatial[1] = inputWidth;
4219 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4227 outputShape[3] = outputChannels;
4228 weightSpatial[0] = kernelHeight;
4229 weightSpatial[1] = kernelWidth;
4238 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4239 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4240 dilationValues.assign(adaptor.getDilation().begin(),
4241 adaptor.getDilation().end());
4246 Conv2DOp::Adaptor adaptor;
4254 : adaptor(adaptor) {}
4258 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4259 if (inputDataShape.
hasRank()) {
4264 outputShape[0] = outputBatch;
4265 inputSpatial[0] = inputHeight;
4266 inputSpatial[1] = inputWidth;
4269 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4270 if (!inputScaleShape.
hasRank())
4284 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4285 if (weightDataShape.
hasRank()) {
4290 outputShape[3] = outputChannels;
4291 weightSpatial[0] = kernelHeight;
4292 weightSpatial[1] = kernelWidth;
4295 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4296 if (!weightScaleShape.
hasRank())
4325 Conv2DBlockScaledOp::Adaptor adaptor;
4333 : adaptor(adaptor) {}
4337 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4346 outputShape[0] = outputBatch;
4347 inputSpatial[0] = inputDepth;
4348 inputSpatial[1] = inputHeight;
4349 inputSpatial[2] = inputWidth;
4354 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4363 outputShape[4] = outputChannels;
4364 weightSpatial[0] = kernelDepth;
4365 weightSpatial[1] = kernelHeight;
4366 weightSpatial[2] = kernelWidth;
4375 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4376 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4377 dilationValues.assign(adaptor.getDilation().begin(),
4378 adaptor.getDilation().end());
4383 Conv3DOp::Adaptor adaptor;
4386template <
typename AdaptorT>
4392 ShapedType::kDynamic);
4394 ShapedType::kDynamic);
4396 ShapedType::kDynamic);
4398 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4399 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4401 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4404 if (biasSize != 1) {
4405 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4406 outputShape[outputChannelDim] =
4407 ShapedType::isDynamic(outputShape[outputChannelDim])
4409 : outputShape[outputChannelDim];
4416 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4422 for (
int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4423 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4424 !ShapedType::isStatic(weightSpatial[dim]))
4427 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4429 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4430 const int64_t unstridedResult = inputSize - filterSize + 1;
4431 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4438LogicalResult Conv2DOp::inferReturnTypeComponents(
4439 MLIRContext *context, ::std::optional<Location> location,
4440 Conv2DOp::Adaptor adaptor,
4441 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4445LogicalResult Conv2DOp::verify() {
4452LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4453 MLIRContext *context, ::std::optional<Location> location,
4454 Conv2DBlockScaledOp::Adaptor adaptor,
4455 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4459LogicalResult Conv2DBlockScaledOp::verify() {
4461 getWeightData().
getType(),
"input_data",
4464 getWeightScale().
getType(),
"input_scale",
4467 getOutput().
getType(),
"bias",
"output")))
4471 int64_t N = ShapedType::kDynamic;
4472 int64_t IH = ShapedType::kDynamic;
4473 int64_t IW = ShapedType::kDynamic;
4474 int64_t IC = ShapedType::kDynamic;
4475 int64_t multiplesOfIC = ShapedType::kDynamic;
4476 int64_t OC = ShapedType::kDynamic;
4477 int64_t KH = ShapedType::kDynamic;
4478 int64_t KW = ShapedType::kDynamic;
4480 const ShapeAdaptor inputDataShape(getInputData().
getType());
4481 if (inputDataShape.hasRank()) {
4482 N = inputDataShape.getDimSize(0);
4483 IH = inputDataShape.getDimSize(1);
4484 IW = inputDataShape.getDimSize(2);
4485 IC = inputDataShape.getDimSize(3);
4488 const ShapeAdaptor inputScaleShape(getInputScale().
getType());
4489 if (inputScaleShape.hasRank()) {
4491 "input_scale",
"batch size")) ||
4493 "input_scale",
"input height")) ||
4495 "input_scale",
"input width")))
4497 multiplesOfIC = inputScaleShape.getDimSize(3);
4500 const ShapeAdaptor weightDataShape(getWeightData().
getType());
4501 if (weightDataShape.hasRank()) {
4502 OC = weightDataShape.getDimSize(0);
4503 KH = weightDataShape.getDimSize(1);
4504 KW = weightDataShape.getDimSize(2);
4506 "weight_data",
"input channels")))
4510 const ShapeAdaptor weightScaleShape(getWeightScale().
getType());
4511 if (weightScaleShape.hasRank()) {
4513 "weight_scale",
"output channels")) ||
4515 "weight_scale",
"kernel height")) ||
4517 "weight_scale",
"kernel width")) ||
4519 weightScaleShape.getDimSize(3),
4520 "weight_scale",
"input channel blocks")))
4524 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
4525 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4526 return emitOpError(
"expect block size to be 32, got ") << blockSize;
4528 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4529 return emitOpError(
"expect IC to be a multiple of block size, got IC=")
4530 << IC <<
", block_size=" << blockSize;
4533 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4534 multiplesOfIC != IC / blockSize)
4536 "expect scale operands dimension 2 to equal IC/block_size (")
4537 << IC <<
"/" << blockSize <<
")"
4538 <<
", got " << multiplesOfIC;
4541 SmallVector<int64_t> padValues;
4543 if (llvm::any_of(padValues, [](int64_t p) {
return p < 0; }))
4544 return emitOpError(
"expect all padding values to be >= 0, got ")
4548 SmallVector<int64_t> strideValues;
4550 if (llvm::any_of(strideValues, [](int64_t s) {
return s < 1; }))
4551 return emitOpError(
"expect all stride values to be >= 1, got ")
4555 SmallVector<int64_t> dilationValues;
4558 if (llvm::any_of(dilationValues, [](int64_t d) {
return d < 1; }))
4559 return emitOpError(
"expect all dilation values to be >= 1, got ")
4564 const ShapeAdaptor outputShape(getOutput().
getType());
4565 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4566 outputShape.hasRank()) {
4568 padValues[0], padValues[1], strideValues[0],
4569 dilationValues[0],
"height",
"y",
"top",
4572 padValues[2], padValues[3], strideValues[1],
4573 dilationValues[1],
"width",
"x",
"left",
4579 const ShapeAdaptor biasShape(getBias().
getType());
4580 if (biasShape.hasRank() && outputShape.hasRank()) {
4581 const int64_t biasChannels = biasShape.getDimSize(0);
4582 const int64_t outputChannels =
4583 outputShape.getDimSize(outputShape.getRank() - 1);
4584 if (biasChannels == ShapedType::kDynamic ||
4585 outputChannels == ShapedType::kDynamic)
4589 if (biasChannels != outputChannels && biasChannels != 1)
4591 "bias channels expected to be equal to output channels (")
4592 << outputChannels <<
") or 1, got " << biasChannels;
4598LogicalResult Conv3DOp::inferReturnTypeComponents(
4599 MLIRContext *context, ::std::optional<Location> location,
4600 Conv3DOp::Adaptor adaptor,
4601 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4605LogicalResult Conv3DOp::verify() {
4612LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4613 MLIRContext *context, ::std::optional<Location> location,
4614 AvgPool2dOp::Adaptor adaptor,
4615 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4616 ShapeAdaptor inputShape(adaptor.getInput().getType());
4617 const Properties &prop = adaptor.getProperties();
4619 inferredReturnShapes);
4622LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4623 MLIRContext *context, ::std::optional<Location> location,
4624 AvgPool2dAdaptiveOp::Adaptor adaptor,
4625 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4626 ShapeAdaptor inputShape(adaptor.getInput().getType());
4628 llvm::SmallVector<int64_t> kernelValues;
4629 llvm::SmallVector<int64_t> strideValues;
4630 llvm::SmallVector<int64_t> padValues;
4637 padValues, inferredReturnShapes);
4640 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4641 if (inputShape.hasRank()) {
4643 outputShape[0] = inputShape.getDimSize(0);
4644 outputShape[3] = inputShape.getDimSize(3);
4647 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4651LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4652 MLIRContext *context, ::std::optional<Location> location,
4653 MaxPool2dOp::Adaptor adaptor,
4654 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4655 ShapeAdaptor inputShape(adaptor.getInput().getType());
4656 const Properties &prop = adaptor.getProperties();
4658 inferredReturnShapes);
4661LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4662 MLIRContext *context, ::std::optional<Location> location,
4663 MaxPool2dAdaptiveOp::Adaptor adaptor,
4664 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4665 ShapeAdaptor inputShape(adaptor.getInput().getType());
4667 llvm::SmallVector<int64_t> kernelValues;
4668 llvm::SmallVector<int64_t> strideValues;
4669 llvm::SmallVector<int64_t> padValues;
4676 padValues, inferredReturnShapes);
4679 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4680 if (inputShape.hasRank()) {
4681 outputShape[0] = inputShape.getDimSize(0);
4682 outputShape[3] = inputShape.getDimSize(3);
4684 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4688LogicalResult MaxPool2dOp::verify() {
4699LogicalResult MaxPool2dAdaptiveOp::verify() {
4704 AdaptivePoolingConstShapeValues values;
4708 values.pad, getInput(), getOutput())))
4714LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4715 MLIRContext *context, ::std::optional<Location> location,
4716 DepthwiseConv2DOp::Adaptor adaptor,
4717 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4718 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4720 int64_t inputWidth = ShapedType::kDynamic;
4721 int64_t inputHeight = ShapedType::kDynamic;
4722 int64_t inputChannels = ShapedType::kDynamic;
4724 int64_t weightWidth = ShapedType::kDynamic;
4725 int64_t weightHeight = ShapedType::kDynamic;
4726 int64_t depthChannels = ShapedType::kDynamic;
4729 ShapeAdaptor inputShape(adaptor.getInput().getType());
4730 if (inputShape.hasRank()) {
4731 outputShape[0] = inputShape.getDimSize(0);
4732 inputHeight = inputShape.getDimSize(1);
4733 inputWidth = inputShape.getDimSize(2);
4734 inputChannels = inputShape.getDimSize(3);
4738 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4739 if (weightShape.hasRank()) {
4740 weightHeight = weightShape.getDimSize(0);
4741 weightWidth = weightShape.getDimSize(1);
4742 inputChannels = ShapedType::isDynamic(inputChannels)
4743 ? weightShape.getDimSize(2)
4745 depthChannels = weightShape.getDimSize(3);
4750 if (ShapedType::isStatic(inputChannels) &&
4751 ShapedType::isStatic(depthChannels)) {
4752 outputShape[3] = inputChannels * depthChannels;
4756 ShapeAdaptor biasShape(adaptor.getBias().getType());
4757 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4758 int64_t bc = biasShape.getDimSize(0);
4759 if (bc != ShapedType::kDynamic && bc != 1)
4760 outputShape[3] = bc;
4763 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4764 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4765 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4767 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4768 int64_t inputSize = inputHeight + padding[0] + padding[1];
4769 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4770 int64_t unstridedResult = inputSize - filterSize + 1;
4771 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4774 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4775 int64_t inputSize = inputWidth + padding[2] + padding[3];
4776 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4777 int64_t unstridedResult = inputSize - filterSize + 1;
4778 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4781 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4785LogicalResult DepthwiseConv2DOp::verify() {
4792LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4793 MLIRContext *context, ::std::optional<Location> location,
4794 TransposeConv2DOp::Adaptor adaptor,
4795 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4796 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4798 int64_t inputWidth = ShapedType::kDynamic;
4799 int64_t inputHeight = ShapedType::kDynamic;
4800 int64_t weightWidth = ShapedType::kDynamic;
4801 int64_t weightHeight = ShapedType::kDynamic;
4804 ShapeAdaptor inputShape(adaptor.getInput().getType());
4805 if (inputShape.hasRank()) {
4806 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4807 ? inputShape.getDimSize(0)
4809 inputHeight = inputShape.getDimSize(1);
4810 inputWidth = inputShape.getDimSize(2);
4814 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4815 if (weightShape.hasRank()) {
4816 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4817 ? weightShape.getDimSize(0)
4819 weightHeight = weightShape.getDimSize(1);
4820 weightWidth = weightShape.getDimSize(2);
4824 ShapeAdaptor biasShape(adaptor.getBias().getType());
4825 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4826 int64_t bc = biasShape.getDimSize(0);
4827 if (bc != ShapedType::kDynamic && bc != 1)
4828 outputShape[3] = bc;
4831 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4832 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4834 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4835 int64_t calculateSize =
4836 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4838 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4841 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4842 int64_t calculateSize =
4843 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4845 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4848 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4852LogicalResult TransposeConv2DOp::verify() {
4856 const llvm::ArrayRef<int64_t> strides = getStride();
4857 const int64_t strideY = strides[0];
4858 const int64_t strideX = strides[1];
4860 if (strideY < 1 || strideX < 1)
4861 return emitOpError(
"expect all stride values to be >= 1, got [")
4864 const auto checkPadAgainstKernelDim =
4865 [
this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4866 llvm::StringRef kernelDimName) -> LogicalResult {
4867 if (padValue <= -kernelDimSize)
4869 << padName <<
" > -" << kernelDimName <<
", but got: " << padName
4870 <<
"=" << padValue <<
" and " << kernelDimName <<
"="
4875 const llvm::ArrayRef<int64_t> padding = getOutPad();
4876 const int64_t outPadTop = padding[0];
4877 const int64_t outPadBottom = padding[1];
4878 const int64_t outPadLeft = padding[2];
4879 const int64_t outPadRight = padding[3];
4881 const auto weightType =
4882 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
4885 const int64_t kernelHeight = weightType.getDimSize(1);
4886 if (ShapedType::isStatic(kernelHeight)) {
4887 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4888 "out_pad_top",
"KH")))
4891 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4892 "out_pad_bottom",
"KH")))
4896 const int64_t kernelWidth = weightType.getDimSize(2);
4897 if (ShapedType::isStatic(kernelWidth)) {
4898 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4899 "out_pad_left",
"KW")))
4902 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4903 "out_pad_right",
"KW")))
4909 const auto outputType =
4910 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
4914 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
4915 if (inputType && weightType) {
4916 const int64_t inputHeight = inputType.getDimSize(1);
4917 const int64_t kernelHeight = weightType.getDimSize(1);
4918 const int64_t outputHeight = outputType.getDimSize(1);
4920 if (ShapedType::isStatic(inputHeight) &&
4921 ShapedType::isStatic(outputHeight)) {
4923 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4925 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4926 "+ out_pad_top + out_pad_bottom + KH, but got ")
4927 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
4928 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
4929 <<
" + " << kernelHeight;
4932 const int64_t inputWidth = inputType.getDimSize(2);
4933 const int64_t kernelWidth = weightType.getDimSize(2);
4934 const int64_t outputWidth = outputType.getDimSize(2);
4936 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4938 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4940 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4941 "+ out_pad_left + out_pad_right + KW, but got ")
4942 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
4943 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
4948 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
4953 const int64_t biasChannels = biasType.getDimSize(0);
4956 if (biasChannels == ShapedType::kDynamic)
4959 const int64_t outputChannels = outputType.getDimSize(3);
4960 if (!ShapedType::isDynamic(outputChannels) &&
4961 biasChannels != outputChannels && biasChannels != 1)
4963 "bias channels expected to be equal to output channels (")
4964 << outputChannels <<
") or 1, got " << biasChannels;
4969LogicalResult RescaleOp::verify() {
4970 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
4971 auto inputElementType =
4973 if (!mlir::isa<IntegerType>(inputElementType)) {
4974 emitOpError(
"expect input to have integer element type, got ")
4975 << inputElementType;
4979 const auto outputType = llvm::cast<ShapedType>(getOutput().
getType());
4980 auto outputElementType =
4982 if (!mlir::isa<IntegerType>(outputElementType)) {
4983 emitOpError(
"expect output to have integer element type, got ")
4984 << outputElementType;
4996 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4997 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
5000 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
5001 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
5004 const auto multiplierType = llvm::cast<ShapedType>(getMultiplier().
getType());
5006 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
5007 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
5008 << multiplierType.getElementType();
5013 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
5015 "expect i16 element type for multiplier for scale32=false, got ")
5016 << multiplierType.getElementType();
5020 if (!inputType.hasRank())
5026 int64_t numChannels = 1;
5027 if (getPerChannel()) {
5028 if (inputType.getRank() < 1) {
5029 emitOpError(
"requires input to be at least rank 1 when per_channel is "
5030 "true, but got rank ")
5031 << inputType.getRank();
5034 numChannels = inputType.getDimSize(inputType.getRank() - 1);
5037 if (outputType.hasRank()) {
5039 getOperation(), outputType, inputType.getShape())))
5043 if (multiplierType.hasRank()) {
5044 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
5046 if (multiplierShape[0] != ShapedType::kDynamic &&
5047 multiplierShape[0] != numChannels) {
5049 << numChannels <<
" } for multiplier input, got { "
5050 << multiplierShape[0] <<
" }";
5055 const auto shiftType = llvm::cast<ShapedType>(getShift().
getType());
5056 if (shiftType.hasRank()) {
5057 ArrayRef<int64_t> shiftShape = shiftType.getShape();
5059 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
5061 << numChannels <<
" } for shift input, got { " << shiftShape[0]
5070LogicalResult RescaleOp::inferReturnTypeComponents(
5071 MLIRContext *context, ::std::optional<Location> location,
5072 RescaleOp::Adaptor adaptor,
5073 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5074 ShapeAdaptor inputShape(adaptor.getInput().getType());
5075 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5079LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
5080 MLIRContext *context, ::std::optional<Location> location,
5081 CastFromBlockScaledOp::Adaptor adaptor,
5082 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5083 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5084 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5088LogicalResult CastFromBlockScaledOp::verify() {
5089 const Type inputDataType = getInputData().getType();
5090 const Type outputDataType = getResult().getType();
5092 return emitOpError() <<
"require compatible shapes for input_data ("
5093 << inputDataType <<
") and " <<
"output_data ("
5094 << outputDataType <<
")";
5096 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5098 if (inputDataShape.
hasRank()) {
5099 const unsigned int blockSize =
5101 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5102 return emitOpError(
"expect block size to be 32, got ") << blockSize;
5103 const int64_t inputDataLastDim =
5105 if (inputDataLastDim % blockSize != 0)
5106 return emitOpError() <<
"expect last dimension of input_data ("
5108 <<
") to be divisible by block_size (" << blockSize
5111 const Type inputScaleType = getInputScale().getType();
5112 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5114 if (inputScaleShape.
hasRank()) {
5115 SmallVector<int64_t> inputDataDims, inputScaleDims;
5116 inputDataShape.
getDims(inputDataDims);
5117 inputScaleShape.
getDims(inputScaleDims);
5119 if (inputDataDims.size() != inputScaleDims.size() ||
5121 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5122 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5124 <<
"require compatible shapes for input_data (" << inputDataType
5125 <<
") and " <<
"input_scale (" << inputScaleType
5126 <<
") except for the last dimension";
5128 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5129 inputScaleDims.back()};
5130 if (ShapedType::isStatic(inputDataLastDim) &&
5133 <<
"expect last dimension of input_scale ("
5134 << inputScaleDims.back()
5135 <<
") to be equal to last dimension of input_data / block_size ("
5136 << inputDataDims.back() / blockSize <<
")";
5143LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5144 MLIRContext *context, ::std::optional<Location> location,
5145 CastToBlockScaledOp::Adaptor adaptor,
5146 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5147 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5148 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5149 if (!inputShape.hasRank())
5153 SmallVector<int64_t> outputScaleShape;
5154 inputShape.getDims(outputScaleShape);
5155 const int64_t lastDimLoc = inputShape.getRank() - 1;
5156 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5157 if (ShapedType::isStatic(lastDimSize)) {
5158 const unsigned int blockSize =
5159 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5160 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5162 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5166LogicalResult CastToBlockScaledOp::verify() {
5167 const Type inputDataType = getInputData().getType();
5168 const Type outputDataType = getResult(0).getType();
5170 return emitOpError() <<
"require compatible shapes for input_data ("
5171 << inputDataType <<
") and " <<
"output_data ("
5172 << outputDataType <<
")";
5174 const unsigned int blockSize =
5176 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5177 return emitOpError(
"expect block size to be 32, got ") << blockSize;
5178 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5179 if (inputDataShape.
hasRank()) {
5180 const int64_t inputDataLastDim =
5182 if (ShapedType::isStatic(inputDataLastDim) &&
5183 inputDataLastDim % blockSize != 0)
5184 return emitOpError() <<
"expect last dimension of input_data ("
5186 <<
") to be divisible by block_size (" << blockSize
5190 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5191 const Type outputScaleType = getResult(1).getType();
5192 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5194 SmallVector<int64_t> outputDataDims, outputScaleDims;
5195 outputDataShape.
getDims(outputDataDims);
5196 outputScaleShape.
getDims(outputScaleDims);
5198 if (outputDataDims.size() != outputScaleDims.size() ||
5200 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5201 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5202 return emitOpError() <<
"require compatible shapes for output_data ("
5203 << outputDataType <<
") and " <<
"output_scale ("
5205 <<
") except for the last dimension";
5207 const int64_t outputDataLastDim = outputDataDims.back();
5208 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5209 outputScaleDims.back()};
5210 if (ShapedType::isStatic(outputDataLastDim) &&
5213 <<
"expect last dimension of output_scale ("
5214 << outputScaleDims.back()
5215 <<
") to be equal to last dimension of output_data / block_size ("
5216 << outputDataDims.back() / blockSize <<
")";
5222LogicalResult IfOp::inferReturnTypeComponents(
5223 MLIRContext *context, ::std::optional<Location> location,
5224 IfOp::Adaptor adaptor,
5225 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5226 llvm::SmallVector<tosa::YieldOp> yieldOps;
5227 for (Region *region : adaptor.getRegions()) {
5228 for (
auto &block : *region)
5229 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5230 yieldOps.push_back(returnOp);
5233 if (yieldOps.empty())
5237 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5238 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5239 for (
auto operand : yieldOps.front().getOperands()) {
5240 resultKnowledge.push_back(
5244 for (
auto yieldOp : yieldOps) {
5245 if (resultKnowledge.size() != yieldOp.getNumOperands())
5248 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5249 int32_t index = it.index();
5251 resultKnowledge[index],
5255 resultKnowledge[index] = meet;
5259 for (
const ValueKnowledge &
result : resultKnowledge) {
5260 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5266LogicalResult WhileOp::inferReturnTypeComponents(
5267 MLIRContext *context, ::std::optional<Location> location,
5268 WhileOp::Adaptor adaptor,
5269 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5270 llvm::SmallVector<tosa::YieldOp> yieldOps;
5271 for (
auto &block : adaptor.getBodyGraph())
5272 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5273 yieldOps.push_back(returnOp);
5277 if (yieldOps.empty())
5281 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5282 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5283 for (
auto operand : yieldOps.front().getOperands()) {
5284 resultKnowledge.push_back(
5288 for (
auto yieldOp : yieldOps) {
5289 if (resultKnowledge.size() != yieldOp.getNumOperands())
5292 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5293 int32_t index = it.index();
5295 resultKnowledge[index],
5297 resultKnowledge[index] = meet;
5302 for (
const ValueKnowledge &
result : resultKnowledge) {
5303 inferredReturnShapes.push_back(
result.getShapedTypeComponents());
5309std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5310 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
5311 return llvm::to_vector<4>(vt.getShape());
5312 return std::nullopt;
5318 StringRef prefix =
"") {
5319 assert(blocksArgs.size() == initializers.size() &&
5320 "expected same length of arguments and initializers");
5321 if (initializers.empty())
5324 parser << prefix <<
'(';
5325 llvm::interleaveComma(
5326 llvm::zip(blocksArgs, initializers), parser,
5327 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
5332ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
5334 result.regions.reserve(2);
5335 Region *thenRegion =
result.addRegion();
5336 Region *elseRegion =
result.addRegion();
5338 OpAsmParser::UnresolvedOperand cond;
5343 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5344 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5347 OptionalParseResult listResult =
5355 "expected type for condition operand");
5361 "expected type for condition operand");
5369 FunctionType functionType;
5373 <<
"expected list of types for block arguments "
5374 <<
"followed by arrow type and list of return types";
5376 result.addTypes(functionType.getResults());
5378 if (functionType.getNumInputs() != operands.size()) {
5380 <<
"expected as many input types as operands " <<
"(expected "
5381 << operands.size() <<
" got " << functionType.getNumInputs()
5412void IfOp::print(OpAsmPrinter &p) {
5413 p <<
" " << getCondition();
5416 getInputList(),
" ");
5418 p << getCondition().getType();
5420 if (!getInputList().empty()) {
5422 llvm::interleaveComma(getInputList().getTypes(), p);
5431 auto &elseRegion = getElseGraph();
5432 if (!elseRegion.
empty()) {
5440LogicalResult IfOp::verify() {
5442 "'then_graph' arguments", getInputList(),
5448 "'else_graph' arguments", getInputList(),
5454 if (getThenGraph().front().mightHaveTerminator()) {
5456 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5458 *
this, thenYield.getInputs(),
"'then_graph' results",
5459 getOutputList(),
"'output_list'")
5465 if (getElseGraph().front().mightHaveTerminator()) {
5467 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5469 *
this, elseYield.getInputs(),
"'else_graph' results",
5470 getOutputList(),
"'output_list'")
5475 auto condType = getCondition().getType();
5477 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
5483LogicalResult WhileOp::verify() {
5485 getOutputList(),
"'output_list'")
5490 "'cond_graph' arguments", getInputList(),
5496 "'body_graph' arguments", getInputList(),
5501 if (getBodyGraph().front().mightHaveTerminator()) {
5503 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5505 "'body_graph' results",
5506 getInputList(),
"'input_list'")
5513 if (!getCondGraph().front().mightHaveTerminator())
5517 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5521 if (condYield.getInputs().size() != 1)
5522 return emitOpError() <<
"require 'cond_graph' only have one result";
5524 auto condOutType = condYield.getInputs()[0].getType();
5526 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
5530 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
5536LogicalResult ReverseOp::verify() {
5537 TensorType inputType = getInput1().getType();
5538 int32_t reverseAxis = getAxis();
5540 if (reverseAxis < 0)
5541 return emitOpError(
"expected non-negative reverse axis");
5543 int64_t inputRank = inputType.getRank();
5546 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5548 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
5555LogicalResult tosa::SelectOp::verify() {
5566 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
5567 if (!predicateType) {
5568 return emitOpError(
"expect shaped tensor for input1, got ")
5569 << getInput1().getType();
5571 auto predicateElementType = predicateType.getElementType();
5572 if (!predicateElementType.isInteger(1)) {
5573 return emitOpError(
"expect element type of bool for input1, got ")
5574 << predicateElementType;
5580LogicalResult tosa::VariableReadOp::verify() {
5588LogicalResult tosa::VariableWriteOp::verify() {
5597ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
5598 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5599 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5600 Region *cond =
result.addRegion();
5601 Region *body =
result.addRegion();
5603 OptionalParseResult listResult =
5608 FunctionType functionType;
5613 result.addTypes(functionType.getResults());
5615 if (functionType.getNumInputs() != operands.size()) {
5617 <<
"expected as many input types as operands " <<
"(expected "
5618 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
5628 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
5629 regionArgs[i].type = functionType.getInput(i);
5631 return failure(parser.
parseRegion(*cond, regionArgs) ||
5636void WhileOp::print(OpAsmPrinter &parser) {
5638 getInputList(),
" ");
5641 getResults().getTypes());
5655 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5656 if (llvm::isa<FloatType>(srcElemType)) {
5658 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
5659 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5661 if (llvm::isa<IntegerType>(srcElemType)) {
5664 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5666 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
5667 return std::nullopt;
5675 return mlir::isa<tosa::shapeType>(t);
5682 return emitError() <<
"invalid rank (must be >= 0): " << rank;
5688 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5689 Operation *definingOp = v.getDefiningOp();
5691 return op->
emitOpError(
"shape operand is not compile time resolvable");
5704 auto getRank = [](
const Type type) {
5705 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5711 for (
auto type : operandTypes) {
5712 if (getRank(type) != rank) {
5713 return op->
emitOpError(
"operands don't have matching ranks");
5716 for (
auto type : resultTypes) {
5717 if (getRank(type) != rank) {
5718 return op->
emitOpError(
"result shape has different rank than operands");
5728LogicalResult tosa::ConstShapeOp::verify() {
5730 auto valuesRank = getValues().getType().getRank();
5731 if (valuesRank != 1)
5732 return emitOpError(
"expect elements in attribute values with rank 1");
5734 auto count = getValues().getNumElements();
5735 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
5736 if (count != rank && (count != 1 || rank != 0)) {
5737 return emitOpError(
"expect number of elements in attribute values (")
5738 << count <<
") to be equal to the rank (" << rank
5739 <<
") for the result shape type";
5744LogicalResult tosa::DimOp::verify() {
5745 const tosa::shapeType outShapeType =
5746 cast<tosa::shapeType>(getResult().
getType());
5747 if (outShapeType.getRank() != 1)
5748 return emitOpError(
"expect output shape type to contain one element, got ")
5753 const int64_t inputRank = inputType.getRank();
5754 const int64_t axis = getAxisAttr().getInt();
5755 if (axis < 0 || axis >= inputRank)
5756 return emitOpError(
"expect axis to be in the range [0, ")
5757 << inputRank <<
"), got " << axis;
5762LogicalResult tosa::ConcatShapeOp::verify() {
5763 const tosa::shapeType outShapeType =
5764 cast<tosa::shapeType>(getResult().
getType());
5765 const int64_t outputRank = outShapeType.getRank();
5768 if (inputList.size() == 0)
5769 return emitOpError(
"requires at least one input shape");
5771 if (llvm::any_of(inputList, [](Value v) {
5772 return cast<tosa::shapeType>(v.
getType()).getRank() == 0;
5774 return emitOpError(
"requires all inputs shapes have a rank greater than 0");
5776 const int64_t inputsRank =
5777 llvm::accumulate(inputList, 0, [](int64_t acc,
const Value &input) {
5778 const tosa::shapeType inShapeType =
5779 cast<tosa::shapeType>(input.
getType());
5780 return acc + inShapeType.getRank();
5782 if (outputRank != inputsRank)
5783 return emitOpError(
"requires output shape rank to be equal to the sum of "
5784 "the input shape ranks (")
5785 << inputsRank <<
"), got " << outputRank;
5790LogicalResult tosa::SliceShapeOp::verify() {
5791 std::optional<int32_t> start;
5792 DenseIntElementsAttr startAttr;
5794 start = startAttr.getValues<int32_t>()[0];
5795 if (start && start.value() < 0)
5796 return emitOpError(
"expected non-negative start index, got ")
5799 std::optional<int32_t> size;
5800 DenseIntElementsAttr sizeAttr;
5802 size = sizeAttr.getValues<int32_t>()[0];
5803 if (size && size.value() <= 0)
5804 return emitOpError(
"expected positive size, got ") << size.value();
5809 const tosa::shapeType outShapeType =
5810 cast<tosa::shapeType>(getResult().
getType());
5811 const int64_t outputRank = outShapeType.getRank();
5812 if (outputRank != size)
5814 "expected output type size to be equal to size attribute, got ")
5815 << outputRank <<
" vs " << size.value();
5820 const tosa::shapeType inShapeType =
5821 cast<tosa::shapeType>(getInput().
getType());
5822 const int64_t inputRank = inShapeType.getRank();
5823 const int64_t sliceSize = start.value() + size.value();
5824 if (sliceSize > inputRank)
5825 return emitOpError(
"expected start + size to be less than or equal to "
5826 "input shape rank (")
5827 << inputRank <<
"), got " << sliceSize;
5836#define GET_ATTRDEF_CLASSES
5837#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5842#define GET_TYPEDEF_CLASSES
5843#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5849#define GET_OP_CLASSES
5850#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
true
Given two iterators into the same block, return "true" if a is before `b.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
static LogicalResult verifyConvOpErrorIf(T op)
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static LogicalResult verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType, ArrayRef< int64_t > expectedShape, StringRef outputName="output")
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static void updateIfDynamic(int64_t ¤t, int64_t candidate)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
int64_t getOutputRank() const
int64_t getNumSpatialDims() const
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
int64_t getNumSpatialDims() const
int64_t getOutputRank() const
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
int64_t getNumSpatialDims() const
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
int64_t getOutputRank() const
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Operation is the basic unit of execution within MLIR.
ResultRange result_range
Support result iteration.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperandRange operand_range
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
FailureOr< T > getConstantScalarIntValue(Value val)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)