29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(
Dialect *dialect)
89 LogicalResult writeAttribute(
Attribute attr,
91 return ::writeAttribute(attr, writer);
101 LogicalResult writeType(
Type type,
103 return ::writeType(type, writer);
110 std::unique_ptr<DialectVersion>
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
131 return {&getBodyGraph()};
138 void TosaDialect::initialize() {
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
151 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152 declarePromisedInterfaces<
153 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159 GreaterEqualOp, MatMulOp>();
166 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167 return builder.
create<tosa::ConstShapeOp>(
168 loc, type, llvm::cast<DenseIntElementsAttr>(value));
170 if (llvm::isa<ElementsAttr>(value))
171 return builder.
create<tosa::ConstOp>(loc, type,
172 llvm::cast<ElementsAttr>(value));
185 <<
"expected attribute";
187 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
204 bool needsSpace =
false;
205 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206 if (!typedAttr || typedAttr.getType() != type.getValue()) {
221 Value src, int32_t val) {
226 const auto padConstAttr{
227 llvm::isa<FloatType>(srcElemType)
232 return builder.
create<tosa::ConstOp>(loc, padConstType, padConstAttr);
239 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
251 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
252 elementType = quantType.getStorageType();
258 Value valZp, StringRef name) {
263 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
267 if (!bothInts || !sameBitWidth) {
269 <<
"expected " << name <<
" and " << name
270 <<
"_zp to both be integer of the same bitwidth, but got " << eType
271 <<
" vs. " << eZpType;
280 template <
typename T>
284 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
286 op.emitOpError(
"expect a ranked tensor for input, got ") << op.getInput();
290 auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
292 op.emitOpError(
"expect a ranked tensor for weight, got ") << op.getWeight();
296 auto inputEType = inputType.getElementType();
297 auto weightEType = weightType.getElementType();
299 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
301 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
302 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
303 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
305 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
306 inputEType = quantType.getStorageType();
308 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
309 weightEType = quantType.getStorageType();
311 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
312 biasEType = quantType.getStorageType();
314 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
315 resultEType = quantType.getStorageType();
317 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
321 "expect both bias and result to have same element type, got ")
322 << biasEType <<
" and " << resultEType;
326 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
327 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
328 if (inputEType != weightEType) {
330 "expect both input and weight to have same element type, got ")
331 << inputEType <<
" and " << weightEType;
336 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
337 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
340 if (inputIsFloat != weightIsFloat) {
342 "expect both input and weight to be float or not together, got ")
343 << inputEType <<
" and " << weightEType;
348 if (inputEType != inputZpEType) {
349 return op.emitOpError(
"expect both input and its zero point are the same "
350 "element type, got ")
351 << inputEType <<
" and " << inputZpEType;
355 if (weightEType != weightZpEType) {
356 return op.emitOpError(
"expect both weight and its zero point are the same "
357 "element type, got ")
358 << weightEType <<
" and " << weightZpEType;
361 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
362 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
365 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
366 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
374 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
375 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
377 if (!attrType || !outputType) {
378 emitOpError(
"expected tensors for attr/result type");
382 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
383 outputType.getElementType())) {
384 if (result.getStorageType() == attrType.getElementType())
388 if (attrType.getElementType() != outputType.getElementType()) {
389 emitOpError(
"expected same attr/result element types");
396 template <
typename T>
399 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
401 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
402 inputEType = quantType.getStorageType();
404 auto accType = op.getAccType();
405 if (inputEType.isInteger(8) && !accType.isInteger(32))
406 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
408 if (inputEType.isInteger(16) && !accType.isInteger(48))
409 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
411 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
412 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
414 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
415 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
417 if (inputEType.isBF16() && !accType.isF32())
418 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
420 if (inputEType.isF32() && !accType.isF32())
421 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
424 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
426 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
427 resultEType = quantType.getStorageType();
430 if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
431 (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
432 (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
433 (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
434 (inputEType.isF16() && resultEType.isF16()) ||
435 (inputEType.isBF16() && resultEType.isBF16()) ||
436 (inputEType.isF32() && resultEType.isF32()))
439 return op.emitOpError(
"input/output element types are incompatible.");
443 template <
typename T>
445 auto inputType = llvm::dyn_cast<TensorType>(inType);
446 auto outputType = llvm::dyn_cast<TensorType>(outType);
448 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
452 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
455 auto inputElementType = inputType.getElementType();
456 auto outputElementType = outputType.getElementType();
457 auto inputQuantType =
458 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
459 auto outputQuantType =
460 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
461 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
462 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
463 inputElementType != outputElementType) {
468 op.emitOpError(
"expect input and output to have same element type, got ")
469 << inputElementType <<
" and " << outputElementType;
476 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
479 if (
const auto resultETy = resultType.getElementType();
480 !resultETy.isIntOrIndex())
481 return emitOpError(
"result tensor is not of integer type");
483 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
484 if (!inputType.hasRank())
488 const int64_t axis = getAxisAttr().getInt();
489 if (((axis < 0) || axis >= inputType.getRank()))
490 return emitOpError(
"specified axis is outside the rank of the tensor");
492 if (!resultType.hasRank())
499 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
501 return emitOpError(
"expected output shape '")
502 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
507 template <
typename T>
510 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
511 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
515 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
516 return op.emitOpError(
"expect all stride values to be >= 1, got ")
520 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
521 return op.emitOpError(
"expect all padding values to be >= 0, got ")
525 const int64_t kernelX = kernel[1];
526 const int64_t padLeft = padding[2];
527 const int64_t padRight = padding[3];
528 if (padRight >= kernelX || padLeft >= kernelX)
529 return op.emitOpError(
"expected left/right padding to be less than the "
530 "width of the kernel, got pad_left=")
531 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
533 const int64_t kernelY = kernel[0];
534 const int64_t padTop = padding[0];
535 const int64_t padBottom = padding[1];
536 if (padTop >= kernelY || padBottom >= kernelY)
537 return op.emitOpError(
"expected top/bottom padding to be less than the "
538 "height of the kernel, got pad_top=")
539 << padTop <<
", pad_bottom=" << padBottom
540 <<
", kernel_y=" << kernelY;
542 const auto inputType =
543 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
544 const auto outputType =
545 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
546 if (!inputType || !outputType)
549 const auto verifyOutputSize =
550 [&op](
const int64_t inputSize,
const int64_t outputSize,
551 const int64_t kernelSize,
const int64_t strideSize,
552 const int64_t padBefore,
const int64_t padAfter,
553 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
554 const llvm::StringRef padBeforeName,
555 const llvm::StringRef padAfterName) -> LogicalResult {
556 if (ShapedType::isDynamic(inputSize))
559 const std::optional<int64_t> calculatedOutSizeMinusOne =
560 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
561 if (!calculatedOutSizeMinusOne.has_value())
562 return op.emitOpError(
"expected input_")
563 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
564 << padAfterName <<
" - kernel_" << dimAxis
565 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
566 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
567 << kernelSize <<
") / " << strideSize;
569 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
570 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
571 return op.emitOpError(
"calculated output ")
572 << dimName <<
" did not match expected: "
573 <<
"calculated=" << calculatedOutSize
574 <<
", expected=" << outputSize;
579 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
580 kernel[0], strides[0], padding[0], padding[1],
581 "height",
"y",
"top",
"bottom")))
584 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
585 kernel[1], strides[1], padding[2], padding[3],
586 "width",
"x",
"left",
"right")))
601 auto accType = getAccType();
602 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
603 return emitOpError(
"accumulator type for integer tensor is not i32");
605 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
606 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
608 if (inputETy.
isBF16() && !accType.isF32())
609 return emitOpError(
"accumulator type for bf16 tensor is not f32");
611 if (inputETy.
isF32() && !accType.isF32())
612 return emitOpError(
"accumulator type for f32 tensor is not f32");
614 if (inputETy != inputZpETy)
615 return emitOpError(
"expect both input and its zero point are the same "
616 "element type, got ")
617 << inputETy <<
" and " << inputZpETy;
619 if (resultETy != outputZpETy)
620 return emitOpError(
"expect both output and its zero point are the same "
621 "element type, got ")
622 << resultETy <<
" and " << outputZpETy;
624 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
625 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
628 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
629 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
637 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
639 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
640 inputETy = quantType.getStorageType();
643 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
645 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
646 outputETy = quantType.getStorageType();
648 if (inputETy != outputETy)
649 return emitOpError(
"input/output element types are incompatible.");
651 auto maxValAttr = getMaxValAttr();
652 auto minValAttr = getMinValAttr();
656 if (inputETy.
isInteger(dataTypeBitWidth)) {
660 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
661 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
662 if (!intMaxValAttr || !intMinValAttr ||
663 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
664 (intMaxValAttr.getType() != inputETy))
665 return emitOpError(
"min/max attributes types are incompatible with "
666 "input/output element types.");
671 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
672 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
673 if (!floatMaxValAttr || !floatMinValAttr ||
674 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
675 (floatMaxValAttr.getType() != inputETy))
676 return emitOpError(
"min/max attributes types are incompatible with "
677 "input/output element types.");
697 result.
addOperands({input, weight, bias, zps.first, zps.second});
702 Type finalOutputType = outputType;
719 result.
addOperands({input, weight, bias, zps.first, zps.second});
723 Type finalOutputType = outputType;
742 Type finalOutputType{outputType};
745 auto inputBits = eType.getIntOrFloatBitWidth();
747 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
748 assert(outputShapedType &&
"Output must be a shaped type");
750 IntegerType accElementType;
756 finalOutputType = outputShapedType.clone(accElementType);
767 DenseArrayAttr kernel, DenseArrayAttr stride,
768 DenseArrayAttr pad, TypeAttr accType) {
775 inputZp = quantAttr.getInputZp();
776 outputZp = quantAttr.getOutputZp();
778 const std::optional<Value> inputZpOp =
783 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
785 const std::optional<Value> outputZpOp =
788 (void)
emitError(loc,
"Failed to create output zero point tensor for "
789 "quantized AVG_POOL2D op");
792 if (inputZpOp && outputZpOp) {
793 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
804 result.
types.push_back(outputType);
818 input1Zp = quantAttr.getInputZp();
819 outputZp = quantAttr.getOutputZp();
821 const std::optional<Value> input1ZpOp =
825 loc,
"Failed to create input1 zero point for quantized NEGATE op");
828 const std::optional<Value> outputZpOp =
832 loc,
"Failed to create output zero point for quantized NEGATE op");
835 if (input1ZpOp && outputZpOp) {
836 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
844 result.
types.push_back(outputType);
857 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
861 result.
types.push_back(outputType);
871 for (
int i = 0, e = operands.size(); i != e; ++i) {
873 if (!shape.hasRank()) {
878 outRank = std::max<int64_t>(outRank, shape.getRank());
881 outShape.resize(outRank, 1);
883 for (
int i = 0, e = operands.size(); i != e; ++i) {
885 auto rankDiff = outShape.size() - shape.getRank();
887 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
888 auto dim1 = outShape[i + rankDiff];
889 auto dim2 = shape.getDimSize(i);
890 auto resolvedDim = dim1;
894 }
else if (dim2 == 1) {
896 }
else if (dim1 != dim2) {
899 outShape[i + rankDiff] = resolvedDim;
906 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
907 MLIRContext *context, ::std::optional<Location> location,
908 ArgMaxOp::Adaptor adaptor,
911 IntegerAttr axis = adaptor.getProperties().axis;
912 int32_t axisVal = axis.getValue().getSExtValue();
914 if (!inputShape.hasRank()) {
920 outShape.reserve(inputShape.getRank() - 1);
921 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
924 outShape.push_back(inputShape.getDimSize(i));
931 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
932 MLIRContext *context, ::std::optional<Location> location,
933 RFFT2dOp::Adaptor adaptor,
935 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
937 if (!inputShape.hasRank())
941 outputShape.resize(3, ShapedType::kDynamic);
942 outputShape[0] = inputShape.getDimSize(0);
943 outputShape[1] = inputShape.getDimSize(1);
944 int64_t inWidth = inputShape.getDimSize(2);
948 if (inWidth != ShapedType::kDynamic)
949 outputShape[2] = inWidth / 2 + 1;
958 const llvm::StringRef dimName) {
959 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
962 << dimName <<
" to be a power of two, got " << dimSize;
968 const auto outputTypes = getResultTypes();
970 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
972 const auto inputType =
973 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
977 const int64_t height = inputType.getDimSize(1);
978 if (!ShapedType::isDynamic(height) &&
982 const int64_t width = inputType.getDimSize(2);
983 if (!ShapedType::isDynamic(width) &&
987 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
993 outputType.getShape().drop_back())))
994 return emitOpError(
"expected batch and height dimensions of input/output "
995 "to match, got input=")
996 << inputType <<
" output=" << outputType;
999 const int64_t outputWidth = outputType.getDimSize(2);
1000 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1001 (outputWidth != (width / 2) + 1))
1003 "expected output width to be equal to input_width / 2 + 1, got ")
1009 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1010 MLIRContext *context, ::std::optional<Location> location,
1011 FFT2dOp::Adaptor adaptor,
1013 inferredReturnShapes.push_back(
1015 inferredReturnShapes.push_back(
1021 const auto inputRealType =
1022 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1023 const auto inputImagType =
1024 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1025 if (!inputRealType || !inputImagType)
1028 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1029 return ShapedType::isDynamic(a) ? a : b;
1032 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1033 inputImagType.getDimSize(1));
1034 if (!ShapedType::isDynamic(height) &&
1038 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1039 inputImagType.getDimSize(2));
1040 if (!ShapedType::isDynamic(width) &&
1047 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1048 MLIRContext *context, ::std::optional<Location> location,
1049 ConcatOp::Adaptor adaptor,
1052 const Properties &prop = adaptor.getProperties();
1053 int32_t axis = prop.axis.getValue().getSExtValue();
1055 bool hasRankedInput =
false;
1056 for (
auto operand : adaptor.getOperands()) {
1058 if (!operandShape.hasRank())
1062 if (!hasRankedInput)
1063 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1066 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1067 if (i == axis || operandShape.isDynamicDim(i))
1069 if (outputShape[i] == ShapedType::kDynamic)
1070 outputShape[i] = operandShape.getDimSize(i);
1071 if (outputShape[i] != operandShape.getDimSize(i))
1073 "Cannot concat tensors with different sizes"
1074 " on the non-axis dimension ",
1078 hasRankedInput =
true;
1081 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1082 if (!hasRankedInput) {
1088 int64_t concatDimSize = 0;
1089 for (
auto operand : adaptor.getOperands()) {
1094 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1095 concatDimSize = ShapedType::kDynamic;
1099 concatDimSize += operandShape.getDimSize(axis);
1102 outputShape[axis] = concatDimSize;
1110 auto outType = getOutput().getType();
1114 if (inputList.empty())
1115 return emitOpError(
"expect at least one input");
1117 if (!llvm::all_of(inputList, [&](
auto input) {
1119 *
this, input.getType(), outType));
1124 const int32_t axis = getAxis();
1126 for (
const auto &input : inputList) {
1127 const Type inputType = input.getType();
1129 if (currShape.hasRank()) {
1130 firstRankedInputShape = currShape;
1132 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1133 return emitOpError(
"expect axis to be within range 0 < axis < "
1134 "rank(input1[firstRankedTensorIdx]), got ")
1140 const auto allOperandsHasRank = [](
const Value input) {
1143 if (llvm::all_of(inputList, allOperandsHasRank)) {
1144 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1146 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1148 const int64_t inputRank = inputShape.getRank();
1149 const size_t operandNum = index + 1;
1152 if (inputRank != firstInputRank)
1154 "expect all operands to have the same rank, but got ")
1155 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1159 for (
int i = 0; i < inputRank; i++) {
1160 const int64_t inputDim = inputShape.getDimSize(i);
1161 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1162 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1163 inputShape.isDynamicDim(i))
1165 if (inputDim != firstInputDim)
1166 return emitOpError(
"expect all operand shapes to have the same sizes "
1167 "on non-axis dimensions, but got ")
1168 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1169 <<
" on operands 0 and " << operandNum;
1177 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1178 MLIRContext *context, ::std::optional<Location> location,
1195 if (l.size() != r.size() || l.size() != 1)
1200 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1201 MLIRContext *context, ::std::optional<Location> location,
1202 MatMulOp::Adaptor adaptor,
1209 outShape.resize(3, ShapedType::kDynamic);
1211 if (lhsShape.hasRank()) {
1212 outShape[0] = lhsShape.getDimSize(0);
1213 outShape[1] = lhsShape.getDimSize(1);
1216 if (rhsShape.hasRank()) {
1217 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1219 outShape[2] = rhsShape.getDimSize(2);
1227 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1228 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1232 return emitOpError(
"expect a shaped tensor for input a, got ")
1233 << getA().getType();
1236 return emitOpError(
"expect a shaped tensor for input b, got ")
1237 << getB().getType();
1239 auto aElementType = aType.getElementType();
1240 auto bElementType = bType.getElementType();
1242 auto aQuantizedEType =
1243 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1244 auto bQuantizedEType =
1245 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1247 if (aQuantizedEType || bQuantizedEType) {
1248 if (!aQuantizedEType || !bQuantizedEType) {
1249 return emitOpError(
"expect operands to be both quantized or both not "
1251 << aElementType <<
" and " << bElementType;
1254 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1255 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1256 if (aQuantWidth != bQuantWidth) {
1257 return emitOpError(
"expect quantized operands to have same widths, got ")
1258 << aQuantWidth <<
" and " << bQuantWidth;
1262 if (aElementType != bElementType) {
1263 return emitOpError(
"expect same element type for inputs a and b, got ")
1264 << aElementType <<
" and " << bElementType;
1271 if (aEType != aZpEType) {
1272 return emitOpError(
"expect input a and a_zp have the same "
1273 "element type, got ")
1274 << aEType <<
" and " << aZpEType;
1279 if (bEType != bZpEType) {
1280 return emitOpError(
"expect input b and b_zp have the same "
1281 "element type, got ")
1282 << bEType <<
" and " << bZpEType;
1285 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1286 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1289 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1290 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1296 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1297 MLIRContext *context, ::std::optional<Location> location,
1298 PadOp::Adaptor adaptor,
1300 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1302 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1307 if (!inputShape.hasRank()) {
1308 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1317 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1322 outputShape.reserve(inputShape.getRank());
1323 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1324 if (inputShape.isDynamicDim(i)) {
1325 outputShape.push_back(ShapedType::kDynamic);
1328 auto padFront = paddingValues[i * 2];
1329 auto padBack = paddingValues[i * 2 + 1];
1330 if (padFront < 0 || padBack < 0) {
1332 outputShape.push_back(ShapedType::kDynamic);
1336 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1350 if (
auto padConst = getPadConst()) {
1358 RankedTensorType inputType = getInput1().getType();
1359 RankedTensorType outputType = getOutput().getType();
1360 auto paddingRank = cast<tosa::shapeType>(getPadding().
getType()).getRank();
1362 if (inputType.getRank() != outputType.getRank())
1363 return emitOpError() <<
"expect same input and output tensor rank.";
1365 if (paddingRank != inputType.getRank() * 2)
1366 return emitOpError() <<
"expected padding tensor dim 0 to have size "
1367 << inputType.getRank() * 2
1368 <<
" (2*rank(shape1)) but got size " << paddingRank;
1374 return to_vector(llvm::map_range(shape, [](int64_t dim) {
1375 return dim == -1 ? ShapedType::kDynamic : dim;
1379 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1380 MLIRContext *context, ::std::optional<Location> location,
1381 SliceOp::Adaptor adaptor,
1390 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1398 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1401 if (inputShape.hasRank()) {
1402 for (
size_t i = 0; i < size.size(); i++) {
1403 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1404 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1405 start[i] < inputShape.getDimSize(i))) {
1407 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1410 outputShape[i] = size[i];
1414 if (size[i] == -1) {
1415 outputShape[i] = inputShape.getDimSize(i) - start[i];
1416 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1418 outputShape[i] = size[i];
1435 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1439 auto startShapeRank =
1440 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1441 if (inputType.getRank() != startShapeRank)
1442 return emitOpError(
"length of start is not equal to rank of input shape");
1444 auto sizeShapeRank =
1445 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1446 if (inputType.getRank() != sizeShapeRank)
1447 return emitOpError(
"length of size is not equal to rank of input shape");
1452 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1453 MLIRContext *context, ::std::optional<Location> location,
1473 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1474 IntegerType lhsIntType =
1476 IntegerType rhsIntType =
1478 if (lhsIntType != rhsIntType)
1479 return emitOpError(
"requires the same element type for all operands");
1484 if (lhsIntType.getWidth() > resIntType.getWidth())
1485 return emitOpError(
"invalid data type size for operands or result");
1490 for (
int i = 0; i < 2; ++i) {
1493 "requires the same element type for all operands and results");
1497 ElementsAttr shift_elem;
1499 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1501 return emitOpError() <<
"require shift to be 0 for float type";
1512 auto hasRank = [](
const Type type) {
1513 if (
auto shaped_type = dyn_cast<ShapedType>(type))
1514 return shaped_type.hasRank();
1519 auto rankedOperandTypes =
1520 llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1522 auto rankedResultTypes =
1523 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1526 if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1530 auto getRank = [](
const Type type) {
1531 return cast<ShapedType>(type).getRank();
1534 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1535 : getRank(*rankedResultTypes.begin());
1537 for (
size_t i = 0; i < 2; ++i) {
1538 if (rank != getRank(rankedOperandTypes[i])) {
1539 return emitOpError(
"operands don't have matching ranks");
1543 for (
const auto type : rankedResultTypes) {
1544 if (rank != getRank(type)) {
1545 return emitOpError(
"result type has different rank than operands");
1554 return mlir::cast<ShapedType>(type).getShape();
1560 return emitOpError(
"operands don't have broadcast-compatible shapes");
1566 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1567 MLIRContext *context, ::std::optional<Location> location,
1568 TableOp::Adaptor adaptor,
1570 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1572 if (!inputShape.hasRank()) {
1577 inferredReturnShapes.resize(1);
1578 inputShape.getDims(inferredReturnShapes[0]);
1583 TensorType inputType = getInput1().getType();
1584 TensorType outputType = getOutput().getType();
1587 inputType.getRank() != outputType.getRank())
1588 return emitOpError()
1589 <<
"expected input tensor rank to equal result tensor rank";
1591 auto inputDims = inputType.
getShape();
1592 auto outputDims = outputType.
getShape();
1594 int64_t dim = it.index();
1595 auto [inputDim, outputDim] = it.value();
1596 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1597 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1598 <<
" doesn't match dim(input, " << dim
1599 <<
") = " << inputDim;
1611 multiples = llvm::to_vector(
1612 llvm::map_range(multiplesAttr.getValues<APInt>(),
1613 [](
const APInt &val) { return val.getSExtValue(); }));
1617 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1618 MLIRContext *context, ::std::optional<Location> location,
1619 TileOp::Adaptor adaptor,
1626 llvm::map_range(multiplesAttr.getValues<APInt>(),
1627 [](
const APInt &val) { return val.getSExtValue(); }));
1629 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1631 if (!inputShape.hasRank()) {
1632 outputShape.resize(multiples.size(), ShapedType::kDynamic);
1635 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
1639 outputShape.reserve(multiples.size());
1640 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1641 int64_t dim = inputShape.getDimSize(i);
1642 if (dim != ShapedType::kDynamic)
1643 dim *= multiples[i];
1644 outputShape.push_back(dim);
1657 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
1658 ShapedType outputType = llvm::cast<ShapedType>(
getType());
1660 shapeType multiplesType =
1661 llvm::cast<tosa::shapeType>(getMultiples().
getType());
1663 auto multiplesRank = multiplesType.getRank();
1665 if (inputType.hasRank()) {
1666 if (inputType.getRank() != multiplesRank)
1667 return emitOpError(
"expect 'multiples' to have rank ")
1668 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
1669 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1670 return emitOpError(
"expect same input and output tensor rank.");
1671 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1672 return emitOpError(
"expect 'multiples' array to have length ")
1673 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
1676 if (getConstantMultiples(multiples).succeeded() &&
1677 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
1679 "expect element of 'multiples' to be positive integer or -1.");
1685 if (l.size() != r.size() || l.size() != 1)
1690 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1691 MLIRContext *context, ::std::optional<Location> location,
1692 ReshapeOp::Adaptor adaptor,
1694 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1699 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1709 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1710 inferredReturnShapes.push_back(
1718 int64_t numElements = inputShape.getNumElements();
1719 int64_t staticMul = 1;
1720 for (
auto val : newShapeValue) {
1721 if (!ShapedType::isDynamic(val)) {
1727 for (
auto &val : newShapeValue) {
1728 if (ShapedType::isDynamic(val))
1729 val = numElements / staticMul;
1732 inferredReturnShapes.push_back(
1743 TensorType inputType = getInput1().getType();
1744 RankedTensorType outputType =
getType();
1749 return mlir::success();
1752 if ((int64_t)shapeValues.size() != outputType.getRank())
1753 return emitOpError() <<
"new shape does not match result rank";
1755 for (
auto [newShapeDim, outputShapeDim] :
1756 zip(shapeValues, outputType.getShape())) {
1757 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1758 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1759 return emitOpError() <<
"new shape is inconsistent with result shape";
1761 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1762 return emitOpError() <<
"new shape has invalid tensor dimension size "
1766 if (inputType.hasStaticShape()) {
1767 int64_t inputElementsNum = inputType.getNumElements();
1768 if (outputType.hasStaticShape()) {
1769 int64_t outputElementsNum = outputType.getNumElements();
1770 if (inputElementsNum != outputElementsNum) {
1771 return emitOpError() <<
"cannot reshape " << inputElementsNum
1772 <<
" elements into " << outputElementsNum;
1776 int64_t newShapeElementsNum = std::accumulate(
1777 shapeValues.begin(), shapeValues.end(), 1LL,
1778 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1779 bool isStaticNewShape =
1780 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
1781 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1782 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1783 return emitOpError() <<
"cannot reshape " << inputElementsNum
1784 <<
" elements into " << newShapeElementsNum;
1788 int missingDims = llvm::count(shapeValues, -1);
1789 if (missingDims > 1)
1790 return emitOpError() <<
"expected at most one target dimension to be -1";
1792 return mlir::success();
1798 template <
typename T>
1800 ElementsAttr zpAttr;
1805 Type zpElemType = zpAttr.getElementType();
1807 if (llvm::isa<FloatType>(zpElemType)) {
1808 if (zpAttr.getValues<APFloat>()[0].isZero()) {
1815 if (llvm::isa<IntegerType>(zpElemType)) {
1816 return zpAttr.getValues<APInt>()[0].getSExtValue();
1823 template <
typename T>
1825 const std::string &operand) {
1828 if (!zpElemType.
isInteger(8) && zp != 0) {
1830 std::string lower = operand;
1831 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
1832 return op.emitOpError()
1833 << lower <<
" zero point must be zero for non-int8 integer types";
1841 const std::string &operand) {
1842 bool isInputZp = (operand ==
"Input");
1844 bool tensorUnsigned =
1845 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
1846 StringRef tensorName = isInputZp ?
"input" :
"output";
1852 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
1853 return op.emitOpError()
1854 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
1856 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
1857 return op.emitOpError() <<
"expect " << tensorName
1858 <<
"_zp of 0 or 32768 for unsigned int16 "
1859 << tensorName <<
", got " << zp;
1866 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
1867 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
1868 return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
1870 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
1871 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
1890 #undef ZERO_POINT_HELPER
1892 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1893 MLIRContext *context, ::std::optional<Location> location,
1894 TransposeOp::Adaptor adaptor,
1896 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1905 const auto inputRank = inputShape.
getRank();
1909 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
1915 if (inputRank == 0) {
1921 bool allTheSame =
true;
1922 for (
int i = 1, s = inputRank; i < s; i++) {
1932 outputShape.resize(inputRank, inputShape.
getDimSize(0));
1937 outputShape.resize(inputRank, ShapedType::kDynamic);
1940 if (llvm::any_of(adaptor.getPerms(),
1941 [inputRank](
const auto i) { return i >= inputRank; }))
1944 outputShape.reserve(inputRank);
1945 for (
int i = 0, s = inputRank; i < s; i++) {
1946 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
1959 TensorType inputType = getInput1().getType();
1960 TensorType outputType = getOutput().getType();
1964 constantPerms.size() !=
static_cast<size_t>(inputType.getRank()))
1965 return emitOpError() <<
"expected perms attribute to have size "
1966 << inputType.getRank() <<
" (input rank) but got size "
1967 << constantPerms.size();
1969 inputType.getRank() != outputType.getRank())
1970 return emitOpError()
1971 <<
"expected input tensor rank to equal result tensor rank";
1973 constantPerms.size() !=
static_cast<size_t>(outputType.getRank()))
1974 return emitOpError() <<
"expected perms attribute to have size "
1975 << outputType.getRank()
1976 <<
" (output rank) but got size "
1977 << constantPerms.size();
1979 if (!llvm::all_of(constantPerms,
1980 [&constantPerms](int32_t s) {
1982 static_cast<size_t>(s) < constantPerms.size();
1985 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
1986 return emitOpError() <<
"expected valid permutation indices";
1991 assert(constantPerms.size() ==
static_cast<size_t>(inputType.getRank()) &&
1992 inputType.getRank() == outputType.getRank());
1994 for (
auto i = 0; i < outputType.getRank(); i++) {
1995 if (inputType.isDynamicDim(constantPerms[i]) ||
1996 outputType.isDynamicDim(i))
1999 if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
2000 return emitOpError()
2001 <<
"expected output tensor dim " << i <<
" to match "
2002 <<
"input dim " << constantPerms[i] <<
" with value of "
2003 << inputType.getDimSize(constantPerms[i]);
2015 Value input = getInput1();
2016 auto inputType = cast<TensorType>(input.
getType());
2019 for (
auto dim : transposePerms) {
2020 int32_t dimInInput = transposePerms[dim];
2021 if (inputType.isDynamicDim(dimInInput))
2023 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
2027 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2030 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2034 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2035 MLIRContext *context, ::std::optional<Location> location,
2036 GatherOp::Adaptor adaptor,
2039 outputShape.resize(3, ShapedType::kDynamic);
2041 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2042 if (valuesShape.hasRank()) {
2043 outputShape[0] = valuesShape.getDimSize(0);
2044 outputShape[2] = valuesShape.getDimSize(2);
2047 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2048 if (indicesShape.hasRank()) {
2049 if (outputShape[0] == ShapedType::kDynamic)
2050 outputShape[0] = indicesShape.getDimSize(0);
2051 if (outputShape[1] == ShapedType::kDynamic)
2052 outputShape[1] = indicesShape.getDimSize(1);
2064 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2065 MLIRContext *context, ::std::optional<Location> location,
2066 ResizeOp::Adaptor adaptor,
2069 outputShape.resize(4, ShapedType::kDynamic);
2072 if (!inputShape.hasRank())
2075 outputShape[0] = inputShape.getDimSize(0);
2076 outputShape[3] = inputShape.getDimSize(3);
2077 int64_t inputHeight = inputShape.getDimSize(1);
2078 int64_t inputWidth = inputShape.getDimSize(2);
2080 if ((inputHeight == ShapedType::kDynamic) ||
2081 (inputWidth == ShapedType::kDynamic))
2096 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2101 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2110 const Value input = getInput();
2111 const Value output = getOutput();
2112 const RankedTensorType inputType =
2113 llvm::dyn_cast<RankedTensorType>(input.
getType());
2114 const RankedTensorType outputType =
2115 llvm::dyn_cast<RankedTensorType>(output.
getType());
2118 return emitOpError(
"expect a ranked input tensor");
2120 return emitOpError(
"expect a ranked output tensor");
2122 const int64_t oh = outputType.getDimSize(1);
2123 const int64_t ow = outputType.getDimSize(2);
2124 const int64_t ih = inputType.getDimSize(1);
2125 const int64_t iw = inputType.getDimSize(2);
2137 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2138 return emitOpError(
"expect all scale values to be > 0, got ")
2141 const int64_t scaleYN = scaleValues[0];
2142 const int64_t scaleYD = scaleValues[1];
2143 const int64_t scaleXN = scaleValues[2];
2144 const int64_t scaleXD = scaleValues[3];
2146 const int64_t offsetY = offsetValues[0];
2147 const int64_t offsetX = offsetValues[1];
2149 const int64_t borderY = borderValues[0];
2150 const int64_t borderX = borderValues[1];
2156 if (ih != ShapedType::kDynamic && ih != 1) {
2157 const std::optional<int64_t> calculatedOutHeightMinusOne =
2158 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2159 if (!calculatedOutHeightMinusOne.has_value())
2160 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2162 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2163 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2164 <<
") / " << scaleYD;
2165 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2166 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2167 return emitOpError(
"calculated output height did not match expected: ")
2168 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2175 if (iw != ShapedType::kDynamic && iw != 1) {
2176 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2177 const std::optional<int64_t> calculatedOutWidthMinusOne =
2179 if (!calculatedOutWidthMinusOne.has_value())
2180 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2182 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2183 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2184 <<
") / " << scaleXD;
2185 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2186 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2187 return emitOpError(
"calculated output width did not match expected: ")
2188 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2194 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2195 MLIRContext *context, ::std::optional<Location> location,
2196 ScatterOp::Adaptor adaptor,
2199 outputShape.resize(3, ShapedType::kDynamic);
2201 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2202 if (valuesInShape.hasRank()) {
2203 outputShape[0] = valuesInShape.getDimSize(0);
2204 outputShape[1] = valuesInShape.getDimSize(1);
2205 outputShape[2] = valuesInShape.getDimSize(2);
2208 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2209 if (indicesShape.hasRank()) {
2210 if (outputShape[0] == ShapedType::kDynamic)
2211 outputShape[0] = indicesShape.getDimSize(0);
2215 if (inputShape.hasRank()) {
2216 if (outputShape[0] == ShapedType::kDynamic)
2217 outputShape[0] = inputShape.getDimSize(0);
2218 if (outputShape[2] == ShapedType::kDynamic)
2219 outputShape[2] = inputShape.getDimSize(2);
2241 int64_t axisVal = axis.getValue().getSExtValue();
2242 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2248 operandShape.
getDims(outputShape);
2249 outputShape[axisVal] = 1;
2254 #define COMPATIBLE_RETURN_TYPES(OP) \
2255 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2256 if (l.size() != r.size() || l.size() != 1) \
2258 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2260 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2263 #define REDUCE_SHAPE_INFER(OP) \
2264 LogicalResult OP::inferReturnTypeComponents( \
2265 MLIRContext *context, ::std::optional<Location> location, \
2266 OP::Adaptor adaptor, \
2267 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2269 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2270 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2271 const Properties &prop = adaptor.getProperties(); \
2272 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2273 inferredReturnShapes); \
2275 COMPATIBLE_RETURN_TYPES(OP)
2283 #undef REDUCE_SHAPE_INFER
2285 #undef COMPATIBLE_RETURN_TYPES
2287 template <
typename T>
2290 TensorType inputType = op.getInput().getType();
2291 TensorType outputType = op.getOutput().getType();
2292 int32_t reduceAxis = op.getAxis();
2294 if (reduceAxis < 0) {
2295 op.emitOpError(
"reduce axis must not be negative");
2299 int64_t inputRank = inputType.getRank();
2302 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2303 op.emitOpError(
"expect input tensor rank (")
2304 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2310 int64_t outputRank = outputType.getRank();
2311 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2313 "expect output tensor rank to be equal to input tensor rank");
2316 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2317 op.emitOpError(
"expect output tensor rank (")
2318 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2324 if (outputRank != 0) {
2325 auto outputShape = outputType.
getShape();
2326 if (!outputType.isDynamicDim(reduceAxis) &&
2327 outputShape[reduceAxis] != 1) {
2328 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2329 << outputShape[reduceAxis];
2356 #define NARY_SHAPE_INFER(OP) \
2357 LogicalResult OP::inferReturnTypeComponents( \
2358 MLIRContext *context, ::std::optional<Location> location, \
2359 ValueShapeRange operands, DictionaryAttr attributes, \
2360 OpaqueProperties properties, RegionRange regions, \
2361 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2362 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2402 #undef PRED_SHAPE_INFER
2404 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2405 MLIRContext *context, ::std::optional<Location> location,
2406 NegateOp::Adaptor adaptor,
2408 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2415 const Type input1Type = getInput1().getType();
2416 const Type outputType = getOutput().getType();
2423 return emitOpError() <<
"requires the same shape for input1 and output";
2426 const Type input1ZpEType =
2428 if (input1EType != input1ZpEType) {
2429 return emitOpError(
"expect both input1 and its zero point are the same "
2430 "element type, got ")
2431 << input1EType <<
" and " << input1ZpEType;
2434 const Type outputZpEType =
2436 if (outputEType != outputZpEType) {
2437 return emitOpError(
"expect both output and its zero point are the same "
2438 "element type, got ")
2439 << outputEType <<
" and " << outputZpEType;
2442 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2443 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2446 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2447 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2458 outputShape.resize(4, ShapedType::kDynamic);
2473 if (!ShapedType::isDynamic(height)) {
2474 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2475 outputShape[1] = padded / stride[0] + 1;
2478 if (!ShapedType::isDynamic(width)) {
2479 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2480 outputShape[2] = padded / stride[1] + 1;
2487 LogicalResult Conv2DOp::inferReturnTypeComponents(
2488 MLIRContext *context, ::std::optional<Location> location,
2489 Conv2DOp::Adaptor adaptor,
2493 int64_t inputWidth = ShapedType::kDynamic;
2494 int64_t inputHeight = ShapedType::kDynamic;
2495 int64_t weightWidth = ShapedType::kDynamic;
2496 int64_t weightHeight = ShapedType::kDynamic;
2501 if (inputShape.hasRank()) {
2502 outputShape[0] = inputShape.getDimSize(0);
2503 inputHeight = inputShape.getDimSize(1);
2504 inputWidth = inputShape.getDimSize(2);
2508 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2509 if (weightShape.hasRank()) {
2510 outputShape[3] = weightShape.getDimSize(0);
2511 weightHeight = weightShape.getDimSize(1);
2512 weightWidth = weightShape.getDimSize(2);
2517 if (biasShape.hasRank()) {
2518 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2519 ? biasShape.getDimSize(0)
2527 if (!ShapedType::isDynamic(inputHeight) &&
2528 !ShapedType::isDynamic(weightHeight)) {
2529 int64_t inputSize = inputHeight + padding[0] + padding[1];
2530 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2531 int64_t unstridedResult = inputSize - filterSize + 1;
2532 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2535 if (!ShapedType::isDynamic(inputWidth) &&
2536 !ShapedType::isDynamic(weightWidth)) {
2537 int64_t inputSize = inputWidth + padding[2] + padding[3];
2538 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2539 int64_t unstridedResult = inputSize - filterSize + 1;
2540 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2552 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
2553 return emitOpError(
"expect all padding values to be >= 0, got ") << padding;
2556 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
2557 return emitOpError(
"expect all stride values to be >= 1, got ") << strides;
2560 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
2561 return emitOpError(
"expect all dilation values to be >= 1, got ")
2564 const RankedTensorType outputType =
2565 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2570 const RankedTensorType inputType =
2571 llvm::dyn_cast<RankedTensorType>(getInput().
getType());
2572 const RankedTensorType weightType =
2573 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
2575 if (inputType && weightType) {
2576 const auto verifyOutputSize =
2577 [
this](
const int64_t inputSize,
const int64_t kernelSize,
2578 const int64_t outputSize,
const int64_t padBefore,
2579 const int64_t padAfter,
const int64_t stride,
2580 const int64_t dilation,
const llvm::StringRef dimName,
2581 const llvm::StringRef dimAxis,
2582 const llvm::StringRef padBeforeName,
2583 const llvm::StringRef padAfterName) -> LogicalResult {
2584 if (inputSize == ShapedType::kDynamic ||
2585 kernelSize == ShapedType::kDynamic)
2588 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
2589 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2591 if (!calculatedOutSizeMinusOne.has_value())
2592 return emitOpError(
"expected input_")
2593 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
2594 << padAfterName <<
" - (kernel_" << dimName
2595 <<
" - 1) * dilation_" << dimAxis
2596 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
2597 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
2598 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
2601 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2602 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2603 return emitOpError(
"calculated output ")
2604 << dimName <<
" did not match expected: "
2605 <<
"calculated=" << calculatedOutSize
2606 <<
", expected=" << outputSize;
2611 if (failed(verifyOutputSize(
2612 inputType.getDimSize(1), weightType.getDimSize(1),
2613 outputType.getDimSize(1), padding[0], padding[1], strides[0],
2614 dilations[0],
"height",
"y",
"top",
"bottom")))
2617 if (failed(verifyOutputSize(
2618 inputType.getDimSize(2), weightType.getDimSize(2),
2619 outputType.getDimSize(2), padding[2], padding[3], strides[1],
2620 dilations[1],
"width",
"x",
"left",
"right")))
2624 const RankedTensorType biasType =
2625 llvm::dyn_cast<RankedTensorType>(getBias().
getType());
2630 const int64_t biasChannels = biasType.getDimSize(0);
2631 const int64_t outputChannels = outputType.getDimSize(3);
2632 if (biasChannels == ShapedType::kDynamic ||
2633 outputChannels == ShapedType::kDynamic)
2637 if (biasChannels != outputChannels && biasChannels != 1)
2639 "bias channels expected to be equal to output channels (")
2640 << outputChannels <<
") or 1, got " << biasChannels;
2644 LogicalResult Conv3DOp::inferReturnTypeComponents(
2645 MLIRContext *context, ::std::optional<Location> location,
2646 Conv3DOp::Adaptor adaptor,
2650 int64_t inputWidth = ShapedType::kDynamic;
2651 int64_t inputHeight = ShapedType::kDynamic;
2652 int64_t inputDepth = ShapedType::kDynamic;
2654 int64_t weightWidth = ShapedType::kDynamic;
2655 int64_t weightHeight = ShapedType::kDynamic;
2656 int64_t weightDepth = ShapedType::kDynamic;
2660 if (inputShape.hasRank()) {
2661 outputShape[0] = inputShape.getDimSize(0);
2662 inputDepth = inputShape.getDimSize(1);
2663 inputHeight = inputShape.getDimSize(2);
2664 inputWidth = inputShape.getDimSize(3);
2668 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2669 if (weightShape.hasRank()) {
2670 outputShape[4] = weightShape.getDimSize(0);
2671 weightDepth = weightShape.getDimSize(1);
2672 weightHeight = weightShape.getDimSize(2);
2673 weightWidth = weightShape.getDimSize(3);
2678 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2679 outputShape[4] = biasShape.getDimSize(0);
2686 if (!ShapedType::isDynamic(inputDepth) &&
2687 !ShapedType::isDynamic(weightDepth)) {
2688 int32_t inputSize = inputDepth + pad[0] + pad[1];
2689 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2690 int32_t unstridedResult = inputSize - filterSize + 1;
2691 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2694 if (!ShapedType::isDynamic(inputHeight) &&
2695 !ShapedType::isDynamic(weightHeight)) {
2696 int32_t inputSize = inputHeight + pad[2] + pad[3];
2697 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2698 int32_t unstridedResult = inputSize - filterSize + 1;
2699 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2702 if (!ShapedType::isDynamic(inputWidth) &&
2703 !ShapedType::isDynamic(weightWidth)) {
2704 int32_t inputSize = inputWidth + pad[4] + pad[5];
2705 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2706 int32_t unstridedResult = inputSize - filterSize + 1;
2707 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2720 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
2721 MLIRContext *context, ::std::optional<Location> location,
2722 AvgPool2dOp::Adaptor adaptor,
2725 const Properties &prop = adaptor.getProperties();
2727 inferredReturnShapes);
2730 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2731 MLIRContext *context, ::std::optional<Location> location,
2732 MaxPool2dOp::Adaptor adaptor,
2735 const Properties &prop = adaptor.getProperties();
2737 inferredReturnShapes);
2751 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2752 MLIRContext *context, ::std::optional<Location> location,
2753 DepthwiseConv2DOp::Adaptor adaptor,
2757 int64_t inputWidth = ShapedType::kDynamic;
2758 int64_t inputHeight = ShapedType::kDynamic;
2759 int64_t inputChannels = ShapedType::kDynamic;
2761 int64_t weightWidth = ShapedType::kDynamic;
2762 int64_t weightHeight = ShapedType::kDynamic;
2763 int64_t depthChannels = ShapedType::kDynamic;
2767 if (inputShape.hasRank()) {
2768 outputShape[0] = inputShape.getDimSize(0);
2769 inputHeight = inputShape.getDimSize(1);
2770 inputWidth = inputShape.getDimSize(2);
2771 inputChannels = inputShape.getDimSize(3);
2775 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2776 if (weightShape.hasRank()) {
2777 weightHeight = weightShape.getDimSize(0);
2778 weightWidth = weightShape.getDimSize(1);
2779 inputChannels = ShapedType::isDynamic(inputChannels)
2780 ? weightShape.getDimSize(2)
2782 depthChannels = weightShape.getDimSize(3);
2787 if (!ShapedType::isDynamic(inputChannels) &&
2788 !ShapedType::isDynamic(depthChannels)) {
2789 outputShape[3] = inputChannels * depthChannels;
2794 if (biasShape.hasRank()) {
2795 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2796 ? biasShape.getDimSize(0)
2804 if (!ShapedType::isDynamic(inputHeight) &&
2805 !ShapedType::isDynamic(weightHeight)) {
2806 int64_t inputSize = inputHeight + padding[0] + padding[1];
2807 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2808 int64_t unstridedResult = inputSize - filterSize + 1;
2809 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2812 if (!ShapedType::isDynamic(inputWidth) &&
2813 !ShapedType::isDynamic(weightWidth)) {
2814 int64_t inputSize = inputWidth + padding[2] + padding[3];
2815 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2816 int64_t unstridedResult = inputSize - filterSize + 1;
2817 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2830 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2831 MLIRContext *context, ::std::optional<Location> location,
2832 TransposeConv2DOp::Adaptor adaptor,
2836 int64_t inputWidth = ShapedType::kDynamic;
2837 int64_t inputHeight = ShapedType::kDynamic;
2838 int64_t weightWidth = ShapedType::kDynamic;
2839 int64_t weightHeight = ShapedType::kDynamic;
2843 if (inputShape.hasRank()) {
2844 outputShape[0] = ShapedType::isDynamic(outputShape[0])
2845 ? inputShape.getDimSize(0)
2847 inputHeight = inputShape.getDimSize(1);
2848 inputWidth = inputShape.getDimSize(2);
2852 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2853 if (weightShape.hasRank()) {
2854 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2855 ? weightShape.getDimSize(0)
2857 weightHeight = weightShape.getDimSize(1);
2858 weightWidth = weightShape.getDimSize(2);
2863 if (biasShape.hasRank()) {
2864 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2865 ? biasShape.getDimSize(0)
2872 if (!ShapedType::isDynamic(inputHeight) &&
2873 !ShapedType::isDynamic(weightHeight)) {
2874 int64_t calculateSize =
2875 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2877 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2880 if (!ShapedType::isDynamic(inputWidth) &&
2881 !ShapedType::isDynamic(weightWidth)) {
2882 int64_t calculateSize =
2883 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2885 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2899 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
2901 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
2905 auto inputElementType =
2907 if (!mlir::isa<IntegerType>(inputElementType)) {
2908 emitOpError(
"expect input to have integer element type, got ")
2909 << inputElementType;
2913 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
2915 emitOpError(
"expect shaped tensor for output, got ")
2916 << getOutput().getType();
2920 auto outputElementType =
2922 if (!mlir::isa<IntegerType>(outputElementType)) {
2923 emitOpError(
"expect output to have integer element type, got ")
2924 << outputElementType;
2936 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
2937 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
2940 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2941 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2944 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
2945 if (!multiplierType) {
2946 emitOpError(
"expect shaped tensor for multiplier, got ")
2947 << getMultiplier().getType();
2951 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
2953 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
2958 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2959 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
2960 << multiplierType.getElementType();
2965 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2967 "expect i16 element type for multiplier for scale32=false, got ")
2968 << multiplierType.getElementType();
2972 if (!inputType.hasRank())
2978 int64_t numChannels = 1;
2979 if (getPerChannel()) {
2980 numChannels = inputType.getDimSize(inputType.getRank() - 1);
2983 if (!multiplierType.hasRank())
2988 if (multiplierShape[0] != ShapedType::kDynamic &&
2989 multiplierShape[0] != numChannels) {
2990 emitOpError(
"expect shape of { ")
2991 << numChannels <<
" } for multiplier input, got { "
2992 << multiplierShape[0] <<
" }";
2996 if (!shiftType.hasRank())
3001 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3002 emitOpError(
"expect shape of { ")
3003 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3010 LogicalResult RescaleOp::inferReturnTypeComponents(
3011 MLIRContext *context, ::std::optional<Location> location,
3012 RescaleOp::Adaptor adaptor,
3019 LogicalResult IfOp::inferReturnTypeComponents(
3020 MLIRContext *context, ::std::optional<Location> location,
3021 IfOp::Adaptor adaptor,
3024 for (
Region *region : adaptor.getRegions()) {
3025 for (
auto &block : *region)
3026 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3027 yieldOps.push_back(returnOp);
3030 if (yieldOps.empty())
3035 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3036 for (
auto operand : yieldOps.front().getOperands()) {
3037 resultKnowledge.push_back(
3041 for (
auto yieldOp : yieldOps) {
3042 if (resultKnowledge.size() != yieldOp.getNumOperands())
3046 int32_t index = it.index();
3048 resultKnowledge[index],
3052 resultKnowledge[index] = meet;
3057 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3063 LogicalResult WhileOp::inferReturnTypeComponents(
3064 MLIRContext *context, ::std::optional<Location> location,
3065 WhileOp::Adaptor adaptor,
3068 for (
auto &block : adaptor.getBodyGraph())
3069 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3070 yieldOps.push_back(returnOp);
3074 if (yieldOps.empty())
3079 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3080 for (
auto operand : yieldOps.front().getOperands()) {
3081 resultKnowledge.push_back(
3085 for (
auto yieldOp : yieldOps) {
3086 if (resultKnowledge.size() != yieldOp.getNumOperands())
3090 int32_t index = it.index();
3092 resultKnowledge[index],
3094 resultKnowledge[index] = meet;
3100 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3106 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3107 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3108 return llvm::to_vector<4>(vt.getShape());
3109 return std::nullopt;
3146 bool printBlockTerminators =
false;
3148 p <<
" " << getCondition();
3149 if (!getResults().empty()) {
3150 p <<
" -> (" << getResultTypes() <<
")";
3152 printBlockTerminators =
true;
3157 printBlockTerminators);
3160 auto &elseRegion = getElseGraph();
3161 if (!elseRegion.
empty()) {
3165 printBlockTerminators);
3176 TensorType inputType = getInput1().getType();
3177 TensorType outputType = getOutput().getType();
3178 int32_t reverseAxis = getAxis();
3180 if (reverseAxis < 0)
3181 return emitOpError(
"expected non-negative reverse axis");
3183 int64_t inputRank = inputType.getRank();
3186 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3187 return emitOpError(
"expect input tensor rank (")
3188 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3192 int64_t outputRank = outputType.getRank();
3193 if (inputType.
hasRank() && outputRank != inputType.getRank())
3195 "expect output tensor rank to be equal to input tensor rank");
3196 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3197 return emitOpError(
"expect output tensor rank (")
3198 << outputRank <<
") to be larger than reverse axis ("
3199 << reverseAxis <<
")";
3215 auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().
getType());
3216 if (!predicateType) {
3217 return emitOpError(
"expect shaped tensor for input1, got ")
3218 << getInput1().getType();
3220 auto predicateElementType = predicateType.getElementType();
3221 if (!predicateElementType.isInteger(1)) {
3222 return emitOpError(
"expect element type of bool for input1, got ")
3223 << predicateElementType;
3241 FunctionType functionType;
3246 result.
addTypes(functionType.getResults());
3248 if (functionType.getNumInputs() != operands.size()) {
3250 <<
"expected as many input types as operands "
3251 <<
"(expected " << operands.size() <<
" got "
3252 << functionType.getNumInputs() <<
")";
3262 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3263 regionArgs[i].type = functionType.getInput(i);
3265 return failure(parser.
parseRegion(*cond, regionArgs) ||
3273 StringRef prefix =
"") {
3274 assert(blocksArgs.size() == initializers.size() &&
3275 "expected same length of arguments and initializers");
3276 if (initializers.empty())
3279 parser << prefix <<
'(';
3280 llvm::interleaveComma(
3281 llvm::zip(blocksArgs, initializers), parser,
3282 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3288 getInputList(),
" ");
3291 getResults().getTypes());
3306 if (llvm::isa<FloatType>(srcElemType)) {
3308 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3309 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3311 if (llvm::isa<IntegerType>(srcElemType)) {
3314 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3316 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3317 return std::nullopt;
3325 return mlir::isa<tosa::shapeType>(t);
3332 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3338 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3339 Operation *definingOp = v.getDefiningOp();
3341 return op->
emitOpError(
"shape operand is not compile time resolvable");
3350 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3351 return op->
emitOpError(
"must have operands with tosa shape type");
3355 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3356 return op->
emitOpError(
"must have result with tosa shape type");
3369 auto getRank = [](
const Type type) {
3370 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3376 for (
auto type : operandTypes) {
3377 if (getRank(type) != rank) {
3378 return op->
emitOpError(
"operands don't have matching ranks");
3381 for (
auto type : resultTypes) {
3382 if (getRank(type) != rank) {
3383 return op->
emitOpError(
"result shape has different rank than operands");
3395 auto valuesRank = getValues().getType().getRank();
3396 if (valuesRank != 1)
3397 return emitOpError(
"expect elements in attribute values with rank 1");
3399 auto count = getValues().getNumElements();
3400 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
3401 if (!(count == rank || (count == 1 && rank == 0))) {
3402 return emitOpError(
"expect number of elements in attribute values (")
3403 << count <<
") to be equal to the rank (" << rank
3404 <<
") for the result shape type";
3413 #define GET_ATTRDEF_CLASSES
3414 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3419 #define GET_TYPEDEF_CLASSES
3420 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3426 #define GET_OP_CLASSES
3427 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(T op, Value val)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)