29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(
Dialect *dialect)
89 LogicalResult writeAttribute(
Attribute attr,
91 return ::writeAttribute(attr, writer);
101 LogicalResult writeType(
Type type,
103 return ::writeType(type, writer);
110 std::unique_ptr<DialectVersion>
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
131 return {&getBodyGraph()};
139 return to_vector(llvm::map_range(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
146 Type elementType = variableOp.getType();
156 void TosaDialect::initialize() {
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
166 #define GET_ATTRDEF_LIST
167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return builder.
create<tosa::ConstShapeOp>(
186 loc, type, llvm::cast<DenseIntElementsAttr>(value));
188 if (llvm::isa<ElementsAttr>(value))
189 return builder.
create<tosa::ConstOp>(loc, type,
190 llvm::cast<ElementsAttr>(value));
200 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
202 TypeAttr &typeAttr) {
203 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
206 <<
"expected ranked type";
208 auto elementType = shapedType.getElementType();
216 <<
"expected shaped type";
233 <<
"expected attribute";
235 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
240 <<
"expected Typed attr";
243 initialValueAttr =
nullptr;
247 <<
"expected type after colon";
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
254 TypeAttr typeAttr,
Attribute initialValueAttr) {
255 bool needsSpace =
false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
267 if (initialValueAttr) {
279 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
287 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
288 srcType = quantType.getStorageType();
297 Value valZp, StringRef name) {
302 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
306 if (!bothInts || !sameBitWidth) {
308 <<
"expected " << name <<
" and " << name
309 <<
"_zp to both be integer of the same bitwidth, but got " << eType
310 <<
" vs. " << eZpType;
317 Value src, int32_t val) {
322 const auto padConstAttr{
323 llvm::isa<FloatType>(srcElemType)
328 return builder.
create<tosa::ConstOp>(loc, padConstType, padConstAttr);
335 template <
typename T>
337 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
338 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
340 auto inputEType = inputType.getElementType();
341 auto weightEType = weightType.getElementType();
343 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
345 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
346 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
347 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
349 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
350 inputEType = quantType.getStorageType();
352 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
353 weightEType = quantType.getStorageType();
355 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
356 biasEType = quantType.getStorageType();
358 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
359 resultEType = quantType.getStorageType();
361 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
365 "expect both bias and result to have same element type, got ")
366 << biasEType <<
" and " << resultEType;
370 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
371 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
372 if (inputEType != weightEType) {
374 "expect both input and weight to have same element type, got ")
375 << inputEType <<
" and " << weightEType;
380 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
381 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
384 if (inputIsFloat != weightIsFloat) {
386 "expect both input and weight to be float or not together, got ")
387 << inputEType <<
" and " << weightEType;
392 if (inputEType != inputZpEType) {
393 return op.emitOpError(
"expect both input and its zero point are the same "
394 "element type, got ")
395 << inputEType <<
" and " << inputZpEType;
399 if (weightEType != weightZpEType) {
400 return op.emitOpError(
"expect both weight and its zero point are the same "
401 "element type, got ")
402 << weightEType <<
" and " << weightZpEType;
405 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
409 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
418 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
419 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
421 if (!attrType || !outputType) {
422 emitOpError(
"expected tensors for attr/result type");
426 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
427 outputType.getElementType())) {
428 if (result.getStorageType() == attrType.getElementType())
432 if (attrType.getElementType() != outputType.getElementType()) {
433 emitOpError(
"expected same attr/result element types");
440 template <
typename T>
443 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
445 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
446 inputEType = quantType.getStorageType();
448 auto accType = op.getAccType();
449 if (inputEType.isInteger(8) && !accType.isInteger(32))
450 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
452 if (inputEType.isInteger(16) && !accType.isInteger(48))
453 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
455 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
458 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
461 if (inputEType.isBF16() && !accType.isF32())
462 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
464 if (inputEType.isF32() && !accType.isF32())
465 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
468 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
470 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
471 resultEType = quantType.getStorageType();
481 template <
typename T>
484 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
485 return op.emitOpError(
"expect all padding values to be >= 0, got ")
489 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
490 return op.emitOpError(
"expect all stride values to be >= 1, got ")
494 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
495 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
498 const RankedTensorType outputType =
499 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
504 const RankedTensorType inputType =
505 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
506 const RankedTensorType weightType =
507 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
509 if (inputType && weightType) {
510 const auto verifyOutputSize =
511 [&op](
const int64_t inputSize,
const int64_t kernelSize,
512 const int64_t outputSize,
const int64_t padBefore,
513 const int64_t padAfter,
const int64_t stride,
514 const int64_t dilation,
const llvm::StringRef dimName,
515 const llvm::StringRef dimAxis,
516 const llvm::StringRef padBeforeName,
517 const llvm::StringRef padAfterName) -> LogicalResult {
518 if (inputSize == ShapedType::kDynamic ||
519 kernelSize == ShapedType::kDynamic)
524 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
525 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
527 if (!calculatedOutSizeMinusOne.has_value())
528 return op.emitOpError(
"expected input_")
529 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
530 << padAfterName <<
" - (kernel_" << dimName
531 <<
" - 1) * dilation_" << dimAxis
532 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
533 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
534 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
537 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539 return op.emitOpError(
"calculated output ")
540 << dimName <<
" did not match expected: "
541 <<
"calculated=" << calculatedOutSize
542 <<
", expected=" << outputSize;
548 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549 if (failed(verifyOutputSize(
550 inputType.getDimSize(1), weightType.getDimSize(1),
551 outputType.getDimSize(1), padding[0], padding[1], strides[0],
552 dilations[0],
"height",
"y",
"top",
"bottom")))
555 if (failed(verifyOutputSize(
556 inputType.getDimSize(2), weightType.getDimSize(2),
557 outputType.getDimSize(2), padding[2], padding[3], strides[1],
558 dilations[1],
"width",
"x",
"left",
"right")))
563 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564 if (failed(verifyOutputSize(
565 inputType.getDimSize(1), weightType.getDimSize(0),
566 outputType.getDimSize(1), padding[0], padding[1], strides[0],
567 dilations[0],
"height",
"y",
"top",
"bottom")))
570 if (failed(verifyOutputSize(
571 inputType.getDimSize(2), weightType.getDimSize(1),
572 outputType.getDimSize(2), padding[2], padding[3], strides[1],
573 dilations[1],
"width",
"x",
"left",
"right")))
578 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579 if (failed(verifyOutputSize(
580 inputType.getDimSize(1), weightType.getDimSize(1),
581 outputType.getDimSize(1), padding[0], padding[1], strides[0],
582 dilations[0],
"depth",
"d",
"front",
"back")))
585 if (failed(verifyOutputSize(
586 inputType.getDimSize(2), weightType.getDimSize(2),
587 outputType.getDimSize(2), padding[2], padding[3], strides[1],
588 dilations[1],
"height",
"y",
"top",
"bottom")))
591 if (failed(verifyOutputSize(
592 inputType.getDimSize(3), weightType.getDimSize(3),
593 outputType.getDimSize(3), padding[4], padding[5], strides[2],
594 dilations[2],
"width",
"x",
"left",
"right")))
599 const RankedTensorType biasType =
600 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
605 const int64_t biasChannels = biasType.getDimSize(0);
606 const int64_t outputChannels =
607 outputType.getDimSize(outputType.getRank() - 1);
608 if (biasChannels == ShapedType::kDynamic ||
609 outputChannels == ShapedType::kDynamic)
613 if (biasChannels != outputChannels && biasChannels != 1)
614 return op.emitOpError(
615 "bias channels expected to be equal to output channels (")
616 << outputChannels <<
") or 1, got " << biasChannels;
623 StringRef name1,
Type type2,
625 auto shapeType1 = dyn_cast<ShapedType>(type1);
626 auto shapeType2 = dyn_cast<ShapedType>(type2);
627 if (!shapeType1 || !shapeType2)
630 auto elemType1 = shapeType1.getElementType();
631 auto elemType2 = shapeType2.getElementType();
632 if (elemType1 != elemType2)
634 <<
"require same element type for " << name1 <<
" (" << elemType1
635 <<
") and " << name2 <<
" (" << elemType2 <<
")";
639 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
640 << name2 <<
" (" << type2 <<
")";
650 if (list1.size() != list2.size())
652 <<
"require same number of values in " << name1 <<
" ("
653 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
655 for (
auto [type1, type2] :
669 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
677 tosa::VariableOp varOp =
nullptr;
691 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
692 if (symName == tosaOp.getName()) {
707 template <
typename T>
709 StringRef symName = op.getName();
712 return op->emitOpError(
"'")
713 << symName <<
"' has not been declared by 'tosa.variable'";
726 template <
typename T>
728 auto inputType = llvm::dyn_cast<TensorType>(inType);
729 auto outputType = llvm::dyn_cast<TensorType>(outType);
731 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
735 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
738 auto inputElementType = inputType.getElementType();
739 auto outputElementType = outputType.getElementType();
740 auto inputQuantType =
741 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
742 auto outputQuantType =
743 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
744 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746 inputElementType != outputElementType) {
751 op.emitOpError(
"expect input and output to have same element type, got ")
752 << inputElementType <<
" and " << outputElementType;
759 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
762 if (
const auto resultETy = resultType.getElementType();
763 !resultETy.isIntOrIndex())
764 return emitOpError(
"result tensor is not of integer type");
766 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
767 if (!inputType.hasRank())
771 const int64_t axis = getAxisAttr().getInt();
772 if (((axis < 0) || axis >= inputType.getRank()))
773 return emitOpError(
"specified axis is outside the rank of the tensor");
775 if (!resultType.hasRank())
781 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
783 return emitOpError(
"expected output shape '")
784 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
789 template <
typename T>
792 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
793 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
797 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
798 return op.emitOpError(
"expect all stride values to be >= 1, got ")
802 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
803 return op.emitOpError(
"expect all padding values to be >= 0, got ")
807 const int64_t kernelX = kernel[1];
808 const int64_t padLeft = padding[2];
809 const int64_t padRight = padding[3];
810 if (padRight >= kernelX || padLeft >= kernelX)
811 return op.emitOpError(
"expected left/right padding to be less than the "
812 "width of the kernel, got pad_left=")
813 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
815 const int64_t kernelY = kernel[0];
816 const int64_t padTop = padding[0];
817 const int64_t padBottom = padding[1];
818 if (padTop >= kernelY || padBottom >= kernelY)
819 return op.emitOpError(
"expected top/bottom padding to be less than the "
820 "height of the kernel, got pad_top=")
821 << padTop <<
", pad_bottom=" << padBottom
822 <<
", kernel_y=" << kernelY;
824 const auto inputType =
825 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
826 const auto outputType =
827 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
828 if (!inputType || !outputType)
831 const auto verifyOutputSize =
832 [&op](
const int64_t inputSize,
const int64_t outputSize,
833 const int64_t kernelSize,
const int64_t strideSize,
834 const int64_t padBefore,
const int64_t padAfter,
835 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
836 const llvm::StringRef padBeforeName,
837 const llvm::StringRef padAfterName) -> LogicalResult {
838 if (ShapedType::isDynamic(inputSize))
841 const std::optional<int64_t> calculatedOutSizeMinusOne =
842 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
843 if (!calculatedOutSizeMinusOne.has_value())
844 return op.emitOpError(
"expected input_")
845 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
846 << padAfterName <<
" - kernel_" << dimAxis
847 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
848 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
849 << kernelSize <<
") / " << strideSize;
851 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853 return op.emitOpError(
"calculated output ")
854 << dimName <<
" did not match expected: "
855 <<
"calculated=" << calculatedOutSize
856 <<
", expected=" << outputSize;
861 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862 kernel[0], strides[0], padding[0], padding[1],
863 "height",
"y",
"top",
"bottom")))
866 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867 kernel[1], strides[1], padding[2], padding[3],
868 "width",
"x",
"left",
"right")))
883 auto accType = getAccType();
884 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
885 return emitOpError(
"accumulator type for integer tensor is not i32");
887 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
888 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
890 if (inputETy.
isBF16() && !accType.isF32())
891 return emitOpError(
"accumulator type for bf16 tensor is not f32");
893 if (inputETy.
isF32() && !accType.isF32())
894 return emitOpError(
"accumulator type for f32 tensor is not f32");
896 if (inputETy != inputZpETy)
897 return emitOpError(
"expect both input and its zero point are the same "
898 "element type, got ")
899 << inputETy <<
" and " << inputZpETy;
901 if (resultETy != outputZpETy)
902 return emitOpError(
"expect both output and its zero point are the same "
903 "element type, got ")
904 << resultETy <<
" and " << outputZpETy;
906 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
910 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
919 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
921 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
922 inputETy = quantType.getStorageType();
925 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
927 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
928 outputETy = quantType.getStorageType();
930 if (inputETy != outputETy)
931 return emitOpError(
"input/output element types are incompatible.");
933 auto maxValAttr = getMaxValAttr();
934 auto minValAttr = getMinValAttr();
938 if (inputETy.
isInteger(dataTypeBitWidth)) {
942 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
943 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
944 if (!intMaxValAttr || !intMinValAttr ||
945 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946 (intMaxValAttr.getType() != inputETy))
947 return emitOpError(
"min/max attributes types are incompatible with "
948 "input/output element types.");
950 const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
951 const APInt minVal = intMinValAttr.getValue();
952 const APInt maxVal = intMaxValAttr.getValue();
953 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954 return emitOpError(
"expected min_val <= max_val, got min_val=")
955 << minValAttr <<
", max_val=" << maxValAttr;
960 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
961 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
962 if (!floatMaxValAttr || !floatMinValAttr ||
963 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964 (floatMaxValAttr.getType() != inputETy))
965 return emitOpError(
"min/max attributes types are incompatible with "
966 "input/output element types.");
968 const APFloat minVal = floatMinValAttr.getValue();
969 const APFloat maxVal = floatMaxValAttr.getValue();
970 if (minVal.isNaN() || maxVal.isNaN())
971 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
972 << minValAttr <<
", max_val=" << maxValAttr;
975 return emitOpError(
"expected min_val <= max_val, got min_val=")
976 << minValAttr <<
", max_val=" << maxValAttr;
996 result.
addOperands({input, weight, bias, zps.first, zps.second});
1001 Type finalOutputType = outputType;
1018 result.
addOperands({input, weight, bias, zps.first, zps.second});
1022 Type finalOutputType = outputType;
1039 result.
addOperands({a, b, zps.first, zps.second});
1041 Type finalOutputType{outputType};
1044 auto inputBits = eType.getIntOrFloatBitWidth();
1046 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1047 assert(outputShapedType &&
"Output must be a shaped type");
1049 IntegerType accElementType;
1050 if (inputBits == 16)
1055 finalOutputType = outputShapedType.clone(accElementType);
1066 DenseArrayAttr kernel, DenseArrayAttr stride,
1067 DenseArrayAttr pad, TypeAttr accType) {
1070 int64_t outputZp{0};
1072 if (
auto quantAttr =
1074 inputZp = quantAttr.getInputZp();
1075 outputZp = quantAttr.getOutputZp();
1077 const std::optional<Value> inputZpOp =
1082 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1084 const std::optional<Value> outputZpOp =
1087 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1088 "quantized AVG_POOL2D op");
1091 if (inputZpOp && outputZpOp) {
1092 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1103 result.
types.push_back(outputType);
1113 int64_t input1Zp{0};
1114 int64_t outputZp{0};
1117 input1Zp = quantAttr.getInputZp();
1118 outputZp = quantAttr.getOutputZp();
1120 const std::optional<Value> input1ZpOp =
1124 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1127 const std::optional<Value> outputZpOp =
1131 loc,
"Failed to create output zero point for quantized NEGATE op");
1134 if (input1ZpOp && outputZpOp) {
1135 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1143 result.
types.push_back(outputType);
1156 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1159 result.
addOperands({input, paddings, padConstOp});
1160 result.
types.push_back(outputType);
1164 StringRef name,
Type variableType,
1169 auto shapedType = dyn_cast<ShapedType>(variableType);
1171 (void)
emitError(loc,
"variable type must be a shaped type");
1174 if (!shapedType.hasRank()) {
1175 (void)
emitError(loc,
"variable type must be a ranked type");
1179 auto elementType = shapedType.getElementType();
1196 int64_t outRank = 0;
1197 for (
int i = 0, e = operands.size(); i != e; ++i) {
1199 if (!shape.hasRank()) {
1204 outRank = std::max<int64_t>(outRank, shape.getRank());
1207 outShape.resize(outRank, 1);
1209 for (
int i = 0, e = operands.size(); i != e; ++i) {
1211 auto rankDiff = outShape.size() - shape.getRank();
1213 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214 auto dim1 = outShape[i + rankDiff];
1215 auto dim2 = shape.getDimSize(i);
1216 auto resolvedDim = dim1;
1220 }
else if (dim2 == 1) {
1222 }
else if (dim1 != dim2) {
1225 outShape[i + rankDiff] = resolvedDim;
1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233 MLIRContext *context, ::std::optional<Location> location,
1234 ArgMaxOp::Adaptor adaptor,
1237 IntegerAttr axis = adaptor.getProperties().axis;
1238 int32_t axisVal = axis.getValue().getSExtValue();
1240 if (!inputShape.hasRank()) {
1246 outShape.reserve(inputShape.getRank() - 1);
1247 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1250 outShape.push_back(inputShape.getDimSize(i));
1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258 MLIRContext *context, ::std::optional<Location> location,
1259 RFFT2dOp::Adaptor adaptor,
1261 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1263 if (!inputShape.hasRank())
1267 outputShape.resize(3, ShapedType::kDynamic);
1268 outputShape[0] = inputShape.getDimSize(0);
1269 outputShape[1] = inputShape.getDimSize(1);
1270 int64_t inWidth = inputShape.getDimSize(2);
1274 if (inWidth != ShapedType::kDynamic)
1275 outputShape[2] = inWidth / 2 + 1;
1284 const llvm::StringRef dimName) {
1285 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1288 << dimName <<
" to be a power of two, got " << dimSize;
1294 const auto outputTypes = getResultTypes();
1296 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1298 const auto inputType =
1299 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1303 const int64_t height = inputType.getDimSize(1);
1304 if (!ShapedType::isDynamic(height) &&
1308 const int64_t width = inputType.getDimSize(2);
1309 if (!ShapedType::isDynamic(width) &&
1313 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1319 outputType.getShape().drop_back())))
1320 return emitOpError(
"expected batch and height dimensions of input/output "
1321 "to match, got input=")
1322 << inputType <<
" output=" << outputType;
1325 const int64_t outputWidth = outputType.getDimSize(2);
1326 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327 (outputWidth != (width / 2) + 1))
1329 "expected output width to be equal to input_width / 2 + 1, got ")
1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336 MLIRContext *context, ::std::optional<Location> location,
1337 FFT2dOp::Adaptor adaptor,
1339 inferredReturnShapes.push_back(
1341 inferredReturnShapes.push_back(
1347 const auto inputRealType =
1348 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1349 const auto inputImagType =
1350 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1351 if (!inputRealType || !inputImagType)
1354 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1355 return ShapedType::isDynamic(a) ? a : b;
1358 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359 inputImagType.getDimSize(1));
1360 if (!ShapedType::isDynamic(height) &&
1364 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365 inputImagType.getDimSize(2));
1366 if (!ShapedType::isDynamic(width) &&
1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374 MLIRContext *context, ::std::optional<Location> location,
1375 ConcatOp::Adaptor adaptor,
1378 const Properties &prop = adaptor.getProperties();
1379 int32_t axis = prop.axis.getValue().getSExtValue();
1381 bool hasRankedInput =
false;
1382 for (
auto operand : adaptor.getOperands()) {
1384 if (!operandShape.hasRank())
1388 if (!hasRankedInput)
1389 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1392 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1393 if (i == axis || operandShape.isDynamicDim(i))
1395 if (outputShape[i] == ShapedType::kDynamic)
1396 outputShape[i] = operandShape.getDimSize(i);
1397 if (outputShape[i] != operandShape.getDimSize(i))
1399 "Cannot concat tensors with different sizes"
1400 " on the non-axis dimension ",
1404 hasRankedInput =
true;
1407 if (adaptor.getInput1().empty())
1411 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1412 if (!hasRankedInput) {
1418 int64_t concatDimSize = 0;
1419 for (
auto operand : adaptor.getOperands()) {
1424 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425 concatDimSize = ShapedType::kDynamic;
1429 concatDimSize += operandShape.getDimSize(axis);
1432 outputShape[axis] = concatDimSize;
1440 auto outType = getOutput().getType();
1444 if (inputList.empty())
1445 return emitOpError(
"expect at least one input");
1447 if (!llvm::all_of(inputList, [&](
auto input) {
1449 *
this, input.getType(), outType));
1454 const int32_t axis = getAxis();
1456 for (
const auto &input : inputList) {
1457 const Type inputType = input.getType();
1459 if (currShape.hasRank()) {
1460 firstRankedInputShape = currShape;
1462 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1463 return emitOpError(
"expect axis to be within range 0 < axis < "
1464 "rank(input1[firstRankedTensorIdx]), got ")
1470 const auto allOperandsHasRank = [](
const Value input) {
1473 if (llvm::all_of(inputList, allOperandsHasRank)) {
1474 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1476 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1478 const int64_t inputRank = inputShape.getRank();
1479 const size_t operandNum = index + 1;
1482 if (inputRank != firstInputRank)
1484 "expect all operands to have the same rank, but got ")
1485 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1489 for (
int i = 0; i < inputRank; i++) {
1490 const int64_t inputDim = inputShape.getDimSize(i);
1491 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1492 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1493 inputShape.isDynamicDim(i))
1495 if (inputDim != firstInputDim)
1496 return emitOpError(
"expect all operand shapes to have the same sizes "
1497 "on non-axis dimensions, but got ")
1498 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1499 <<
" on operands 0 and " << operandNum;
1504 int64_t axisSum = 0;
1505 for (
const auto &input : inputList) {
1507 if (inputShape.isDynamicDim(axis)) {
1512 axisSum += inputShape.getDimSize(axis);
1515 if (axisSum >= 0 && outputShape.hasRank() &&
1516 !outputShape.isDynamicDim(axis) &&
1517 axisSum != outputShape.getDimSize(axis))
1518 return emitOpError(
"requires sum of axis dimensions of input1 "
1519 "equal to output axis dimension, got ")
1520 << axisSum <<
" and " << outputShape.getDimSize(axis);
1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527 MLIRContext *context, ::std::optional<Location> location,
1544 if (l.size() != r.size() || l.size() != 1)
1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550 MLIRContext *context, ::std::optional<Location> location,
1551 MatMulOp::Adaptor adaptor,
1558 outShape.resize(3, ShapedType::kDynamic);
1560 if (lhsShape.hasRank()) {
1561 outShape[0] = lhsShape.getDimSize(0);
1562 outShape[1] = lhsShape.getDimSize(1);
1565 if (rhsShape.hasRank()) {
1566 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1568 outShape[2] = rhsShape.getDimSize(2);
1576 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1577 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1581 return emitOpError(
"expect a shaped tensor for input a, got ")
1582 << getA().getType();
1585 return emitOpError(
"expect a shaped tensor for input b, got ")
1586 << getB().getType();
1588 auto aElementType = aType.getElementType();
1589 auto bElementType = bType.getElementType();
1591 auto aQuantizedEType =
1592 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1593 auto bQuantizedEType =
1594 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1596 if (aQuantizedEType || bQuantizedEType) {
1597 if (!aQuantizedEType || !bQuantizedEType) {
1598 return emitOpError(
"expect operands to be both quantized or both not "
1600 << aElementType <<
" and " << bElementType;
1603 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605 if (aQuantWidth != bQuantWidth) {
1606 return emitOpError(
"expect quantized operands to have same widths, got ")
1607 << aQuantWidth <<
" and " << bQuantWidth;
1611 if (aElementType != bElementType) {
1612 return emitOpError(
"expect same element type for inputs a and b, got ")
1613 << aElementType <<
" and " << bElementType;
1620 if (aEType != aZpEType) {
1621 return emitOpError(
"expect input a and a_zp have the same "
1622 "element type, got ")
1623 << aEType <<
" and " << aZpEType;
1628 if (bEType != bZpEType) {
1629 return emitOpError(
"expect input b and b_zp have the same "
1630 "element type, got ")
1631 << bEType <<
" and " << bZpEType;
1634 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1638 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646 MLIRContext *context, ::std::optional<Location> location,
1647 PadOp::Adaptor adaptor,
1649 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1651 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1656 if (!inputShape.hasRank()) {
1657 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1666 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1671 outputShape.reserve(inputShape.getRank());
1672 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1673 if (inputShape.isDynamicDim(i)) {
1674 outputShape.push_back(ShapedType::kDynamic);
1677 auto padFront = paddingValues[i * 2];
1678 auto padBack = paddingValues[i * 2 + 1];
1679 if (padFront < 0 || padBack < 0) {
1681 outputShape.push_back(ShapedType::kDynamic);
1685 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1699 if (
auto padConst = getPadConst()) {
1707 RankedTensorType inputType =
1708 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1709 RankedTensorType outputType =
1710 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1711 if (!inputType || !outputType)
1714 auto inputRank = inputType.getRank();
1715 auto outputRank = outputType.getRank();
1716 if (inputRank != outputRank)
1717 return emitOpError() <<
"expect same input and output tensor rank, but got "
1718 <<
"inputRank: " << inputRank
1719 <<
", outputRank: " << outputRank;
1726 auto paddingValues = paddingAttr.getValues<APInt>();
1727 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1728 return emitOpError() <<
"padding tensor must have " << inputRank
1729 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1730 << paddingValues.size();
1732 auto inputShape = inputType.getShape();
1733 auto outputShape = outputType.getShape();
1735 for (int64_t i = 0; i < inputRank; ++i) {
1736 int64_t padStart = paddingValues[i * 2].getSExtValue();
1737 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1739 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740 return emitOpError()
1741 <<
"invalid padding values at dimension " << i
1742 <<
": values must be non-negative or -1 for dynamic padding, got ["
1743 << padStart <<
", " << padEnd <<
"]";
1747 if (inputShape[i] == ShapedType::kDynamic ||
1748 outputShape[i] == ShapedType::kDynamic)
1751 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752 return emitOpError() <<
"mismatch in output shape at dimension " << i
1753 <<
": expected " << inputShape[i] <<
" + "
1754 << padStart <<
" + " << padEnd <<
" = "
1755 << (inputShape[i] + padStart + padEnd)
1756 <<
", but got " << outputShape[i];
1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional<Location> location,
1765 SliceOp::Adaptor adaptor,
1774 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1782 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1785 if (inputShape.hasRank()) {
1786 for (
size_t i = 0; i < size.size(); i++) {
1787 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789 start[i] < inputShape.getDimSize(i))) {
1791 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1794 outputShape[i] = size[i];
1798 if (size[i] == -1) {
1799 outputShape[i] = inputShape.getDimSize(i) - start[i];
1800 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1802 outputShape[i] = size[i];
1821 if (inputShape.hasRank()) {
1822 const auto inputRank = inputShape.getRank();
1824 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1826 "expect input1 and output to have the same ranks, got ")
1827 << inputRank <<
" and " << outputShape.getRank();
1829 const auto startShapeRank =
1830 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1831 if (inputRank != startShapeRank)
1832 return emitOpError(
"length of start is not equal to rank of input shape");
1834 const auto sizeShapeRank =
1835 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1836 if (inputRank != sizeShapeRank)
1837 return emitOpError(
"length of size is not equal to rank of input shape");
1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844 MLIRContext *context, ::std::optional<Location> location,
1860 const Value output = getOutput();
1865 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1866 IntegerType lhsIntType =
1868 IntegerType rhsIntType =
1870 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871 return emitOpError(
"requires the same element type for all operands");
1876 if (lhsIntType.getWidth() > resIntType.getWidth())
1877 return emitOpError(
"invalid data type size for operands or result");
1882 for (
int i = 0; i < 2; ++i) {
1885 "requires the same element type for all operands and results");
1889 ElementsAttr shift_elem;
1891 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1893 return emitOpError() <<
"require shift to be 0 for float type";
1901 TypeRange operandTypes = getOperandTypes();
1902 ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903 ShapedType bType = cast<ShapedType>(operandTypes[1]);
1905 const bool aHasRank = aType.hasRank();
1906 const bool bHasRank = bType.hasRank();
1907 if (aHasRank && bHasRank) {
1908 const int64_t aRank = aType.getRank();
1909 const int64_t bRank = bType.getRank();
1911 return emitOpError(
"a and b operands don't have matching ranks, got ")
1912 << aRank <<
" and " << bRank;
1917 aType.getShape(), bType.getShape(), resultShape))
1918 return emitOpError(
"a and b operands don't have broadcast-compatible "
1920 << aType <<
" and " << bType;
1923 ShapedType resultType = cast<ShapedType>(output.
getType());
1924 if (!resultType.hasRank())
1927 const int64_t resultRank = resultType.getRank();
1928 if (aHasRank && resultRank != aType.getRank())
1929 return emitOpError(
"result type has different rank than a, got ")
1930 << resultRank <<
" vs " << aType.getRank();
1931 if (bHasRank && resultRank != bType.getRank())
1932 return emitOpError(
"result type has different rank than b, got ")
1933 << resultRank <<
" vs " << bType.getRank();
1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939 MLIRContext *context, ::std::optional<Location> location,
1940 TableOp::Adaptor adaptor,
1942 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1944 if (!inputShape.hasRank()) {
1949 inferredReturnShapes.resize(1);
1950 inputShape.getDims(inferredReturnShapes[0]);
1955 TensorType inputType = getInput1().getType();
1956 TensorType outputType = getOutput().getType();
1959 inputType.getRank() != outputType.getRank())
1960 return emitOpError()
1961 <<
"expected input tensor rank to equal result tensor rank";
1963 auto inputDims = inputType.
getShape();
1964 auto outputDims = outputType.
getShape();
1966 int64_t dim = it.index();
1967 auto [inputDim, outputDim] = it.value();
1968 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1970 <<
" doesn't match dim(input, " << dim
1971 <<
") = " << inputDim;
1983 multiples = llvm::to_vector(
1984 llvm::map_range(multiplesAttr.getValues<APInt>(),
1985 [](
const APInt &val) { return val.getSExtValue(); }));
1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990 MLIRContext *context, ::std::optional<Location> location,
1991 TileOp::Adaptor adaptor,
1998 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2006 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2008 if (!inputShape.hasRank()) {
2009 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010 inferredReturnShapes.push_back(
2013 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2017 outputShape.reserve(multiples.size());
2018 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2019 if (multiples[i] == ShapedType::kDynamic) {
2020 outputShape.push_back(ShapedType::kDynamic);
2022 int64_t dim = inputShape.getDimSize(i);
2023 if (dim != ShapedType::kDynamic)
2024 dim *= multiples[i];
2025 outputShape.push_back(dim);
2039 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2040 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2042 shapeType multiplesType =
2043 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2045 auto multiplesRank = multiplesType.getRank();
2047 if (inputType.hasRank()) {
2048 if (inputType.getRank() != multiplesRank)
2049 return emitOpError(
"expect 'multiples' to have rank ")
2050 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2051 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052 return emitOpError(
"expect same input and output tensor rank.");
2053 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054 return emitOpError(
"expect 'multiples' array to have length ")
2055 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2058 if (getConstantMultiples(multiples).succeeded() &&
2059 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2061 "expect element of 'multiples' to be positive integer or -1.");
2067 if (l.size() != r.size() || l.size() != 1)
2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073 MLIRContext *context, ::std::optional<Location> location,
2074 ReshapeOp::Adaptor adaptor,
2076 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2081 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2091 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092 inferredReturnShapes.push_back(
2100 int64_t numElements = inputShape.getNumElements();
2101 int64_t staticMul = 1;
2102 for (
auto val : newShapeValue) {
2103 if (!ShapedType::isDynamic(val)) {
2109 for (
auto &val : newShapeValue) {
2110 if (ShapedType::isDynamic(val))
2111 val = numElements / staticMul;
2114 inferredReturnShapes.push_back(
2125 TensorType inputType = getInput1().getType();
2130 return mlir::success();
2133 int missingDims = llvm::count(shapeValues, -1);
2134 if (missingDims > 1)
2135 return emitOpError() <<
"expected at most one target dimension to be -1";
2137 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2141 if ((int64_t)shapeValues.size() != outputType.getRank())
2142 return emitOpError() <<
"new shape does not match result rank";
2144 for (
auto [newShapeDim, outputShapeDim] :
2145 zip(shapeValues, outputType.getShape())) {
2146 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148 return emitOpError() <<
"new shape is inconsistent with result shape";
2150 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151 return emitOpError() <<
"new shape has invalid tensor dimension size "
2155 if (inputType.hasStaticShape()) {
2156 int64_t inputElementsNum = inputType.getNumElements();
2157 if (outputType.hasStaticShape()) {
2158 int64_t outputElementsNum = outputType.getNumElements();
2159 if (inputElementsNum != outputElementsNum) {
2160 return emitOpError() <<
"cannot reshape " << inputElementsNum
2161 <<
" elements into " << outputElementsNum;
2165 int64_t newShapeElementsNum = std::accumulate(
2166 shapeValues.begin(), shapeValues.end(), 1LL,
2167 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168 bool isStaticNewShape =
2169 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2170 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172 return emitOpError() <<
"cannot reshape " << inputElementsNum
2173 <<
" elements into " << newShapeElementsNum;
2177 return mlir::success();
2184 ElementsAttr zpAttr;
2189 Type zpElemType = zpAttr.getElementType();
2191 if (llvm::isa<FloatType>(zpElemType)) {
2192 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2199 if (llvm::isa<IntegerType>(zpElemType)) {
2201 return zpAttr.getValues<APInt>()[0].getSExtValue();
2203 return zpAttr.getValues<APInt>()[0].getZExtValue();
2210 template <
typename T>
2212 const std::string &operand) {
2215 if (!zpElemType.
isInteger(8) && zp != 0) {
2217 std::string lower = operand;
2218 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2219 return op.emitOpError()
2220 << lower <<
" zero point must be zero for non-int8 integer types";
2228 const std::string &operand) {
2229 bool isInputZp = (operand ==
"Input");
2231 bool tensorUnsigned =
2232 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233 StringRef tensorName = isInputZp ?
"input" :
"output";
2239 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2240 return op.emitOpError()
2241 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2243 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2244 return op.emitOpError() <<
"expect " << tensorName
2245 <<
"_zp of 0 or 32768 for unsigned int16 "
2246 << tensorName <<
", got " << zp;
2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2257 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2277 #undef ZERO_POINT_HELPER
2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280 MLIRContext *context, ::std::optional<Location> location,
2281 TransposeOp::Adaptor adaptor,
2283 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2292 const auto inputRank = inputShape.
getRank();
2296 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2302 if (inputRank == 0) {
2308 bool allTheSame =
true;
2309 for (
int i = 1, s = inputRank; i < s; i++) {
2319 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2324 outputShape.resize(inputRank, ShapedType::kDynamic);
2327 if (llvm::any_of(adaptor.getPerms(),
2328 [inputRank](
const auto i) { return i >= inputRank; }))
2331 outputShape.reserve(inputRank);
2332 for (
int i = 0, s = inputRank; i < s; i++) {
2333 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2352 if (inputShape.hasRank() &&
2353 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2354 return emitOpError() <<
"expected perms attribute to have size "
2355 << inputShape.getRank()
2356 <<
" (input rank) but got size "
2357 << constantPerms.size();
2359 if (inputShape.hasRank() && outputShape.hasRank() &&
2360 inputShape.getRank() != outputShape.getRank())
2361 return emitOpError()
2362 <<
"expected input tensor rank to equal result tensor rank";
2364 if (outputShape.hasRank() &&
2365 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2366 return emitOpError() <<
"expected perms attribute to have size "
2367 << outputShape.getRank()
2368 <<
" (output rank) but got size "
2369 << constantPerms.size();
2371 if (!llvm::all_of(constantPerms,
2372 [&constantPerms](int32_t s) {
2374 static_cast<size_t>(s) < constantPerms.size();
2377 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2378 return emitOpError() <<
"expected valid permutation indices";
2381 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382 inputShape.getNumElements() != outputShape.getNumElements())
2383 return emitOpError() <<
"expected input1 and output to have same numbers "
2385 << inputShape.getNumElements() <<
" and "
2386 << outputShape.getNumElements();
2390 if (inputShape.hasRank() && outputShape.hasRank()) {
2391 for (
auto i = 0; i < outputShape.getRank(); i++) {
2392 if (inputShape.isDynamicDim(constantPerms[i]) ||
2393 outputShape.isDynamicDim(i))
2396 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397 return emitOpError()
2398 <<
"expected output tensor dim " << i <<
" to match "
2399 <<
"input dim " << constantPerms[i] <<
" with value of "
2400 << inputShape.getDimSize(constantPerms[i]);
2412 Value input = getInput1();
2413 auto inputType = cast<TensorType>(input.
getType());
2416 for (
auto dim : transposePerms) {
2417 int32_t dimInInput = transposePerms[dim];
2418 if (inputType.isDynamicDim(dimInInput))
2420 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
2424 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2427 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432 MLIRContext *context, ::std::optional<Location> location,
2433 GatherOp::Adaptor adaptor,
2436 outputShape.resize(3, ShapedType::kDynamic);
2438 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439 if (valuesShape.hasRank()) {
2440 outputShape[0] = valuesShape.getDimSize(0);
2441 outputShape[2] = valuesShape.getDimSize(2);
2444 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445 if (indicesShape.hasRank()) {
2446 if (outputShape[0] == ShapedType::kDynamic)
2447 outputShape[0] = indicesShape.getDimSize(0);
2448 if (outputShape[1] == ShapedType::kDynamic)
2449 outputShape[1] = indicesShape.getDimSize(1);
2467 int64_t N = ShapedType::kDynamic;
2468 int64_t
W = ShapedType::kDynamic;
2469 int64_t
C = ShapedType::kDynamic;
2471 if (valuesShape.hasRank()) {
2472 N = valuesShape.getDimSize(0);
2473 C = valuesShape.getDimSize(2);
2475 if (indicesShape.hasRank()) {
2476 const int64_t indicesN = indicesShape.getDimSize(0);
2477 W = indicesShape.getDimSize(1);
2478 if (N == ShapedType::kDynamic)
2480 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2482 <<
", got " << indicesN;
2484 if (outputShape.hasRank()) {
2485 const int64_t outputN = outputShape.getDimSize(0);
2486 const int64_t outputW = outputShape.getDimSize(1);
2487 const int64_t outputC = outputShape.getDimSize(2);
2488 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2490 return emitOpError() <<
"requires output dimension 0 to have size " << N
2491 <<
", got " << outputN;
2493 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2495 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2496 <<
", got " << outputW;
2497 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2499 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2500 <<
", got " << outputC;
2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506 MLIRContext *context, ::std::optional<Location> location,
2507 ResizeOp::Adaptor adaptor,
2510 outputShape.resize(4, ShapedType::kDynamic);
2513 if (!inputShape.hasRank())
2516 outputShape[0] = inputShape.getDimSize(0);
2517 outputShape[3] = inputShape.getDimSize(3);
2518 int64_t inputHeight = inputShape.getDimSize(1);
2519 int64_t inputWidth = inputShape.getDimSize(2);
2521 if ((inputHeight == ShapedType::kDynamic) ||
2522 (inputWidth == ShapedType::kDynamic))
2537 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2542 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2551 const Value input = getInput();
2552 const Value output = getOutput();
2553 const RankedTensorType inputType =
2554 llvm::dyn_cast<RankedTensorType>(input.
getType());
2555 const RankedTensorType outputType =
2556 llvm::dyn_cast<RankedTensorType>(output.
getType());
2568 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2569 return emitOpError(
"expect all scale values to be > 0, got ")
2572 const int64_t scaleYN = scaleValues[0];
2573 const int64_t scaleYD = scaleValues[1];
2574 const int64_t scaleXN = scaleValues[2];
2575 const int64_t scaleXD = scaleValues[3];
2577 const int64_t offsetY = offsetValues[0];
2578 const int64_t offsetX = offsetValues[1];
2580 const int64_t borderY = borderValues[0];
2581 const int64_t borderX = borderValues[1];
2588 const int64_t oh = outputType.getDimSize(1);
2589 const int64_t ow = outputType.getDimSize(2);
2590 const int64_t ih = inputType.getDimSize(1);
2591 const int64_t iw = inputType.getDimSize(2);
2597 if (ih != ShapedType::kDynamic && ih != 1) {
2598 const std::optional<int64_t> calculatedOutHeightMinusOne =
2599 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2600 if (!calculatedOutHeightMinusOne.has_value())
2601 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2603 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2604 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2605 <<
") / " << scaleYD;
2606 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2607 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2608 return emitOpError(
"calculated output height did not match expected: ")
2609 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2616 if (iw != ShapedType::kDynamic && iw != 1) {
2617 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2618 const std::optional<int64_t> calculatedOutWidthMinusOne =
2620 if (!calculatedOutWidthMinusOne.has_value())
2621 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2623 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2624 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2625 <<
") / " << scaleXD;
2626 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2627 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2628 return emitOpError(
"calculated output width did not match expected: ")
2629 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2635 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2636 MLIRContext *context, ::std::optional<Location> location,
2637 ScatterOp::Adaptor adaptor,
2640 outputShape.resize(3, ShapedType::kDynamic);
2642 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2643 if (valuesInShape.hasRank()) {
2644 outputShape[0] = valuesInShape.getDimSize(0);
2645 outputShape[1] = valuesInShape.getDimSize(1);
2646 outputShape[2] = valuesInShape.getDimSize(2);
2649 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2650 if (indicesShape.hasRank()) {
2651 if (outputShape[0] == ShapedType::kDynamic)
2652 outputShape[0] = indicesShape.getDimSize(0);
2656 if (inputShape.hasRank()) {
2657 if (outputShape[0] == ShapedType::kDynamic)
2658 outputShape[0] = inputShape.getDimSize(0);
2659 if (outputShape[2] == ShapedType::kDynamic)
2660 outputShape[2] = inputShape.getDimSize(2);
2682 int64_t N = ShapedType::kDynamic;
2683 int64_t K = ShapedType::kDynamic;
2684 int64_t
W = ShapedType::kDynamic;
2685 int64_t
C = ShapedType::kDynamic;
2686 if (valuesInShape.hasRank()) {
2687 N = valuesInShape.getDimSize(0);
2688 K = valuesInShape.getDimSize(1);
2689 C = valuesInShape.getDimSize(2);
2691 if (indicesShape.hasRank()) {
2692 const int64_t indicesN = indicesShape.getDimSize(0);
2693 W = indicesShape.getDimSize(1);
2694 if (N == ShapedType::kDynamic)
2696 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2697 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2698 <<
", got " << indicesN;
2700 if (inputShape.hasRank()) {
2701 const int64_t inputN = inputShape.getDimSize(0);
2702 const int64_t inputW = inputShape.getDimSize(1);
2703 const int64_t inputC = inputShape.getDimSize(2);
2704 if (N == ShapedType::kDynamic)
2706 else if (inputN != ShapedType::kDynamic && N != inputN)
2707 return emitOpError() <<
"requires input dimension 0 to have size " << N
2708 <<
", got " << inputN;
2709 if (W == ShapedType::kDynamic)
2711 else if (inputW != ShapedType::kDynamic && W != inputW)
2712 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2713 <<
", got " << inputW;
2715 if (C == ShapedType::kDynamic)
2717 else if (inputC != ShapedType::kDynamic && C != inputC)
2718 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2719 <<
", got " << inputC;
2721 if (outputShape.hasRank()) {
2722 const int64_t outputN = outputShape.getDimSize(0);
2723 const int64_t outputK = outputShape.getDimSize(1);
2724 const int64_t outputC = outputShape.getDimSize(2);
2725 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2727 return emitOpError() <<
"requires values_out dimension 0 to have size "
2728 << N <<
", got " << outputN;
2729 if (K == ShapedType::kDynamic)
2731 else if (outputK != ShapedType::kDynamic && K != outputK)
2732 return emitOpError() <<
"requires values_out dimension 1 to have size "
2733 << K <<
", got " << outputK;
2734 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2736 return emitOpError() <<
"requires values_out dimension 2 to have size "
2737 <<
C <<
", got " << outputC;
2739 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2740 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2749 int64_t axisVal = axis.getValue().getSExtValue();
2750 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2756 operandShape.
getDims(outputShape);
2757 outputShape[axisVal] = 1;
2762 #define COMPATIBLE_RETURN_TYPES(OP) \
2763 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2764 if (l.size() != r.size() || l.size() != 1) \
2766 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2768 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2771 #define REDUCE_SHAPE_INFER(OP) \
2772 LogicalResult OP::inferReturnTypeComponents( \
2773 MLIRContext *context, ::std::optional<Location> location, \
2774 OP::Adaptor adaptor, \
2775 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2777 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2778 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2779 const Properties &prop = adaptor.getProperties(); \
2780 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2781 inferredReturnShapes); \
2783 COMPATIBLE_RETURN_TYPES(OP)
2791 #undef REDUCE_SHAPE_INFER
2793 #undef COMPATIBLE_RETURN_TYPES
2795 template <
typename T>
2798 TensorType inputType = op.getInput().getType();
2799 TensorType outputType = op.getOutput().getType();
2800 int32_t reduceAxis = op.getAxis();
2802 if (reduceAxis < 0) {
2803 op.emitOpError(
"reduce axis must not be negative");
2807 int64_t inputRank = inputType.getRank();
2810 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2811 op.emitOpError(
"expect input tensor rank (")
2812 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2818 int64_t outputRank = outputType.getRank();
2819 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2821 "expect output tensor rank to be equal to input tensor rank");
2824 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2825 op.emitOpError(
"expect output tensor rank (")
2826 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2832 if (outputRank != 0) {
2833 auto outputShape = outputType.
getShape();
2834 if (!outputType.isDynamicDim(reduceAxis) &&
2835 outputShape[reduceAxis] != 1) {
2836 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2837 << outputShape[reduceAxis];
2864 #define NARY_SHAPE_INFER(OP) \
2865 LogicalResult OP::inferReturnTypeComponents( \
2866 MLIRContext *context, ::std::optional<Location> location, \
2867 ValueShapeRange operands, DictionaryAttr attributes, \
2868 OpaqueProperties properties, RegionRange regions, \
2869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2870 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2910 #undef PRED_SHAPE_INFER
2912 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2913 MLIRContext *context, ::std::optional<Location> location,
2914 NegateOp::Adaptor adaptor,
2916 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2923 const Type input1Type = getInput1().getType();
2924 const Type outputType = getOutput().getType();
2931 return emitOpError() <<
"requires the same shape for input1 and output";
2934 const Type input1ZpEType =
2936 if (input1EType != input1ZpEType) {
2937 return emitOpError(
"expect both input1 and its zero point are the same "
2938 "element type, got ")
2939 << input1EType <<
" and " << input1ZpEType;
2942 const Type outputZpEType =
2944 if (outputEType != outputZpEType) {
2945 return emitOpError(
"expect both output and its zero point are the same "
2946 "element type, got ")
2947 << outputEType <<
" and " << outputZpEType;
2950 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2951 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2954 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2955 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2966 outputShape.resize(4, ShapedType::kDynamic);
2981 if (!ShapedType::isDynamic(height)) {
2982 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2983 outputShape[1] = padded / stride[0] + 1;
2986 if (!ShapedType::isDynamic(width)) {
2987 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2988 outputShape[2] = padded / stride[1] + 1;
2995 LogicalResult Conv2DOp::inferReturnTypeComponents(
2996 MLIRContext *context, ::std::optional<Location> location,
2997 Conv2DOp::Adaptor adaptor,
3001 int64_t inputWidth = ShapedType::kDynamic;
3002 int64_t inputHeight = ShapedType::kDynamic;
3003 int64_t weightWidth = ShapedType::kDynamic;
3004 int64_t weightHeight = ShapedType::kDynamic;
3009 if (inputShape.hasRank()) {
3010 outputShape[0] = inputShape.getDimSize(0);
3011 inputHeight = inputShape.getDimSize(1);
3012 inputWidth = inputShape.getDimSize(2);
3016 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3017 if (weightShape.hasRank()) {
3018 outputShape[3] = weightShape.getDimSize(0);
3019 weightHeight = weightShape.getDimSize(1);
3020 weightWidth = weightShape.getDimSize(2);
3025 if (biasShape.hasRank()) {
3026 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3027 ? biasShape.getDimSize(0)
3035 if (!ShapedType::isDynamic(inputHeight) &&
3036 !ShapedType::isDynamic(weightHeight)) {
3037 int64_t inputSize = inputHeight + padding[0] + padding[1];
3038 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3039 int64_t unstridedResult = inputSize - filterSize + 1;
3040 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3043 if (!ShapedType::isDynamic(inputWidth) &&
3044 !ShapedType::isDynamic(weightWidth)) {
3045 int64_t inputSize = inputWidth + padding[2] + padding[3];
3046 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3047 int64_t unstridedResult = inputSize - filterSize + 1;
3048 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3062 LogicalResult Conv3DOp::inferReturnTypeComponents(
3063 MLIRContext *context, ::std::optional<Location> location,
3064 Conv3DOp::Adaptor adaptor,
3068 int64_t inputWidth = ShapedType::kDynamic;
3069 int64_t inputHeight = ShapedType::kDynamic;
3070 int64_t inputDepth = ShapedType::kDynamic;
3072 int64_t weightWidth = ShapedType::kDynamic;
3073 int64_t weightHeight = ShapedType::kDynamic;
3074 int64_t weightDepth = ShapedType::kDynamic;
3078 if (inputShape.hasRank()) {
3079 outputShape[0] = inputShape.getDimSize(0);
3080 inputDepth = inputShape.getDimSize(1);
3081 inputHeight = inputShape.getDimSize(2);
3082 inputWidth = inputShape.getDimSize(3);
3086 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3087 if (weightShape.hasRank()) {
3088 outputShape[4] = weightShape.getDimSize(0);
3089 weightDepth = weightShape.getDimSize(1);
3090 weightHeight = weightShape.getDimSize(2);
3091 weightWidth = weightShape.getDimSize(3);
3096 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3097 outputShape[4] = biasShape.getDimSize(0);
3104 if (!ShapedType::isDynamic(inputDepth) &&
3105 !ShapedType::isDynamic(weightDepth)) {
3106 int32_t inputSize = inputDepth + pad[0] + pad[1];
3107 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3108 int32_t unstridedResult = inputSize - filterSize + 1;
3109 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3112 if (!ShapedType::isDynamic(inputHeight) &&
3113 !ShapedType::isDynamic(weightHeight)) {
3114 int32_t inputSize = inputHeight + pad[2] + pad[3];
3115 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3116 int32_t unstridedResult = inputSize - filterSize + 1;
3117 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3120 if (!ShapedType::isDynamic(inputWidth) &&
3121 !ShapedType::isDynamic(weightWidth)) {
3122 int32_t inputSize = inputWidth + pad[4] + pad[5];
3123 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3124 int32_t unstridedResult = inputSize - filterSize + 1;
3125 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3139 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3140 MLIRContext *context, ::std::optional<Location> location,
3141 AvgPool2dOp::Adaptor adaptor,
3144 const Properties &prop = adaptor.getProperties();
3146 inferredReturnShapes);
3149 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3150 MLIRContext *context, ::std::optional<Location> location,
3151 MaxPool2dOp::Adaptor adaptor,
3154 const Properties &prop = adaptor.getProperties();
3156 inferredReturnShapes);
3170 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3171 MLIRContext *context, ::std::optional<Location> location,
3172 DepthwiseConv2DOp::Adaptor adaptor,
3176 int64_t inputWidth = ShapedType::kDynamic;
3177 int64_t inputHeight = ShapedType::kDynamic;
3178 int64_t inputChannels = ShapedType::kDynamic;
3180 int64_t weightWidth = ShapedType::kDynamic;
3181 int64_t weightHeight = ShapedType::kDynamic;
3182 int64_t depthChannels = ShapedType::kDynamic;
3186 if (inputShape.hasRank()) {
3187 outputShape[0] = inputShape.getDimSize(0);
3188 inputHeight = inputShape.getDimSize(1);
3189 inputWidth = inputShape.getDimSize(2);
3190 inputChannels = inputShape.getDimSize(3);
3194 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3195 if (weightShape.hasRank()) {
3196 weightHeight = weightShape.getDimSize(0);
3197 weightWidth = weightShape.getDimSize(1);
3198 inputChannels = ShapedType::isDynamic(inputChannels)
3199 ? weightShape.getDimSize(2)
3201 depthChannels = weightShape.getDimSize(3);
3206 if (!ShapedType::isDynamic(inputChannels) &&
3207 !ShapedType::isDynamic(depthChannels)) {
3208 outputShape[3] = inputChannels * depthChannels;
3213 if (biasShape.hasRank()) {
3214 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3215 ? biasShape.getDimSize(0)
3223 if (!ShapedType::isDynamic(inputHeight) &&
3224 !ShapedType::isDynamic(weightHeight)) {
3225 int64_t inputSize = inputHeight + padding[0] + padding[1];
3226 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3227 int64_t unstridedResult = inputSize - filterSize + 1;
3228 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3231 if (!ShapedType::isDynamic(inputWidth) &&
3232 !ShapedType::isDynamic(weightWidth)) {
3233 int64_t inputSize = inputWidth + padding[2] + padding[3];
3234 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3235 int64_t unstridedResult = inputSize - filterSize + 1;
3236 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3250 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3251 MLIRContext *context, ::std::optional<Location> location,
3252 TransposeConv2DOp::Adaptor adaptor,
3256 int64_t inputWidth = ShapedType::kDynamic;
3257 int64_t inputHeight = ShapedType::kDynamic;
3258 int64_t weightWidth = ShapedType::kDynamic;
3259 int64_t weightHeight = ShapedType::kDynamic;
3263 if (inputShape.hasRank()) {
3264 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3265 ? inputShape.getDimSize(0)
3267 inputHeight = inputShape.getDimSize(1);
3268 inputWidth = inputShape.getDimSize(2);
3272 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3273 if (weightShape.hasRank()) {
3274 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3275 ? weightShape.getDimSize(0)
3277 weightHeight = weightShape.getDimSize(1);
3278 weightWidth = weightShape.getDimSize(2);
3283 if (biasShape.hasRank()) {
3284 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285 ? biasShape.getDimSize(0)
3292 if (!ShapedType::isDynamic(inputHeight) &&
3293 !ShapedType::isDynamic(weightHeight)) {
3294 int64_t calculateSize =
3295 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3297 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3300 if (!ShapedType::isDynamic(inputWidth) &&
3301 !ShapedType::isDynamic(weightWidth)) {
3302 int64_t calculateSize =
3303 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3305 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3317 const int64_t strideY = strides[0];
3318 const int64_t strideX = strides[1];
3320 if (strideY < 1 || strideX < 1)
3321 return emitOpError(
"expect all stride values to be >= 1, got [")
3324 const auto checkPadAgainstKernelDim =
3325 [
this](int64_t pad_value, int64_t kernel_dim_size,
3326 llvm::StringRef pad_name,
3327 llvm::StringRef kernel_dim_name) -> LogicalResult {
3328 if (pad_value <= -kernel_dim_size)
3329 return emitOpError(
"expected ")
3330 << pad_name <<
" > -" << kernel_dim_name
3331 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3332 << kernel_dim_name <<
"=" << kernel_dim_size;
3337 const int64_t outPadTop = padding[0];
3338 const int64_t outPadBottom = padding[1];
3339 const int64_t outPadLeft = padding[2];
3340 const int64_t outPadRight = padding[3];
3342 const auto weightType =
3343 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3346 const int64_t kernelHeight = weightType.getDimSize(1);
3347 if (!ShapedType::isDynamic(kernelHeight)) {
3348 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349 "out_pad_top",
"KH")))
3352 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353 "out_pad_bottom",
"KH")))
3357 const int64_t kernelWidth = weightType.getDimSize(2);
3358 if (!ShapedType::isDynamic(kernelWidth)) {
3359 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360 "out_pad_left",
"KW")))
3363 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364 "out_pad_right",
"KW")))
3370 const auto outputType =
3371 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3375 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3376 if (inputType && weightType) {
3377 const int64_t inputHeight = inputType.getDimSize(1);
3378 const int64_t kernelHeight = weightType.getDimSize(1);
3379 const int64_t outputHeight = outputType.getDimSize(1);
3381 if (!ShapedType::isDynamic(inputHeight) &&
3382 !ShapedType::isDynamic(outputHeight)) {
3384 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3386 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387 "+ out_pad_top + out_pad_bottom + KH, but got ")
3388 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3389 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3390 <<
" + " << kernelHeight;
3393 const int64_t inputWidth = inputType.getDimSize(2);
3394 const int64_t kernelWidth = weightType.getDimSize(2);
3395 const int64_t outputWidth = outputType.getDimSize(2);
3397 if (!ShapedType::isDynamic(inputWidth) &&
3398 !ShapedType::isDynamic(outputWidth)) {
3400 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3402 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3403 "+ out_pad_left + out_pad_right + KW, but got ")
3404 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3405 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3410 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3415 const int64_t biasChannels = biasType.getDimSize(0);
3418 if (biasChannels == ShapedType::kDynamic)
3421 const int64_t outputChannels = outputType.getDimSize(3);
3422 if (biasChannels != outputChannels && biasChannels != 1)
3424 "bias channels expected to be equal to output channels (")
3425 << outputChannels <<
") or 1, got " << biasChannels;
3431 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3433 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3437 auto inputElementType =
3439 if (!mlir::isa<IntegerType>(inputElementType)) {
3440 emitOpError(
"expect input to have integer element type, got ")
3441 << inputElementType;
3445 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3447 emitOpError(
"expect shaped tensor for output, got ")
3448 << getOutput().getType();
3452 auto outputElementType =
3454 if (!mlir::isa<IntegerType>(outputElementType)) {
3455 emitOpError(
"expect output to have integer element type, got ")
3456 << outputElementType;
3468 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3472 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3476 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3477 if (!multiplierType) {
3478 emitOpError(
"expect shaped tensor for multiplier, got ")
3479 << getMultiplier().getType();
3483 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3485 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3490 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3492 << multiplierType.getElementType();
3497 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3499 "expect i16 element type for multiplier for scale32=false, got ")
3500 << multiplierType.getElementType();
3504 if (!inputType.hasRank())
3510 int64_t numChannels = 1;
3511 if (getPerChannel()) {
3512 if (inputType.getRank() < 1) {
3513 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3514 "true, but got rank ")
3515 << inputType.getRank();
3518 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3521 if (!multiplierType.hasRank())
3526 if (multiplierShape[0] != ShapedType::kDynamic &&
3527 multiplierShape[0] != numChannels) {
3528 emitOpError(
"expect shape of { ")
3529 << numChannels <<
" } for multiplier input, got { "
3530 << multiplierShape[0] <<
" }";
3534 if (!shiftType.hasRank())
3539 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540 emitOpError(
"expect shape of { ")
3541 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3548 LogicalResult RescaleOp::inferReturnTypeComponents(
3549 MLIRContext *context, ::std::optional<Location> location,
3550 RescaleOp::Adaptor adaptor,
3557 LogicalResult IfOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional<Location> location,
3559 IfOp::Adaptor adaptor,
3562 for (
Region *region : adaptor.getRegions()) {
3563 for (
auto &block : *region)
3564 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3565 yieldOps.push_back(returnOp);
3568 if (yieldOps.empty())
3573 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574 for (
auto operand : yieldOps.front().getOperands()) {
3575 resultKnowledge.push_back(
3579 for (
auto yieldOp : yieldOps) {
3580 if (resultKnowledge.size() != yieldOp.getNumOperands())
3584 int32_t index = it.index();
3586 resultKnowledge[index],
3590 resultKnowledge[index] = meet;
3595 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3601 LogicalResult WhileOp::inferReturnTypeComponents(
3602 MLIRContext *context, ::std::optional<Location> location,
3603 WhileOp::Adaptor adaptor,
3606 for (
auto &block : adaptor.getBodyGraph())
3607 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3608 yieldOps.push_back(returnOp);
3612 if (yieldOps.empty())
3617 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618 for (
auto operand : yieldOps.front().getOperands()) {
3619 resultKnowledge.push_back(
3623 for (
auto yieldOp : yieldOps) {
3624 if (resultKnowledge.size() != yieldOp.getNumOperands())
3628 int32_t index = it.index();
3630 resultKnowledge[index],
3632 resultKnowledge[index] = meet;
3638 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3646 return llvm::to_vector<4>(vt.getShape());
3647 return std::nullopt;
3684 bool printBlockTerminators =
false;
3686 p <<
" " << getCondition();
3687 if (!getResults().empty()) {
3688 p <<
" -> (" << getResultTypes() <<
")";
3690 printBlockTerminators =
true;
3695 printBlockTerminators);
3698 auto &elseRegion = getElseGraph();
3699 if (!elseRegion.
empty()) {
3703 printBlockTerminators);
3711 "'then_graph' arguments", getInputList(),
3717 "'else_graph' arguments", getInputList(),
3722 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3724 "'then_graph' results", getOutputList(),
3729 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3731 "'else_graph' results", getOutputList(),
3736 auto condType = getCondition().getType();
3738 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
3746 getOutputList(),
"'output_list'")
3751 "'cond_graph' arguments", getInputList(),
3757 "'body_graph' arguments", getInputList(),
3762 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3764 "'body_graph' results", getInputList(),
3771 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3772 if (condYield.getInputs().size() != 1)
3773 return emitOpError() <<
"require 'cond_graph' only have one result";
3775 auto condOutType = condYield.getInputs()[0].getType();
3777 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
3781 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
3792 TensorType inputType = getInput1().getType();
3793 TensorType outputType = getOutput().getType();
3794 int32_t reverseAxis = getAxis();
3796 if (reverseAxis < 0)
3797 return emitOpError(
"expected non-negative reverse axis");
3799 int64_t inputRank = inputType.getRank();
3802 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803 return emitOpError(
"expect input tensor rank (")
3804 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3808 int64_t outputRank = outputType.getRank();
3809 if (inputType.
hasRank() && outputRank != inputType.getRank())
3811 "expect output tensor rank to be equal to input tensor rank");
3812 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813 return emitOpError(
"expect output tensor rank (")
3814 << outputRank <<
") to be larger than reverse axis ("
3815 << reverseAxis <<
")";
3831 auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().
getType());
3832 if (!predicateType) {
3833 return emitOpError(
"expect shaped tensor for input1, got ")
3834 << getInput1().getType();
3836 auto predicateElementType = predicateType.getElementType();
3837 if (!predicateElementType.isInteger(1)) {
3838 return emitOpError(
"expect element type of bool for input1, got ")
3839 << predicateElementType;
3846 StringRef symName = getName();
3848 if (succeeded(varOp))
3849 return emitOpError(
"illegal to have multiple declaration of '")
3883 FunctionType functionType;
3888 result.
addTypes(functionType.getResults());
3890 if (functionType.getNumInputs() != operands.size()) {
3892 <<
"expected as many input types as operands "
3893 <<
"(expected " << operands.size() <<
" got "
3894 << functionType.getNumInputs() <<
")";
3904 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905 regionArgs[i].type = functionType.getInput(i);
3907 return failure(parser.
parseRegion(*cond, regionArgs) ||
3915 StringRef prefix =
"") {
3916 assert(blocksArgs.size() == initializers.size() &&
3917 "expected same length of arguments and initializers");
3918 if (initializers.empty())
3921 parser << prefix <<
'(';
3922 llvm::interleaveComma(
3923 llvm::zip(blocksArgs, initializers), parser,
3924 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3930 getInputList(),
" ");
3933 getResults().getTypes());
3948 if (llvm::isa<FloatType>(srcElemType)) {
3950 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3951 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3953 if (llvm::isa<IntegerType>(srcElemType)) {
3956 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3958 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3959 return std::nullopt;
3967 return mlir::isa<tosa::shapeType>(t);
3974 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3980 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981 Operation *definingOp = v.getDefiningOp();
3983 return op->
emitOpError(
"shape operand is not compile time resolvable");
3992 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3993 return op->
emitOpError(
"must have operands with tosa shape type");
3997 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3998 return op->
emitOpError(
"must have result with tosa shape type");
4011 auto getRank = [](
const Type type) {
4012 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4018 for (
auto type : operandTypes) {
4019 if (getRank(type) != rank) {
4020 return op->
emitOpError(
"operands don't have matching ranks");
4023 for (
auto type : resultTypes) {
4024 if (getRank(type) != rank) {
4025 return op->
emitOpError(
"result shape has different rank than operands");
4037 auto valuesRank = getValues().getType().getRank();
4038 if (valuesRank != 1)
4039 return emitOpError(
"expect elements in attribute values with rank 1");
4041 auto count = getValues().getNumElements();
4042 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4043 if (!(count == rank || (count == 1 && rank == 0))) {
4044 return emitOpError(
"expect number of elements in attribute values (")
4045 << count <<
") to be equal to the rank (" << rank
4046 <<
") for the result shape type";
4055 #define GET_ATTRDEF_CLASSES
4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4061 #define GET_TYPEDEF_CLASSES
4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4068 #define GET_OP_CLASSES
4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)