29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
69 return (isa<tosa::IfOp>(dest->getParentOp()) ||
70 isa<tosa::WhileOp>(dest->getParentOp()));
76 TosaDialectBytecodeInterface(
Dialect *dialect)
86 LogicalResult writeAttribute(
Attribute attr,
88 return ::writeAttribute(attr, writer);
98 LogicalResult writeType(
Type type,
100 return ::writeType(type, writer);
107 std::unique_ptr<DialectVersion>
110 reader.
emitError(
"Dialect does not support versioning");
114 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
133 void TosaDialect::initialize() {
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
140 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
143 #define GET_ATTRDEF_LIST
144 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
146 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
147 declarePromisedInterfaces<
148 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
149 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
150 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
151 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
152 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
153 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
154 GreaterEqualOp, MatMulOp>();
161 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
162 return builder.
create<tosa::ConstShapeOp>(
163 loc, type, llvm::cast<DenseIntElementsAttr>(value));
165 if (llvm::isa<ElementsAttr>(value))
166 return builder.
create<tosa::ConstOp>(loc, type,
167 llvm::cast<ElementsAttr>(value));
180 <<
"expected attribute";
182 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
199 bool needsSpace =
false;
200 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
201 if (!typedAttr || typedAttr.getType() != type.getValue()) {
218 template <
typename T>
222 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
224 op.emitOpError(
"expect a ranked tensor for input, got ") << op.getInput();
228 auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
230 op.emitOpError(
"expect a ranked tensor for weight, got ") << op.getWeight();
234 auto inputEType = inputType.getElementType();
235 auto weightEType = weightType.getElementType();
237 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
239 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
240 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
243 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
244 inputEType = quantType.getStorageType();
246 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
247 biasEType = quantType.getStorageType();
249 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
250 resultEType = quantType.getStorageType();
252 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
256 "expect both bias and result to have same element type, got ")
257 << biasEType <<
" and " << resultEType;
261 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
262 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
263 if (inputEType != weightEType) {
265 "expect both input and weight to have same element type, got ")
266 << inputEType <<
" and " << weightEType;
271 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
272 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
275 if (inputIsFloat != weightIsFloat) {
277 "expect both input and weight to be float or not together, got ")
278 << inputEType <<
" and " << weightEType;
284 if (!op.getInputZp() && !op.getWeightZp())
285 return inputEType.isInteger(8) ? failure() : success();
287 ElementsAttr inputZpAttr;
288 ElementsAttr weightZpAttr;
292 "bail out if the actual value of zero points cannot be determined");
303 op.emitOpError(
"input zero point must be zero for non-int8 integer types");
310 op.emitOpError(
"weight zero point must be zero for non-int8 integer types");
319 auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().
getType());
320 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
322 if (!attrType || !outputType) {
323 emitOpError(
"expected tensors for attr/result type");
327 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
328 outputType.getElementType())) {
329 if (result.getStorageType() == attrType.getElementType())
333 if (attrType.getElementType() != outputType.getElementType()) {
334 emitOpError(
"expected same attr/result element types");
341 template <
typename T>
344 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
346 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
347 inputEType = quantType.getStorageType();
349 auto accType = op.getAccType();
350 if (inputEType.isInteger(8) && !accType.isInteger(32))
351 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
353 if (inputEType.isInteger(16) && !accType.isInteger(48))
354 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
356 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
357 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
359 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
360 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
362 if (inputEType.isBF16() && !accType.isF32())
363 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
365 if (inputEType.isF32() && !accType.isF32())
366 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
369 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
371 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372 resultEType = quantType.getStorageType();
375 if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376 (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377 (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378 (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379 (inputEType.isF16() && resultEType.isF16()) ||
380 (inputEType.isBF16() && resultEType.isBF16()) ||
381 (inputEType.isF32() && resultEType.isF32()))
384 return op.emitOpError(
"input/output element types are incompatible.");
388 template <
typename T>
390 auto inputType = llvm::dyn_cast<TensorType>(inType);
391 auto outputType = llvm::dyn_cast<TensorType>(outType);
393 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
397 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
400 auto inputElementType = inputType.getElementType();
401 auto outputElementType = outputType.getElementType();
402 auto inputQuantType =
403 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
404 auto outputQuantType =
405 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
406 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
407 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
408 inputElementType != outputElementType) {
413 op.emitOpError(
"expect input and output to have same element type, got ")
414 << inputElementType <<
" and " << outputElementType;
422 const auto resultETy = llvm::cast<ShapedType>(
getType()).getElementType();
423 if (!resultETy.isIntOrIndex())
424 return emitOpError(
"result tensor is not of integer type");
427 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
428 const int64_t axis = getAxisAttr().getInt();
429 if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
430 return emitOpError(
"specified axis is outside the rank of the tensor");
436 auto inputType = llvm::cast<ShapedType>(getInput().
getType());
438 auto inputETy = inputType.getElementType();
439 auto resultETy = llvm::cast<ShapedType>(
getType()).getElementType();
442 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
443 inputETy = quantType.getStorageType();
446 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
447 resultETy = quantType.getStorageType();
449 auto accType = getAccType();
450 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
451 return emitOpError(
"accumulator type for integer tensor is not i32");
453 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
454 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
456 if (inputETy.isBF16() && !accType.isF32())
457 return emitOpError(
"accumulator type for bf16 tensor is not f32");
459 if (inputETy.isF32() && !accType.isF32())
460 return emitOpError(
"accumulator type for f32 tensor is not f32");
462 if ((inputETy.isF32() && resultETy.isF32()) ||
463 (inputETy.isF16() && resultETy.isF16()) ||
464 (inputETy.isBF16() && resultETy.isBF16()) ||
465 (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
466 (inputETy.isInteger(16) && resultETy.isInteger(16)))
469 return emitOpError(
"input/output element types are incompatible.");
474 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
476 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
477 inputETy = quantType.getStorageType();
480 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
482 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
483 outputETy = quantType.getStorageType();
485 if (inputETy != outputETy)
486 return emitOpError(
"input/output element types are incompatible.");
488 auto maxValAttr = getMaxValAttr();
489 auto minValAttr = getMinValAttr();
493 if (inputETy.
isInteger(dataTypeBitWidth)) {
497 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
498 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
499 if (!intMaxValAttr || !intMinValAttr ||
500 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
501 (intMaxValAttr.getType() != inputETy))
502 return emitOpError(
"min/max attributes types are incompatible with "
503 "input/output element types.");
508 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
509 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
510 if (!floatMaxValAttr || !floatMinValAttr ||
511 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
512 (floatMaxValAttr.getType() != inputETy))
513 return emitOpError(
"min/max attributes types are incompatible with "
514 "input/output element types.");
534 result.
addOperands({input, weight, bias, zps.first, zps.second});
539 Type finalOutputType = outputType;
555 result.
addOperands({input, weight, bias, zps.first, zps.second});
560 Type finalOutputType = outputType;
581 static_cast<int32_t
>(quantAttr.getAZp())));
583 static_cast<int32_t
>(quantAttr.getBZp())));
585 auto inputType = llvm::dyn_cast<ShapedType>(a.
getType());
586 assert(inputType &&
"Input must be a shaped tensor type!");
588 auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
589 inputType.getElementType());
590 assert(inputQType &&
"Tensor must have quantized datatype!");
592 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
594 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
595 assert(outputShapedType &&
"Output must be a shaped type");
597 IntegerType accElementType;
602 auto accType = outputShapedType.clone(accElementType);
615 DenseArrayAttr kernel, DenseArrayAttr stride,
616 DenseArrayAttr pad, TypeAttr accType) {
626 static_cast<int32_t
>(quantAttr.getInputZp())));
629 static_cast<int32_t
>(quantAttr.getOutputZp())));
631 result.
types.push_back(outputType);
646 static_cast<int32_t
>(quantAttr.getInputZp())));
649 static_cast<int32_t
>(quantAttr.getOutputZp())));
651 result.
types.push_back(outputType);
665 static_cast<int32_t
>(quantAttr.getInputZp())));
667 result.
types.push_back(outputType);
682 static_cast<int32_t
>(quantAttr.getInputZp())));
684 result.
types.push_back(outputType);
694 for (
int i = 0, e = operands.size(); i != e; ++i) {
696 if (!shape.hasRank()) {
701 outRank = std::max<int64_t>(outRank, shape.getRank());
704 outShape.resize(outRank, 1);
706 for (
int i = 0, e = operands.size(); i != e; ++i) {
708 auto rankDiff = outShape.size() - shape.getRank();
710 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
711 auto dim1 = outShape[i + rankDiff];
712 auto dim2 = shape.getDimSize(i);
713 auto resolvedDim = dim1;
717 }
else if (dim2 == 1) {
719 }
else if (dim1 != dim2) {
722 outShape[i + rankDiff] = resolvedDim;
729 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
730 MLIRContext *context, ::std::optional<Location> location,
731 ArgMaxOp::Adaptor adaptor,
734 IntegerAttr axis = adaptor.getProperties().axis;
735 int32_t axisVal = axis.getValue().getSExtValue();
737 if (!inputShape.hasRank()) {
743 outShape.reserve(inputShape.getRank() - 1);
744 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
747 outShape.push_back(inputShape.getDimSize(i));
754 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
755 MLIRContext *context, ::std::optional<Location> location,
756 RFFT2dOp::Adaptor adaptor,
760 if (!inputShape.hasRank())
764 outputShape.resize(3, ShapedType::kDynamic);
765 outputShape[0] = inputShape.getDimSize(0);
766 outputShape[1] = inputShape.getDimSize(1);
767 int64_t inWidth = inputShape.getDimSize(2);
771 if (inWidth != ShapedType::kDynamic)
772 outputShape[2] = inWidth / 2 + 1;
780 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
781 MLIRContext *context, ::std::optional<Location> location,
782 FFT2dOp::Adaptor adaptor,
784 inferredReturnShapes.push_back(
786 inferredReturnShapes.push_back(
791 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
792 MLIRContext *context, ::std::optional<Location> location,
793 ConcatOp::Adaptor adaptor,
796 const Properties &prop = adaptor.getProperties();
797 int32_t axis = prop.axis.getValue().getSExtValue();
799 bool hasRankedInput =
false;
800 for (
auto operand : adaptor.getOperands()) {
802 if (!operandShape.hasRank())
807 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
810 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
811 if (i == axis || operandShape.isDynamicDim(i))
813 if (outputShape[i] == ShapedType::kDynamic)
814 outputShape[i] = operandShape.getDimSize(i);
815 if (outputShape[i] != operandShape.getDimSize(i))
817 "Cannot concat tensors with different sizes"
818 " on the non-axis dimension ",
822 hasRankedInput =
true;
825 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
826 if (!hasRankedInput) {
832 int64_t concatDimSize = 0;
833 for (
auto operand : adaptor.getOperands()) {
838 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
839 concatDimSize = ShapedType::kDynamic;
843 concatDimSize += operandShape.getDimSize(axis);
846 outputShape[axis] = concatDimSize;
852 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
853 MLIRContext *context, ::std::optional<Location> location,
870 if (l.size() != r.size() || l.size() != 1)
875 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
876 MLIRContext *context, ::std::optional<Location> location,
877 MatMulOp::Adaptor adaptor,
884 outShape.resize(3, ShapedType::kDynamic);
886 if (lhsShape.hasRank()) {
887 outShape[0] = lhsShape.getDimSize(0);
888 outShape[1] = lhsShape.getDimSize(1);
891 if (rhsShape.hasRank()) {
892 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
894 outShape[2] = rhsShape.getDimSize(2);
901 LogicalResult tosa::PadOp::inferReturnTypeComponents(
902 MLIRContext *context, ::std::optional<Location> location,
903 PadOp::Adaptor adaptor,
907 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
912 if (!inputShape.hasRank()) {
913 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
922 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
927 outputShape.reserve(inputShape.getRank());
928 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
929 if (inputShape.isDynamicDim(i)) {
930 outputShape.push_back(ShapedType::kDynamic);
933 auto padFront = paddingValues[i * 2];
934 auto padBack = paddingValues[i * 2 + 1];
935 if (padFront < 0 || padBack < 0) {
937 outputShape.push_back(ShapedType::kDynamic);
941 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
949 RankedTensorType inputType = getInput1().getType();
950 RankedTensorType outputType = getOutput().getType();
951 auto paddingRank = cast<tosa::shapeType>(getPadding().
getType()).getRank();
953 if (inputType.getRank() != outputType.getRank())
954 return emitOpError() <<
"expect same input and output tensor rank.";
956 if (paddingRank != inputType.getRank() * 2)
957 return emitOpError() <<
"expected padding tensor dim 0 to have size "
958 << inputType.getRank() * 2
959 <<
" (2*rank(shape1)) but got size " << paddingRank;
965 return to_vector(llvm::map_range(shape, [](int64_t dim) {
966 return dim == -1 ? ShapedType::kDynamic : dim;
970 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
971 MLIRContext *context, ::std::optional<Location> location,
972 SliceOp::Adaptor adaptor,
981 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
992 if (inputShape.hasRank()) {
993 for (
size_t i = 0; i < size.size(); i++) {
994 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
995 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
996 start[i] < inputShape.getDimSize(i))) {
998 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1001 outputShape[i] = size[i];
1005 if (size[i] == -1) {
1006 outputShape[i] = inputShape.getDimSize(i) - start[i];
1007 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1009 outputShape[i] = size[i];
1022 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1026 auto startShapeRank =
1027 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1028 if (inputType.getRank() != startShapeRank)
1030 "length of start attribute is not equal rank of input shape");
1032 auto sizeShapeRank =
1033 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1034 if (inputType.getRank() != sizeShapeRank)
1036 "length of size attribute is not equal rank of input shape");
1041 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1042 MLIRContext *context, ::std::optional<Location> location,
1062 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1063 IntegerType lhsIntType =
1065 IntegerType rhsIntType =
1067 if (lhsIntType != rhsIntType)
1068 return emitOpError(
"requires the same element type for all operands");
1073 if (lhsIntType.getWidth() > resIntType.getWidth())
1074 return emitOpError(
"invalid data type size for operands or result");
1079 for (
int i = 0; i < 2; ++i) {
1082 "requires the same element type for all operands and results");
1086 ElementsAttr shift_elem;
1088 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1090 return emitOpError() <<
"require shift to be 0 for float type";
1101 auto hasRank = [](
const Type type) {
1102 if (
auto shaped_type = dyn_cast<ShapedType>(type))
1103 return shaped_type.hasRank();
1108 auto rankedOperandTypes =
1109 llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1111 auto rankedResultTypes =
1112 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1115 if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1119 auto getRank = [](
const Type type) {
1120 return cast<ShapedType>(type).getRank();
1123 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1124 : getRank(*rankedResultTypes.begin());
1126 for (
size_t i = 0; i < 2; ++i) {
1127 if (rank != getRank(rankedOperandTypes[i])) {
1128 return emitOpError(
"operands don't have matching ranks");
1132 for (
const auto type : rankedResultTypes) {
1133 if (rank != getRank(type)) {
1134 return emitOpError(
"result type has different rank than operands");
1143 return mlir::cast<ShapedType>(type).getShape();
1149 return emitOpError(
"operands don't have broadcast-compatible shapes");
1155 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1156 MLIRContext *context, ::std::optional<Location> location,
1157 TableOp::Adaptor adaptor,
1159 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1161 if (!inputShape.hasRank()) {
1166 inferredReturnShapes.resize(1);
1167 inputShape.getDims(inferredReturnShapes[0]);
1172 TensorType inputType = getInput1().getType();
1173 TensorType outputType = getOutput().getType();
1176 inputType.getRank() != outputType.getRank())
1177 return emitOpError()
1178 <<
"expected input tensor rank to equal result tensor rank";
1180 auto inputDims = inputType.
getShape();
1181 auto outputDims = outputType.
getShape();
1183 int64_t dim = it.index();
1184 auto [inputDim, outputDim] = it.value();
1185 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1186 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1187 <<
" doesn't match dim(input, " << dim
1188 <<
") = " << inputDim;
1200 multiples = llvm::to_vector(
1201 llvm::map_range(multiplesAttr.getValues<APInt>(),
1202 [](
const APInt &val) { return val.getSExtValue(); }));
1206 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1207 MLIRContext *context, ::std::optional<Location> location,
1208 TileOp::Adaptor adaptor,
1215 llvm::map_range(multiplesAttr.getValues<APInt>(),
1216 [](
const APInt &val) { return val.getSExtValue(); }));
1218 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1220 if (!inputShape.hasRank()) {
1221 outputShape.resize(multiples.size(), ShapedType::kDynamic);
1224 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
1228 outputShape.reserve(multiples.size());
1229 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1230 int64_t dim = inputShape.getDimSize(i);
1231 if (dim != ShapedType::kDynamic)
1232 dim *= multiples[i];
1233 outputShape.push_back(dim);
1241 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
1242 ShapedType outputType = llvm::cast<ShapedType>(
getType());
1244 shapeType multiplesType =
1245 llvm::cast<tosa::shapeType>(getMultiples().
getType());
1247 auto multiplesRank = multiplesType.getRank();
1249 if (inputType.hasRank()) {
1250 if (inputType.getRank() != multiplesRank)
1251 return emitOpError(
"expect 'multiples' to have rank ")
1252 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
1253 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1254 return emitOpError(
"expect same input and output tensor rank.");
1255 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1256 return emitOpError(
"expect 'multiples' array to have length ")
1257 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
1260 if (getConstantMultiples(multiples).succeeded() &&
1261 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
1263 "expect element of 'multiples' to be positive integer or -1.");
1269 if (l.size() != r.size() || l.size() != 1)
1274 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1275 MLIRContext *context, ::std::optional<Location> location,
1276 ReshapeOp::Adaptor adaptor,
1278 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1283 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1293 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1294 inferredReturnShapes.push_back(
1302 int64_t numElements = inputShape.getNumElements();
1303 int64_t staticMul = 1;
1304 for (
auto val : newShapeValue) {
1305 if (!ShapedType::isDynamic(val)) {
1311 for (
auto &val : newShapeValue) {
1312 if (ShapedType::isDynamic(val))
1313 val = numElements / staticMul;
1316 inferredReturnShapes.push_back(
1322 TensorType inputType = getInput1().getType();
1323 RankedTensorType outputType =
getType();
1328 return mlir::success();
1331 if ((int64_t)shapeValues.size() != outputType.getRank())
1332 return emitOpError() <<
"new shape does not match result rank";
1334 for (
auto [newShapeDim, outputShapeDim] :
1335 zip(shapeValues, outputType.getShape())) {
1336 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1337 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1338 return emitOpError() <<
"new shape is inconsistent with result shape";
1340 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1341 return emitOpError() <<
"new shape has invalid tensor dimension size "
1345 if (inputType.hasStaticShape()) {
1346 int64_t inputElementsNum = inputType.getNumElements();
1347 if (outputType.hasStaticShape()) {
1348 int64_t outputElementsNum = outputType.getNumElements();
1349 if (inputElementsNum != outputElementsNum) {
1350 return emitOpError() <<
"cannot reshape " << inputElementsNum
1351 <<
" elements into " << outputElementsNum;
1355 int64_t newShapeElementsNum = std::accumulate(
1356 shapeValues.begin(), shapeValues.end(), 1LL,
1357 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1358 bool isStaticNewShape =
1359 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
1360 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1361 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1362 return emitOpError() <<
"cannot reshape " << inputElementsNum
1363 <<
" elements into " << newShapeElementsNum;
1367 int missingDims = llvm::count(shapeValues, -1);
1368 if (missingDims > 1)
1369 return emitOpError() <<
"expected at most one target dimension to be -1";
1371 return mlir::success();
1381 for (
auto v : permsAttr.getValues<APInt>())
1382 perms.push_back(v.getSExtValue());
1387 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1388 MLIRContext *context, ::std::optional<Location> location,
1389 TransposeOp::Adaptor adaptor,
1391 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1395 if (permsShape.hasRank() && permsShape.getRank() == 0)
1400 if (!inputShape.hasRank() || !permsShape.hasRank() ||
1401 permsShape.isDynamicDim(0)) {
1408 if (permsShape.getDimSize(0) != inputShape.getRank()) {
1414 if (inputShape.getRank() == 0) {
1420 bool allTheSame =
true;
1421 for (
int i = 1, s = inputShape.getRank(); i < s; i++) {
1422 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1431 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1436 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1441 attr.getType().getRank() == 1) {
1444 if (inputShape.getRank() != permShape.
getRank())
1446 "constant permutation must be the same length"
1447 " as the input rank");
1450 for (
int i = 0, e = inputShape.getRank(); i < e; i++) {
1451 if (inputShape.getRank() <= permShape.
getDimSize(i))
1455 outputShape.reserve(inputShape.getRank());
1456 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1457 outputShape[i] = inputShape.getDimSize(permShape.
getDimSize(i));
1466 TensorType inputType = getInput1().getType();
1468 TensorType outputType = getOutput().getType();
1470 if (permType.
hasRank() && permType.getRank() != 1)
1471 return emitOpError()
1472 <<
"expected permutation tensor to be rank 1 but got rank "
1473 << permType.getRank();
1475 if (!permType.isDynamicDim(0) &&
1476 permType.getDimSize(0) != inputType.getRank())
1477 return emitOpError() <<
"expected permutation tensor dim 0 to have size "
1478 << inputType.getRank()
1479 <<
" (input rank) but got size "
1480 << permType.getDimSize(0);
1482 inputType.getRank() != outputType.getRank())
1483 return emitOpError()
1484 <<
"expected input tensor rank to equal result tensor rank";
1486 if (!permType.isDynamicDim(0) &&
1487 permType.getDimSize(0) != outputType.getRank())
1488 return emitOpError() <<
"expected permutation tensor dim 0 to have size "
1489 << outputType.getRank()
1490 <<
" (output rank) but got size "
1491 << permType.getDimSize(0);
1494 if (succeeded(getConstantPerms(constantPerms))) {
1498 "Unexpectedly found permutation tensor without rank");
1499 if (!llvm::all_of(constantPerms,
1500 [&constantPerms](int32_t s) {
1502 static_cast<size_t>(s) < constantPerms.size();
1505 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
1506 return emitOpError() <<
"expected valid permutation tensor";
1511 assert(constantPerms.size() ==
static_cast<size_t>(inputType.getRank()) &&
1512 inputType.getRank() == outputType.getRank());
1514 for (
auto i = 0; i < outputType.getRank(); i++) {
1515 if (inputType.isDynamicDim(constantPerms[i]) ||
1516 outputType.isDynamicDim(i))
1519 if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1520 return emitOpError()
1521 <<
"expected output tensor dim " << i <<
" to match "
1522 <<
"input dim " << constantPerms[i] <<
" with value of "
1523 << inputType.getDimSize(constantPerms[i]);
1534 if (getConstantPerms(transposePerms).failed())
1537 Value input = getInput1();
1538 auto inputType = cast<TensorType>(input.
getType());
1541 for (
auto dim : transposePerms) {
1542 int32_t dimInInput = transposePerms[dim];
1543 if (inputType.isDynamicDim(dimInInput))
1545 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
1549 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
1552 reifiedReturnShapes.emplace_back(std::move(returnedDims));
1556 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1557 MLIRContext *context, ::std::optional<Location> location,
1558 GatherOp::Adaptor adaptor,
1561 outputShape.resize(3, ShapedType::kDynamic);
1563 ShapeAdaptor valuesShape(adaptor.getValues().getType());
1564 if (valuesShape.hasRank()) {
1565 outputShape[0] = valuesShape.getDimSize(0);
1566 outputShape[2] = valuesShape.getDimSize(2);
1569 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1570 if (indicesShape.hasRank()) {
1571 if (outputShape[0] == ShapedType::kDynamic)
1572 outputShape[0] = indicesShape.getDimSize(0);
1573 if (outputShape[1] == ShapedType::kDynamic)
1574 outputShape[1] = indicesShape.getDimSize(1);
1581 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1582 MLIRContext *context, ::std::optional<Location> location,
1583 ResizeOp::Adaptor adaptor,
1586 outputShape.resize(4, ShapedType::kDynamic);
1589 if (!inputShape.hasRank())
1592 outputShape[0] = inputShape.getDimSize(0);
1593 outputShape[3] = inputShape.getDimSize(3);
1594 int64_t inputHeight = inputShape.getDimSize(1);
1595 int64_t inputWidth = inputShape.getDimSize(2);
1597 if ((inputHeight == ShapedType::kDynamic) ||
1598 (inputWidth == ShapedType::kDynamic))
1607 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1612 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1620 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1621 MLIRContext *context, ::std::optional<Location> location,
1622 ScatterOp::Adaptor adaptor,
1625 outputShape.resize(3, ShapedType::kDynamic);
1627 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1628 if (valuesInShape.hasRank()) {
1629 outputShape[0] = valuesInShape.getDimSize(0);
1630 outputShape[1] = valuesInShape.getDimSize(1);
1631 outputShape[2] = valuesInShape.getDimSize(2);
1634 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1635 if (indicesShape.hasRank()) {
1636 if (outputShape[0] == ShapedType::kDynamic)
1637 outputShape[0] = indicesShape.getDimSize(0);
1641 if (inputShape.hasRank()) {
1642 if (outputShape[0] == ShapedType::kDynamic)
1643 outputShape[0] = inputShape.getDimSize(0);
1644 if (outputShape[2] == ShapedType::kDynamic)
1645 outputShape[2] = inputShape.getDimSize(2);
1655 int64_t axisVal = axis.getValue().getSExtValue();
1656 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
1662 operandShape.
getDims(outputShape);
1663 outputShape[axisVal] = 1;
1668 #define COMPATIBLE_RETURN_TYPES(OP) \
1669 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1670 if (l.size() != r.size() || l.size() != 1) \
1672 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1674 return succeeded(verifyCompatibleShape(l[0], r[0])); \
1677 #define REDUCE_SHAPE_INFER(OP) \
1678 LogicalResult OP::inferReturnTypeComponents( \
1679 MLIRContext *context, ::std::optional<Location> location, \
1680 OP::Adaptor adaptor, \
1681 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1683 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1684 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1685 const Properties &prop = adaptor.getProperties(); \
1686 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1687 inferredReturnShapes); \
1689 COMPATIBLE_RETURN_TYPES(OP)
1697 #undef REDUCE_SHAPE_INFER
1699 #undef COMPATIBLE_RETURN_TYPES
1701 template <
typename T>
1704 TensorType inputType = op.getInput().getType();
1705 TensorType outputType = op.getOutput().getType();
1706 int32_t reduceAxis = op.getAxis();
1708 if (reduceAxis < 0) {
1709 op.emitOpError(
"reduce axis must not be negative");
1713 int64_t inputRank = inputType.getRank();
1716 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1717 op.emitOpError(
"expect input tensor rank (")
1718 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
1724 int64_t outputRank = outputType.getRank();
1725 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
1727 "expect output tensor rank to be equal to input tensor rank");
1730 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1731 op.emitOpError(
"expect output tensor rank (")
1732 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
1738 if (outputRank != 0) {
1739 auto outputShape = outputType.
getShape();
1740 if (!outputType.isDynamicDim(reduceAxis) &&
1741 outputShape[reduceAxis] != 1) {
1742 op.emitOpError(
"expect reduced dimension size to be 1, got ")
1743 << outputShape[reduceAxis];
1770 #define NARY_SHAPE_INFER(OP) \
1771 LogicalResult OP::inferReturnTypeComponents( \
1772 MLIRContext *context, ::std::optional<Location> location, \
1773 ValueShapeRange operands, DictionaryAttr attributes, \
1774 OpaqueProperties properties, RegionRange regions, \
1775 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1776 return NAryInferReturnTypes(operands, inferredReturnShapes); \
1818 #undef PRED_SHAPE_INFER
1825 outputShape.resize(4, ShapedType::kDynamic);
1840 if (!ShapedType::isDynamic(height)) {
1841 int64_t padded = height + pad[0] + pad[1] - kernel[0];
1842 outputShape[1] = padded / stride[0] + 1;
1845 if (!ShapedType::isDynamic(width)) {
1846 int64_t padded = width + pad[2] + pad[3] - kernel[1];
1847 outputShape[2] = padded / stride[1] + 1;
1854 LogicalResult Conv2DOp::inferReturnTypeComponents(
1855 MLIRContext *context, ::std::optional<Location> location,
1856 Conv2DOp::Adaptor adaptor,
1860 int64_t inputWidth = ShapedType::kDynamic;
1861 int64_t inputHeight = ShapedType::kDynamic;
1862 int64_t weightWidth = ShapedType::kDynamic;
1863 int64_t weightHeight = ShapedType::kDynamic;
1868 if (inputShape.hasRank()) {
1869 outputShape[0] = inputShape.getDimSize(0);
1870 inputHeight = inputShape.getDimSize(1);
1871 inputWidth = inputShape.getDimSize(2);
1875 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1876 if (weightShape.hasRank()) {
1877 outputShape[3] = weightShape.getDimSize(0);
1878 weightHeight = weightShape.getDimSize(1);
1879 weightWidth = weightShape.getDimSize(2);
1884 if (biasShape.hasRank()) {
1885 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1886 ? biasShape.getDimSize(0)
1894 if (!ShapedType::isDynamic(inputHeight) &&
1895 !ShapedType::isDynamic(weightHeight)) {
1896 int64_t inputSize = inputHeight + padding[0] + padding[1];
1897 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1898 int64_t unstridedResult = inputSize - filterSize + 1;
1899 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1902 if (!ShapedType::isDynamic(inputWidth) &&
1903 !ShapedType::isDynamic(weightWidth)) {
1904 int64_t inputSize = inputWidth + padding[2] + padding[3];
1905 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1906 int64_t unstridedResult = inputSize - filterSize + 1;
1907 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1920 LogicalResult Conv3DOp::inferReturnTypeComponents(
1921 MLIRContext *context, ::std::optional<Location> location,
1922 Conv3DOp::Adaptor adaptor,
1926 int64_t inputWidth = ShapedType::kDynamic;
1927 int64_t inputHeight = ShapedType::kDynamic;
1928 int64_t inputDepth = ShapedType::kDynamic;
1930 int64_t weightWidth = ShapedType::kDynamic;
1931 int64_t weightHeight = ShapedType::kDynamic;
1932 int64_t weightDepth = ShapedType::kDynamic;
1936 if (inputShape.hasRank()) {
1937 outputShape[0] = inputShape.getDimSize(0);
1938 inputDepth = inputShape.getDimSize(1);
1939 inputHeight = inputShape.getDimSize(2);
1940 inputWidth = inputShape.getDimSize(3);
1944 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1945 if (weightShape.hasRank()) {
1946 outputShape[4] = weightShape.getDimSize(0);
1947 weightDepth = weightShape.getDimSize(1);
1948 weightHeight = weightShape.getDimSize(2);
1949 weightWidth = weightShape.getDimSize(3);
1954 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1955 outputShape[4] = biasShape.getDimSize(0);
1962 if (!ShapedType::isDynamic(inputDepth) &&
1963 !ShapedType::isDynamic(weightDepth)) {
1964 int32_t inputSize = inputDepth + pad[0] + pad[1];
1965 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1966 int32_t unstridedResult = inputSize - filterSize + 1;
1967 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1970 if (!ShapedType::isDynamic(inputHeight) &&
1971 !ShapedType::isDynamic(weightHeight)) {
1972 int32_t inputSize = inputHeight + pad[2] + pad[3];
1973 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1974 int32_t unstridedResult = inputSize - filterSize + 1;
1975 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1978 if (!ShapedType::isDynamic(inputWidth) &&
1979 !ShapedType::isDynamic(weightWidth)) {
1980 int32_t inputSize = inputWidth + pad[4] + pad[5];
1981 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1982 int32_t unstridedResult = inputSize - filterSize + 1;
1983 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1996 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1997 MLIRContext *context, ::std::optional<Location> location,
1998 AvgPool2dOp::Adaptor adaptor,
2001 const Properties &prop = adaptor.getProperties();
2003 inferredReturnShapes);
2006 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2007 MLIRContext *context, ::std::optional<Location> location,
2008 MaxPool2dOp::Adaptor adaptor,
2011 const Properties &prop = adaptor.getProperties();
2013 inferredReturnShapes);
2016 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2017 MLIRContext *context, ::std::optional<Location> location,
2018 DepthwiseConv2DOp::Adaptor adaptor,
2022 int64_t inputWidth = ShapedType::kDynamic;
2023 int64_t inputHeight = ShapedType::kDynamic;
2024 int64_t inputChannels = ShapedType::kDynamic;
2026 int64_t weightWidth = ShapedType::kDynamic;
2027 int64_t weightHeight = ShapedType::kDynamic;
2028 int64_t depthChannels = ShapedType::kDynamic;
2032 if (inputShape.hasRank()) {
2033 outputShape[0] = inputShape.getDimSize(0);
2034 inputHeight = inputShape.getDimSize(1);
2035 inputWidth = inputShape.getDimSize(2);
2036 inputChannels = inputShape.getDimSize(3);
2040 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2041 if (weightShape.hasRank()) {
2042 weightHeight = weightShape.getDimSize(0);
2043 weightWidth = weightShape.getDimSize(1);
2044 inputChannels = ShapedType::isDynamic(inputChannels)
2045 ? weightShape.getDimSize(2)
2047 depthChannels = weightShape.getDimSize(3);
2052 if (!ShapedType::isDynamic(inputChannels) &&
2053 !ShapedType::isDynamic(depthChannels)) {
2054 outputShape[3] = inputChannels * depthChannels;
2059 if (biasShape.hasRank()) {
2060 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2061 ? biasShape.getDimSize(0)
2069 if (!ShapedType::isDynamic(inputHeight) &&
2070 !ShapedType::isDynamic(weightHeight)) {
2071 int64_t inputSize = inputHeight + padding[0] + padding[1];
2072 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2073 int64_t unstridedResult = inputSize - filterSize + 1;
2074 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2077 if (!ShapedType::isDynamic(inputWidth) &&
2078 !ShapedType::isDynamic(weightWidth)) {
2079 int64_t inputSize = inputWidth + padding[2] + padding[3];
2080 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2081 int64_t unstridedResult = inputSize - filterSize + 1;
2082 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2095 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2096 MLIRContext *context, ::std::optional<Location> location,
2097 TransposeConv2DOp::Adaptor adaptor,
2103 int64_t inputWidth = ShapedType::kDynamic;
2104 int64_t inputHeight = ShapedType::kDynamic;
2105 int64_t weightWidth = ShapedType::kDynamic;
2106 int64_t weightHeight = ShapedType::kDynamic;
2110 if (inputShape.hasRank()) {
2111 outputShape[0] = ShapedType::isDynamic(outputShape[0])
2112 ? inputShape.getDimSize(0)
2114 inputHeight = inputShape.getDimSize(1);
2115 inputWidth = inputShape.getDimSize(2);
2119 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2120 if (weightShape.hasRank()) {
2121 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2122 ? weightShape.getDimSize(0)
2124 weightHeight = weightShape.getDimSize(1);
2125 weightWidth = weightShape.getDimSize(2);
2130 if (biasShape.hasRank()) {
2131 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2132 ? biasShape.getDimSize(0)
2139 if (!ShapedType::isDynamic(inputHeight) &&
2140 !ShapedType::isDynamic(weightHeight)) {
2141 int64_t calculateSize =
2142 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2144 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2147 if (!ShapedType::isDynamic(inputWidth) &&
2148 !ShapedType::isDynamic(weightWidth)) {
2149 int64_t calculateSize =
2150 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2152 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2165 LogicalResult IfOp::inferReturnTypeComponents(
2166 MLIRContext *context, ::std::optional<Location> location,
2167 IfOp::Adaptor adaptor,
2170 for (
Region *region : adaptor.getRegions()) {
2171 for (
auto &block : *region)
2172 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2173 yieldOps.push_back(returnOp);
2176 if (yieldOps.empty())
2181 resultKnowledge.reserve(yieldOps.front().getNumOperands());
2182 for (
auto operand : yieldOps.front().getOperands()) {
2183 resultKnowledge.push_back(
2187 for (
auto yieldOp : yieldOps) {
2188 if (resultKnowledge.size() != yieldOp.getNumOperands())
2192 int32_t index = it.index();
2194 resultKnowledge[index],
2198 resultKnowledge[index] = meet;
2203 inferredReturnShapes.push_back(result.getShapedTypeComponents());
2209 LogicalResult WhileOp::inferReturnTypeComponents(
2210 MLIRContext *context, ::std::optional<Location> location,
2211 WhileOp::Adaptor adaptor,
2214 for (
auto &block : adaptor.getBody())
2215 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2216 yieldOps.push_back(returnOp);
2220 if (yieldOps.empty())
2225 resultKnowledge.reserve(yieldOps.front().getNumOperands());
2226 for (
auto operand : yieldOps.front().getOperands()) {
2227 resultKnowledge.push_back(
2231 for (
auto yieldOp : yieldOps) {
2232 if (resultKnowledge.size() != yieldOp.getNumOperands())
2236 int32_t index = it.index();
2238 resultKnowledge[index],
2240 resultKnowledge[index] = meet;
2246 inferredReturnShapes.push_back(result.getShapedTypeComponents());
2252 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
2253 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
2254 return llvm::to_vector<4>(vt.getShape());
2255 return std::nullopt;
2292 bool printBlockTerminators =
false;
2294 p <<
" " << getCond();
2295 if (!getResults().empty()) {
2296 p <<
" -> (" << getResultTypes() <<
")";
2298 printBlockTerminators =
true;
2303 printBlockTerminators);
2306 auto &elseRegion = getElseBranch();
2307 if (!elseRegion.
empty()) {
2311 printBlockTerminators);
2318 TensorType inputType = getInput1().getType();
2319 TensorType outputType = getOutput().getType();
2320 int32_t reverseAxis = getAxis();
2322 if (reverseAxis < 0)
2323 return emitOpError(
"expected non-negative reverse axis");
2325 int64_t inputRank = inputType.getRank();
2328 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
2329 return emitOpError(
"expect input tensor rank (")
2330 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
2334 int64_t outputRank = outputType.getRank();
2335 if (inputType.
hasRank() && outputRank != inputType.getRank())
2337 "expect output tensor rank to be equal to input tensor rank");
2338 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
2339 return emitOpError(
"expect output tensor rank (")
2340 << outputRank <<
") to be larger than reverse axis ("
2341 << reverseAxis <<
")";
2358 FunctionType functionType;
2363 result.
addTypes(functionType.getResults());
2365 if (functionType.getNumInputs() != operands.size()) {
2367 <<
"expected as many input types as operands "
2368 <<
"(expected " << operands.size() <<
" got "
2369 << functionType.getNumInputs() <<
")";
2379 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
2380 regionArgs[i].type = functionType.getInput(i);
2382 return failure(parser.
parseRegion(*cond, regionArgs) ||
2390 StringRef prefix =
"") {
2391 assert(blocksArgs.size() == initializers.size() &&
2392 "expected same length of arguments and initializers");
2393 if (initializers.empty())
2396 parser << prefix <<
'(';
2397 llvm::interleaveComma(
2398 llvm::zip(blocksArgs, initializers), parser,
2399 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
2416 Type zpElemType = zpAttr.getElementType();
2417 if (
auto quantType =
2418 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
2419 zp = quantType.getZeroPoint();
2422 if (llvm::isa<FloatType>(zpElemType)) {
2424 if (!zpAttr.getValues<APFloat>()[0].isZero())
2429 if (llvm::isa<IntegerType>(zpElemType)) {
2430 zp = zpAttr.getValues<APInt>()[0].getSExtValue();
2443 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
2444 srcElemType = quantType.getStorageType();
2446 if (llvm::isa<FloatType>(srcElemType)) {
2448 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
2449 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
2451 if (llvm::isa<IntegerType>(srcElemType)) {
2454 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
2456 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
2457 return std::nullopt;
2465 return mlir::isa<tosa::shapeType>(t);
2472 return emitError() <<
"invalid rank (must be >= 0): " << rank;
2478 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
2479 Operation *definingOp = v.getDefiningOp();
2481 return op->
emitOpError(
"shape operand is not compile time resolvable");
2490 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2491 return op->
emitOpError(
"must have operands with tosa shape type");
2495 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2496 return op->
emitOpError(
"must have result with tosa shape type");
2509 auto getRank = [](
const Type type) {
2510 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
2516 for (
auto type : operandTypes) {
2517 if (getRank(type) != rank) {
2518 return op->
emitOpError(
"operands don't have matching ranks");
2521 for (
auto type : resultTypes) {
2522 if (getRank(type) != rank) {
2523 return op->
emitOpError(
"result shape has different rank than operands");
2535 auto valuesRank = getValue().getType().getRank();
2536 if (valuesRank != 1)
2537 return emitOpError(
"expect elements in attribute value with rank 1");
2539 auto count = getValue().getNumElements();
2540 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
2541 if (!(count == rank || (count == 1 && rank == 0))) {
2542 return emitOpError(
"expect number of elements in attribute value (")
2543 << count <<
") to be equal to the rank (" << rank
2544 <<
") for the result shape type";
2553 #define GET_ATTRDEF_CLASSES
2554 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2559 #define GET_TYPEDEF_CLASSES
2560 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
2566 #define GET_OP_CLASSES
2567 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
bool getConstShapeValue(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)