28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(
Dialect *dialect)
87 LogicalResult writeAttribute(
Attribute attr,
89 return ::writeAttribute(attr, writer);
99 LogicalResult writeType(
Type type,
101 return ::writeType(type, writer);
108 std::unique_ptr<DialectVersion>
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
154 void TosaDialect::initialize() {
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
265 if (initialValueAttr) {
277 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
285 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
286 srcType = quantType.getStorageType();
295 Value valZp, StringRef name) {
300 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
304 if (!bothInts || !sameBitWidth) {
306 <<
"expected " << name <<
" and " << name
307 <<
"_zp to both be integer of the same bitwidth, but got " << eType
308 <<
" vs. " << eZpType;
315 Value src, int32_t val) {
320 const auto padConstAttr{
321 llvm::isa<FloatType>(srcElemType)
326 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
333 template <
typename T>
335 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
336 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
338 auto inputEType = inputType.getElementType();
339 auto weightEType = weightType.getElementType();
341 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
343 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
344 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
345 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
347 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
348 inputEType = quantType.getStorageType();
350 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
351 weightEType = quantType.getStorageType();
353 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
354 biasEType = quantType.getStorageType();
356 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
357 resultEType = quantType.getStorageType();
359 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
363 "expect both bias and result to have same element type, got ")
364 << biasEType <<
" and " << resultEType;
368 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
369 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
370 if (inputEType != weightEType) {
372 "expect both input and weight to have same element type, got ")
373 << inputEType <<
" and " << weightEType;
378 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
379 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
382 if (inputIsFloat != weightIsFloat) {
384 "expect both input and weight to be float or not together, got ")
385 << inputEType <<
" and " << weightEType;
390 if (inputEType != inputZpEType) {
391 return op.emitOpError(
"expect both input and its zero point are the same "
392 "element type, got ")
393 << inputEType <<
" and " << inputZpEType;
397 if (weightEType != weightZpEType) {
398 return op.emitOpError(
"expect both weight and its zero point are the same "
399 "element type, got ")
400 << weightEType <<
" and " << weightZpEType;
403 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
404 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
407 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
408 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
416 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
417 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
419 if (!attrType || !outputType) {
420 emitOpError(
"expected tensors for attr/result type");
424 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
425 outputType.getElementType())) {
426 if (result.getStorageType() == attrType.getElementType())
430 if (attrType.getElementType() != outputType.getElementType()) {
431 emitOpError(
"expected same attr/result element types");
438 template <
typename T>
441 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
443 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
444 inputEType = quantType.getStorageType();
446 auto accType = op.getAccType();
447 if (inputEType.isInteger(8) && !accType.isInteger(32))
448 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
450 if (inputEType.isInteger(16) && !accType.isInteger(48))
451 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
453 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
454 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
456 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
457 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
459 if (inputEType.isBF16() && !accType.isF32())
460 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
462 if (inputEType.isF32() && !accType.isF32())
463 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
466 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
468 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
469 resultEType = quantType.getStorageType();
479 template <
typename T>
482 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
483 return op.emitOpError(
"expect all padding values to be >= 0, got ")
487 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
488 return op.emitOpError(
"expect all stride values to be >= 1, got ")
492 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
493 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
496 const RankedTensorType outputType =
497 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
502 const RankedTensorType inputType =
503 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
504 const RankedTensorType weightType =
505 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
507 if (inputType && weightType) {
508 const auto verifyOutputSize =
509 [&op](
const int64_t inputSize,
const int64_t kernelSize,
510 const int64_t outputSize,
const int64_t padBefore,
511 const int64_t padAfter,
const int64_t stride,
512 const int64_t dilation,
const llvm::StringRef dimName,
513 const llvm::StringRef dimAxis,
514 const llvm::StringRef padBeforeName,
515 const llvm::StringRef padAfterName) -> LogicalResult {
516 if (inputSize == ShapedType::kDynamic ||
517 kernelSize == ShapedType::kDynamic)
522 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
523 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
525 if (!calculatedOutSizeMinusOne.has_value())
526 return op.emitOpError(
"expected input_")
527 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
528 << padAfterName <<
" - (kernel_" << dimName
529 <<
" - 1) * dilation_" << dimAxis
530 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
531 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
532 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
535 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
536 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
537 return op.emitOpError(
"calculated output ")
538 << dimName <<
" did not match expected: "
539 <<
"calculated=" << calculatedOutSize
540 <<
", expected=" << outputSize;
546 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
547 if (failed(verifyOutputSize(
548 inputType.getDimSize(1), weightType.getDimSize(1),
549 outputType.getDimSize(1), padding[0], padding[1], strides[0],
550 dilations[0],
"height",
"y",
"top",
"bottom")))
553 if (failed(verifyOutputSize(
554 inputType.getDimSize(2), weightType.getDimSize(2),
555 outputType.getDimSize(2), padding[2], padding[3], strides[1],
556 dilations[1],
"width",
"x",
"left",
"right")))
561 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
562 if (failed(verifyOutputSize(
563 inputType.getDimSize(1), weightType.getDimSize(0),
564 outputType.getDimSize(1), padding[0], padding[1], strides[0],
565 dilations[0],
"height",
"y",
"top",
"bottom")))
568 if (failed(verifyOutputSize(
569 inputType.getDimSize(2), weightType.getDimSize(1),
570 outputType.getDimSize(2), padding[2], padding[3], strides[1],
571 dilations[1],
"width",
"x",
"left",
"right")))
576 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
577 if (failed(verifyOutputSize(
578 inputType.getDimSize(1), weightType.getDimSize(1),
579 outputType.getDimSize(1), padding[0], padding[1], strides[0],
580 dilations[0],
"depth",
"d",
"front",
"back")))
583 if (failed(verifyOutputSize(
584 inputType.getDimSize(2), weightType.getDimSize(2),
585 outputType.getDimSize(2), padding[2], padding[3], strides[1],
586 dilations[1],
"height",
"y",
"top",
"bottom")))
589 if (failed(verifyOutputSize(
590 inputType.getDimSize(3), weightType.getDimSize(3),
591 outputType.getDimSize(3), padding[4], padding[5], strides[2],
592 dilations[2],
"width",
"x",
"left",
"right")))
597 const RankedTensorType biasType =
598 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
603 const int64_t biasChannels = biasType.getDimSize(0);
604 const int64_t outputChannels =
605 outputType.getDimSize(outputType.getRank() - 1);
606 if (biasChannels == ShapedType::kDynamic ||
607 outputChannels == ShapedType::kDynamic)
611 if (biasChannels != outputChannels && biasChannels != 1)
612 return op.emitOpError(
613 "bias channels expected to be equal to output channels (")
614 << outputChannels <<
") or 1, got " << biasChannels;
621 StringRef name1,
Type type2,
623 auto shapeType1 = dyn_cast<ShapedType>(type1);
624 auto shapeType2 = dyn_cast<ShapedType>(type2);
625 if (!shapeType1 || !shapeType2)
628 auto elemType1 = shapeType1.getElementType();
629 auto elemType2 = shapeType2.getElementType();
630 if (elemType1 != elemType2)
632 <<
"require same element type for " << name1 <<
" (" << elemType1
633 <<
") and " << name2 <<
" (" << elemType2 <<
")";
637 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
638 << name2 <<
" (" << type2 <<
")";
648 if (list1.size() != list2.size())
650 <<
"require same number of values in " << name1 <<
" ("
651 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
653 for (
auto [type1, type2] :
667 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
675 tosa::VariableOp varOp =
nullptr;
689 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
690 if (symName == tosaOp.getName()) {
705 template <
typename T>
707 StringRef symName = op.getName();
710 return op->emitOpError(
"'")
711 << symName <<
"' has not been declared by 'tosa.variable'";
724 template <
typename T>
726 auto inputType = llvm::dyn_cast<TensorType>(inType);
727 auto outputType = llvm::dyn_cast<TensorType>(outType);
729 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
733 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
736 auto inputElementType = inputType.getElementType();
737 auto outputElementType = outputType.getElementType();
738 auto inputQuantType =
739 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
740 auto outputQuantType =
741 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
742 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
743 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
744 inputElementType != outputElementType) {
749 op.emitOpError(
"expect input and output to have same element type, got ")
750 << inputElementType <<
" and " << outputElementType;
757 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
760 if (
const auto resultETy = resultType.getElementType();
761 !resultETy.isIntOrIndex())
762 return emitOpError(
"result tensor is not of integer type");
764 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
765 if (!inputType.hasRank())
769 const int64_t axis = getAxisAttr().getInt();
770 if (((axis < 0) || axis >= inputType.getRank()))
771 return emitOpError(
"specified axis is outside the rank of the tensor");
773 if (!resultType.hasRank())
779 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
781 return emitOpError(
"expected output shape '")
782 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
787 template <
typename T>
790 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
791 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
795 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
796 return op.emitOpError(
"expect all stride values to be >= 1, got ")
800 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
801 return op.emitOpError(
"expect all padding values to be >= 0, got ")
805 const int64_t kernelX = kernel[1];
806 const int64_t padLeft = padding[2];
807 const int64_t padRight = padding[3];
808 if (padRight >= kernelX || padLeft >= kernelX)
809 return op.emitOpError(
"expected left/right padding to be less than the "
810 "width of the kernel, got pad_left=")
811 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
813 const int64_t kernelY = kernel[0];
814 const int64_t padTop = padding[0];
815 const int64_t padBottom = padding[1];
816 if (padTop >= kernelY || padBottom >= kernelY)
817 return op.emitOpError(
"expected top/bottom padding to be less than the "
818 "height of the kernel, got pad_top=")
819 << padTop <<
", pad_bottom=" << padBottom
820 <<
", kernel_y=" << kernelY;
822 const auto inputType =
823 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
824 const auto outputType =
825 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
826 if (!inputType || !outputType)
829 const auto verifyOutputSize =
830 [&op](
const int64_t inputSize,
const int64_t outputSize,
831 const int64_t kernelSize,
const int64_t strideSize,
832 const int64_t padBefore,
const int64_t padAfter,
833 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
834 const llvm::StringRef padBeforeName,
835 const llvm::StringRef padAfterName) -> LogicalResult {
836 if (ShapedType::isDynamic(inputSize))
839 const std::optional<int64_t> calculatedOutSizeMinusOne =
840 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
841 if (!calculatedOutSizeMinusOne.has_value())
842 return op.emitOpError(
"expected input_")
843 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
844 << padAfterName <<
" - kernel_" << dimAxis
845 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
846 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
847 << kernelSize <<
") / " << strideSize;
849 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
850 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
851 return op.emitOpError(
"calculated output ")
852 << dimName <<
" did not match expected: "
853 <<
"calculated=" << calculatedOutSize
854 <<
", expected=" << outputSize;
859 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
860 kernel[0], strides[0], padding[0], padding[1],
861 "height",
"y",
"top",
"bottom")))
864 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
865 kernel[1], strides[1], padding[2], padding[3],
866 "width",
"x",
"left",
"right")))
881 auto accType = getAccType();
882 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
883 return emitOpError(
"accumulator type for integer tensor is not i32");
885 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
886 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
888 if (inputETy.
isBF16() && !accType.isF32())
889 return emitOpError(
"accumulator type for bf16 tensor is not f32");
891 if (inputETy.
isF32() && !accType.isF32())
892 return emitOpError(
"accumulator type for f32 tensor is not f32");
894 if (inputETy != inputZpETy)
895 return emitOpError(
"expect both input and its zero point are the same "
896 "element type, got ")
897 << inputETy <<
" and " << inputZpETy;
899 if (resultETy != outputZpETy)
900 return emitOpError(
"expect both output and its zero point are the same "
901 "element type, got ")
902 << resultETy <<
" and " << outputZpETy;
904 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
905 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
908 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
909 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
917 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
919 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
920 inputETy = quantType.getStorageType();
923 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
925 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
926 outputETy = quantType.getStorageType();
928 if (inputETy != outputETy)
929 return emitOpError(
"input/output element types are incompatible.");
931 auto maxValAttr = getMaxValAttr();
932 auto minValAttr = getMinValAttr();
936 if (inputETy.
isInteger(dataTypeBitWidth)) {
940 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
941 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
942 if (!intMaxValAttr || !intMinValAttr ||
943 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
944 (intMaxValAttr.getType() != inputETy))
945 return emitOpError(
"min/max attributes types are incompatible with "
946 "input/output element types.");
949 const bool isBoolean = inputETy.
isInteger(1);
950 const APInt minVal = intMinValAttr.getValue();
951 const APInt maxVal = intMaxValAttr.getValue();
952 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
953 return emitOpError(
"expected min_val <= max_val, got min_val=")
954 << minValAttr <<
", max_val=" << maxValAttr;
959 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
960 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
961 if (!floatMaxValAttr || !floatMinValAttr ||
962 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
963 (floatMaxValAttr.getType() != inputETy))
964 return emitOpError(
"min/max attributes types are incompatible with "
965 "input/output element types.");
967 const APFloat minVal = floatMinValAttr.getValue();
968 const APFloat maxVal = floatMaxValAttr.getValue();
969 if (minVal.isNaN() || maxVal.isNaN())
970 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
971 << minValAttr <<
", max_val=" << maxValAttr;
974 return emitOpError(
"expected min_val <= max_val, got min_val=")
975 << minValAttr <<
", max_val=" << maxValAttr;
995 result.
addOperands({input, weight, bias, zps.first, zps.second});
1000 Type finalOutputType = outputType;
1017 result.
addOperands({input, weight, bias, zps.first, zps.second});
1021 Type finalOutputType = outputType;
1038 result.
addOperands({a, b, zps.first, zps.second});
1040 Type finalOutputType{outputType};
1043 auto inputBits = eType.getIntOrFloatBitWidth();
1045 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1046 assert(outputShapedType &&
"Output must be a shaped type");
1048 IntegerType accElementType;
1049 if (inputBits == 16)
1054 finalOutputType = outputShapedType.clone(accElementType);
1065 DenseArrayAttr kernel, DenseArrayAttr stride,
1066 DenseArrayAttr pad, TypeAttr accType) {
1069 int64_t outputZp{0};
1071 if (
auto quantAttr =
1073 inputZp = quantAttr.getInputZp();
1074 outputZp = quantAttr.getOutputZp();
1076 const std::optional<Value> inputZpOp =
1081 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1083 const std::optional<Value> outputZpOp =
1086 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1087 "quantized AVG_POOL2D op");
1090 if (inputZpOp && outputZpOp) {
1091 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1102 result.
types.push_back(outputType);
1112 int64_t input1Zp{0};
1113 int64_t outputZp{0};
1116 input1Zp = quantAttr.getInputZp();
1117 outputZp = quantAttr.getOutputZp();
1119 const std::optional<Value> input1ZpOp =
1123 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1126 const std::optional<Value> outputZpOp =
1130 loc,
"Failed to create output zero point for quantized NEGATE op");
1133 if (input1ZpOp && outputZpOp) {
1134 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1142 result.
types.push_back(outputType);
1155 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1158 result.
addOperands({input, paddings, padConstOp});
1159 result.
types.push_back(outputType);
1163 StringRef name,
Type variableType,
1168 auto shapedType = dyn_cast<ShapedType>(variableType);
1170 (void)
emitError(loc,
"variable type must be a shaped type");
1173 if (!shapedType.hasRank()) {
1174 (void)
emitError(loc,
"variable type must be a ranked type");
1178 auto elementType = shapedType.getElementType();
1195 int64_t outRank = 0;
1196 for (
int i = 0, e = operands.size(); i != e; ++i) {
1198 if (!shape.hasRank()) {
1203 outRank = std::max<int64_t>(outRank, shape.getRank());
1206 outShape.resize(outRank, 1);
1208 for (
int i = 0, e = operands.size(); i != e; ++i) {
1210 auto rankDiff = outShape.size() - shape.getRank();
1212 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1213 auto dim1 = outShape[i + rankDiff];
1214 auto dim2 = shape.getDimSize(i);
1215 auto resolvedDim = dim1;
1219 }
else if (dim2 == 1) {
1221 }
else if (dim1 != dim2) {
1224 outShape[i + rankDiff] = resolvedDim;
1231 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1232 MLIRContext *context, ::std::optional<Location> location,
1233 ArgMaxOp::Adaptor adaptor,
1236 IntegerAttr axis = adaptor.getProperties().axis;
1237 int32_t axisVal = axis.getValue().getSExtValue();
1239 if (!inputShape.hasRank()) {
1245 outShape.reserve(inputShape.getRank() - 1);
1246 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1249 outShape.push_back(inputShape.getDimSize(i));
1256 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1257 MLIRContext *context, ::std::optional<Location> location,
1258 RFFT2dOp::Adaptor adaptor,
1260 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1262 if (!inputShape.hasRank())
1266 outputShape.resize(3, ShapedType::kDynamic);
1267 outputShape[0] = inputShape.getDimSize(0);
1268 outputShape[1] = inputShape.getDimSize(1);
1269 int64_t inWidth = inputShape.getDimSize(2);
1273 if (inWidth != ShapedType::kDynamic)
1274 outputShape[2] = inWidth / 2 + 1;
1283 const llvm::StringRef dimName) {
1284 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1287 << dimName <<
" to be a power of two, got " << dimSize;
1293 const auto outputTypes = getResultTypes();
1295 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1297 const auto inputType =
1298 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1302 const int64_t height = inputType.getDimSize(1);
1303 if (ShapedType::isStatic(height) &&
1307 const int64_t width = inputType.getDimSize(2);
1308 if (ShapedType::isStatic(width) &&
1312 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1318 outputType.getShape().drop_back())))
1319 return emitOpError(
"expected batch and height dimensions of input/output "
1320 "to match, got input=")
1321 << inputType <<
" output=" << outputType;
1324 const int64_t outputWidth = outputType.getDimSize(2);
1325 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1326 (outputWidth != (width / 2) + 1))
1328 "expected output width to be equal to input_width / 2 + 1, got ")
1334 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1335 MLIRContext *context, ::std::optional<Location> location,
1336 FFT2dOp::Adaptor adaptor,
1338 inferredReturnShapes.push_back(
1340 inferredReturnShapes.push_back(
1346 const auto inputRealType =
1347 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1348 const auto inputImagType =
1349 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1350 if (!inputRealType || !inputImagType)
1353 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1354 return ShapedType::isDynamic(a) ? a : b;
1357 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1358 inputImagType.getDimSize(1));
1359 if (ShapedType::isStatic(height) &&
1363 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1364 inputImagType.getDimSize(2));
1365 if (ShapedType::isStatic(width) &&
1372 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1373 MLIRContext *context, ::std::optional<Location> location,
1374 ConcatOp::Adaptor adaptor,
1377 const Properties &prop = adaptor.getProperties();
1378 int32_t axis = prop.axis.getValue().getSExtValue();
1380 bool hasRankedInput =
false;
1381 for (
auto operand : adaptor.getOperands()) {
1383 if (!operandShape.hasRank())
1387 if (!hasRankedInput)
1388 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1391 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1392 if (i == axis || operandShape.isDynamicDim(i))
1394 if (outputShape[i] == ShapedType::kDynamic)
1395 outputShape[i] = operandShape.getDimSize(i);
1396 if (outputShape[i] != operandShape.getDimSize(i))
1398 "Cannot concat tensors with different sizes"
1399 " on the non-axis dimension ",
1403 hasRankedInput =
true;
1406 if (adaptor.getInput1().empty())
1410 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1411 if (!hasRankedInput) {
1417 int64_t concatDimSize = 0;
1418 for (
auto operand : adaptor.getOperands()) {
1423 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1424 concatDimSize = ShapedType::kDynamic;
1428 concatDimSize += operandShape.getDimSize(axis);
1431 outputShape[axis] = concatDimSize;
1439 auto outType = getOutput().getType();
1443 if (inputList.empty())
1444 return emitOpError(
"expect at least one input");
1446 if (!llvm::all_of(inputList, [&](
auto input) {
1448 *
this, input.getType(), outType));
1453 const int32_t axis = getAxis();
1455 for (
const auto &input : inputList) {
1456 const Type inputType = input.getType();
1458 if (currShape.hasRank()) {
1459 firstRankedInputShape = currShape;
1461 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1462 return emitOpError(
"expect axis to be within range 0 < axis < "
1463 "rank(input1[firstRankedTensorIdx]), got ")
1469 const auto allOperandsHasRank = [](
const Value input) {
1472 if (llvm::all_of(inputList, allOperandsHasRank)) {
1473 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1475 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1477 const int64_t inputRank = inputShape.getRank();
1478 const size_t operandNum = index + 1;
1481 if (inputRank != firstInputRank)
1483 "expect all operands to have the same rank, but got ")
1484 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1488 for (
int i = 0; i < inputRank; i++) {
1489 const int64_t inputDim = inputShape.getDimSize(i);
1490 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1491 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1492 inputShape.isDynamicDim(i))
1494 if (inputDim != firstInputDim)
1495 return emitOpError(
"expect all operand shapes to have the same sizes "
1496 "on non-axis dimensions, but got ")
1497 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1498 <<
" on operands 0 and " << operandNum;
1503 int64_t axisSum = 0;
1504 for (
const auto &input : inputList) {
1506 if (inputShape.isDynamicDim(axis)) {
1511 axisSum += inputShape.getDimSize(axis);
1514 if (axisSum >= 0 && outputShape.hasRank() &&
1515 !outputShape.isDynamicDim(axis) &&
1516 axisSum != outputShape.getDimSize(axis))
1517 return emitOpError(
"requires sum of axis dimensions of input1 "
1518 "equal to output axis dimension, got ")
1519 << axisSum <<
" and " << outputShape.getDimSize(axis);
1525 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1526 MLIRContext *context, ::std::optional<Location> location,
1543 if (l.size() != r.size() || l.size() != 1)
1548 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1549 MLIRContext *context, ::std::optional<Location> location,
1550 MatMulOp::Adaptor adaptor,
1557 outShape.resize(3, ShapedType::kDynamic);
1559 if (lhsShape.hasRank()) {
1560 outShape[0] = lhsShape.getDimSize(0);
1561 outShape[1] = lhsShape.getDimSize(1);
1564 if (rhsShape.hasRank()) {
1565 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1567 outShape[2] = rhsShape.getDimSize(2);
1575 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1576 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1580 return emitOpError(
"expect a shaped tensor for input a, got ")
1581 << getA().getType();
1584 return emitOpError(
"expect a shaped tensor for input b, got ")
1585 << getB().getType();
1587 auto aElementType = aType.getElementType();
1588 auto bElementType = bType.getElementType();
1590 auto aQuantizedEType =
1591 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1592 auto bQuantizedEType =
1593 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1595 if (aQuantizedEType || bQuantizedEType) {
1596 if (!aQuantizedEType || !bQuantizedEType) {
1597 return emitOpError(
"expect operands to be both quantized or both not "
1599 << aElementType <<
" and " << bElementType;
1602 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1603 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1604 if (aQuantWidth != bQuantWidth) {
1605 return emitOpError(
"expect quantized operands to have same widths, got ")
1606 << aQuantWidth <<
" and " << bQuantWidth;
1610 if (aElementType != bElementType) {
1611 return emitOpError(
"expect same element type for inputs a and b, got ")
1612 << aElementType <<
" and " << bElementType;
1619 if (aEType != aZpEType) {
1620 return emitOpError(
"expect input a and a_zp have the same "
1621 "element type, got ")
1622 << aEType <<
" and " << aZpEType;
1627 if (bEType != bZpEType) {
1628 return emitOpError(
"expect input b and b_zp have the same "
1629 "element type, got ")
1630 << bEType <<
" and " << bZpEType;
1633 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1634 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1637 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1638 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1644 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1645 MLIRContext *context, ::std::optional<Location> location,
1646 PadOp::Adaptor adaptor,
1648 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1650 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1655 if (!inputShape.hasRank()) {
1656 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1665 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1670 outputShape.reserve(inputShape.getRank());
1671 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1672 if (inputShape.isDynamicDim(i)) {
1673 outputShape.push_back(ShapedType::kDynamic);
1676 auto padFront = paddingValues[i * 2];
1677 auto padBack = paddingValues[i * 2 + 1];
1678 if (padFront < 0 || padBack < 0) {
1680 outputShape.push_back(ShapedType::kDynamic);
1684 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1698 if (
auto padConst = getPadConst()) {
1706 RankedTensorType inputType =
1707 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1708 RankedTensorType outputType =
1709 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1710 if (!inputType || !outputType)
1713 auto inputRank = inputType.getRank();
1714 auto outputRank = outputType.getRank();
1715 if (inputRank != outputRank)
1716 return emitOpError() <<
"expect same input and output tensor rank, but got "
1717 <<
"inputRank: " << inputRank
1718 <<
", outputRank: " << outputRank;
1725 auto paddingValues = paddingAttr.getValues<APInt>();
1726 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1727 return emitOpError() <<
"padding tensor must have " << inputRank
1728 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1729 << paddingValues.size();
1731 auto inputShape = inputType.getShape();
1732 auto outputShape = outputType.getShape();
1734 for (int64_t i = 0; i < inputRank; ++i) {
1735 int64_t padStart = paddingValues[i * 2].getSExtValue();
1736 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1738 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1739 return emitOpError()
1740 <<
"invalid padding values at dimension " << i
1741 <<
": values must be non-negative or -1 for dynamic padding, got ["
1742 << padStart <<
", " << padEnd <<
"]";
1746 if (inputShape[i] == ShapedType::kDynamic ||
1747 outputShape[i] == ShapedType::kDynamic)
1750 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1751 return emitOpError() <<
"mismatch in output shape at dimension " << i
1752 <<
": expected " << inputShape[i] <<
" + "
1753 << padStart <<
" + " << padEnd <<
" = "
1754 << (inputShape[i] + padStart + padEnd)
1755 <<
", but got " << outputShape[i];
1762 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1763 MLIRContext *context, ::std::optional<Location> location,
1764 SliceOp::Adaptor adaptor,
1773 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1781 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1784 if (inputShape.hasRank()) {
1785 for (
size_t i = 0; i < size.size(); i++) {
1786 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1787 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1788 start[i] < inputShape.getDimSize(i))) {
1790 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1793 outputShape[i] = size[i];
1797 if (size[i] == -1) {
1798 outputShape[i] = inputShape.getDimSize(i) - start[i];
1799 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1801 outputShape[i] = size[i];
1820 if (inputShape.hasRank()) {
1821 const auto inputRank = inputShape.getRank();
1823 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1825 "expect input1 and output to have the same ranks, got ")
1826 << inputRank <<
" and " << outputShape.getRank();
1828 const auto startShapeRank =
1829 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1830 if (inputRank != startShapeRank)
1831 return emitOpError(
"length of start is not equal to rank of input shape");
1833 const auto sizeShapeRank =
1834 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1835 if (inputRank != sizeShapeRank)
1836 return emitOpError(
"length of size is not equal to rank of input shape");
1842 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1843 MLIRContext *context, ::std::optional<Location> location,
1859 const Value output = getOutput();
1864 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1865 IntegerType lhsIntType =
1867 IntegerType rhsIntType =
1869 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1870 return emitOpError(
"requires the same element type for all operands");
1875 if (lhsIntType.getWidth() > resIntType.getWidth())
1876 return emitOpError(
"invalid data type size for operands or result");
1881 for (
int i = 0; i < 2; ++i) {
1884 "requires the same element type for all operands and results");
1888 ElementsAttr shift_elem;
1890 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1892 return emitOpError() <<
"require shift to be 0 for float type";
1900 TypeRange operandTypes = getOperandTypes();
1901 ShapedType aType = cast<ShapedType>(operandTypes[0]);
1902 ShapedType bType = cast<ShapedType>(operandTypes[1]);
1904 const bool aHasRank = aType.hasRank();
1905 const bool bHasRank = bType.hasRank();
1906 if (aHasRank && bHasRank) {
1907 const int64_t aRank = aType.getRank();
1908 const int64_t bRank = bType.getRank();
1910 return emitOpError(
"a and b operands don't have matching ranks, got ")
1911 << aRank <<
" and " << bRank;
1916 aType.getShape(), bType.getShape(), resultShape))
1917 return emitOpError(
"a and b operands don't have broadcast-compatible "
1919 << aType <<
" and " << bType;
1922 ShapedType resultType = cast<ShapedType>(output.
getType());
1923 if (!resultType.hasRank())
1926 const int64_t resultRank = resultType.getRank();
1927 if (aHasRank && resultRank != aType.getRank())
1928 return emitOpError(
"result type has different rank than a, got ")
1929 << resultRank <<
" vs " << aType.getRank();
1930 if (bHasRank && resultRank != bType.getRank())
1931 return emitOpError(
"result type has different rank than b, got ")
1932 << resultRank <<
" vs " << bType.getRank();
1937 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1938 MLIRContext *context, ::std::optional<Location> location,
1939 TableOp::Adaptor adaptor,
1941 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1943 if (!inputShape.hasRank()) {
1948 inferredReturnShapes.resize(1);
1949 inputShape.getDims(inferredReturnShapes[0]);
1954 TensorType inputType = getInput1().getType();
1955 TensorType outputType = getOutput().getType();
1958 inputType.getRank() != outputType.getRank())
1959 return emitOpError()
1960 <<
"expected input tensor rank to equal result tensor rank";
1962 auto inputDims = inputType.
getShape();
1963 auto outputDims = outputType.
getShape();
1965 int64_t dim = it.index();
1966 auto [inputDim, outputDim] = it.value();
1967 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
1968 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1969 <<
" doesn't match dim(input, " << dim
1970 <<
") = " << inputDim;
1982 multiples = llvm::to_vector(
1983 llvm::map_range(multiplesAttr.getValues<APInt>(),
1984 [](
const APInt &val) { return val.getSExtValue(); }));
1988 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1989 MLIRContext *context, ::std::optional<Location> location,
1990 TileOp::Adaptor adaptor,
1997 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2005 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2007 if (!inputShape.hasRank()) {
2008 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2009 inferredReturnShapes.push_back(
2012 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2016 outputShape.reserve(multiples.size());
2017 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2018 if (multiples[i] == ShapedType::kDynamic) {
2019 outputShape.push_back(ShapedType::kDynamic);
2021 int64_t dim = inputShape.getDimSize(i);
2022 if (dim != ShapedType::kDynamic)
2023 dim *= multiples[i];
2024 outputShape.push_back(dim);
2038 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2039 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2041 shapeType multiplesType =
2042 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2044 auto multiplesRank = multiplesType.getRank();
2046 if (inputType.hasRank()) {
2047 if (inputType.getRank() != multiplesRank)
2048 return emitOpError(
"expect 'multiples' to have rank ")
2049 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2050 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2051 return emitOpError(
"expect same input and output tensor rank.");
2052 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2053 return emitOpError(
"expect 'multiples' array to have length ")
2054 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2057 if (getConstantMultiples(multiples).succeeded() &&
2058 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2060 "expect element of 'multiples' to be positive integer or -1.");
2066 if (l.size() != r.size() || l.size() != 1)
2071 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2072 MLIRContext *context, ::std::optional<Location> location,
2073 ReshapeOp::Adaptor adaptor,
2075 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2080 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2090 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2091 inferredReturnShapes.push_back(
2099 int64_t numElements = inputShape.getNumElements();
2100 int64_t staticMul = 1;
2101 for (
auto val : newShapeValue) {
2102 if (ShapedType::isStatic(val)) {
2108 for (
auto &val : newShapeValue) {
2109 if (ShapedType::isDynamic(val))
2110 val = numElements / staticMul;
2113 inferredReturnShapes.push_back(
2124 TensorType inputType = getInput1().getType();
2129 return mlir::success();
2132 int missingDims = llvm::count(shapeValues, -1);
2133 if (missingDims > 1)
2134 return emitOpError() <<
"expected at most one target dimension to be -1";
2136 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2140 if ((int64_t)shapeValues.size() != outputType.getRank())
2141 return emitOpError() <<
"new shape does not match result rank";
2143 for (
auto [newShapeDim, outputShapeDim] :
2144 zip(shapeValues, outputType.getShape())) {
2145 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2146 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2147 return emitOpError() <<
"new shape is inconsistent with result shape";
2149 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2150 return emitOpError() <<
"new shape has invalid tensor dimension size "
2154 if (inputType.hasStaticShape()) {
2155 int64_t inputElementsNum = inputType.getNumElements();
2156 if (outputType.hasStaticShape()) {
2157 int64_t outputElementsNum = outputType.getNumElements();
2158 if (inputElementsNum != outputElementsNum) {
2159 return emitOpError() <<
"cannot reshape " << inputElementsNum
2160 <<
" elements into " << outputElementsNum;
2164 int64_t newShapeElementsNum = std::accumulate(
2165 shapeValues.begin(), shapeValues.end(), 1LL,
2166 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2167 bool isStaticNewShape =
2168 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2169 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2170 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2171 return emitOpError() <<
"cannot reshape " << inputElementsNum
2172 <<
" elements into " << newShapeElementsNum;
2176 return mlir::success();
2183 ElementsAttr zpAttr;
2188 Type zpElemType = zpAttr.getElementType();
2190 if (llvm::isa<FloatType>(zpElemType)) {
2191 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2198 if (llvm::isa<IntegerType>(zpElemType)) {
2200 return zpAttr.getValues<APInt>()[0].getSExtValue();
2202 return zpAttr.getValues<APInt>()[0].getZExtValue();
2209 template <
typename T>
2211 const std::string &operand) {
2214 if (!zpElemType.
isInteger(8) && zp != 0) {
2216 std::string lower = operand;
2217 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2218 return op.emitOpError()
2219 << lower <<
" zero point must be zero for non-int8 integer types";
2227 const std::string &operand) {
2228 bool isInputZp = (operand ==
"Input");
2230 bool tensorUnsigned =
2231 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2232 StringRef tensorName = isInputZp ?
"input" :
"output";
2238 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2239 return op.emitOpError()
2240 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2242 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2243 return op.emitOpError() <<
"expect " << tensorName
2244 <<
"_zp of 0 or 32768 for unsigned int16 "
2245 << tensorName <<
", got " << zp;
2252 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2253 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2254 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2256 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2257 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2276 #undef ZERO_POINT_HELPER
2278 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2279 MLIRContext *context, ::std::optional<Location> location,
2280 TransposeOp::Adaptor adaptor,
2282 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2291 const auto inputRank = inputShape.
getRank();
2295 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2301 if (inputRank == 0) {
2307 bool allTheSame =
true;
2308 for (
int i = 1, s = inputRank; i < s; i++) {
2318 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2323 outputShape.resize(inputRank, ShapedType::kDynamic);
2326 if (llvm::any_of(adaptor.getPerms(),
2327 [inputRank](
const auto i) { return i >= inputRank; }))
2330 outputShape.reserve(inputRank);
2331 for (
int i = 0, s = inputRank; i < s; i++) {
2332 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2351 if (inputShape.hasRank() &&
2352 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2353 return emitOpError() <<
"expected perms attribute to have size "
2354 << inputShape.getRank()
2355 <<
" (input rank) but got size "
2356 << constantPerms.size();
2358 if (inputShape.hasRank() && outputShape.hasRank() &&
2359 inputShape.getRank() != outputShape.getRank())
2360 return emitOpError()
2361 <<
"expected input tensor rank to equal result tensor rank";
2363 if (outputShape.hasRank() &&
2364 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2365 return emitOpError() <<
"expected perms attribute to have size "
2366 << outputShape.getRank()
2367 <<
" (output rank) but got size "
2368 << constantPerms.size();
2370 if (!llvm::all_of(constantPerms,
2371 [&constantPerms](int32_t s) {
2373 static_cast<size_t>(s) < constantPerms.size();
2376 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2377 return emitOpError() <<
"expected valid permutation indices";
2380 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2381 inputShape.getNumElements() != outputShape.getNumElements())
2382 return emitOpError() <<
"expected input1 and output to have same numbers "
2384 << inputShape.getNumElements() <<
" and "
2385 << outputShape.getNumElements();
2389 if (inputShape.hasRank() && outputShape.hasRank()) {
2390 for (
auto i = 0; i < outputShape.getRank(); i++) {
2391 if (inputShape.isDynamicDim(constantPerms[i]) ||
2392 outputShape.isDynamicDim(i))
2395 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2396 return emitOpError()
2397 <<
"expected output tensor dim " << i <<
" to match "
2398 <<
"input dim " << constantPerms[i] <<
" with value of "
2399 << inputShape.getDimSize(constantPerms[i]);
2411 Value input = getInput1();
2412 auto inputType = cast<TensorType>(input.
getType());
2415 for (
auto dim : transposePerms) {
2416 int32_t dimInInput = transposePerms[dim];
2417 if (inputType.isDynamicDim(dimInInput))
2419 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2423 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2426 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2430 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2431 MLIRContext *context, ::std::optional<Location> location,
2432 GatherOp::Adaptor adaptor,
2435 outputShape.resize(3, ShapedType::kDynamic);
2437 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2438 if (valuesShape.hasRank()) {
2439 outputShape[0] = valuesShape.getDimSize(0);
2440 outputShape[2] = valuesShape.getDimSize(2);
2443 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2444 if (indicesShape.hasRank()) {
2445 if (outputShape[0] == ShapedType::kDynamic)
2446 outputShape[0] = indicesShape.getDimSize(0);
2447 if (outputShape[1] == ShapedType::kDynamic)
2448 outputShape[1] = indicesShape.getDimSize(1);
2466 int64_t N = ShapedType::kDynamic;
2467 int64_t
W = ShapedType::kDynamic;
2468 int64_t
C = ShapedType::kDynamic;
2470 if (valuesShape.hasRank()) {
2471 N = valuesShape.getDimSize(0);
2472 C = valuesShape.getDimSize(2);
2474 if (indicesShape.hasRank()) {
2475 const int64_t indicesN = indicesShape.getDimSize(0);
2476 W = indicesShape.getDimSize(1);
2477 if (N == ShapedType::kDynamic)
2479 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2480 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2481 <<
", got " << indicesN;
2483 if (outputShape.hasRank()) {
2484 const int64_t outputN = outputShape.getDimSize(0);
2485 const int64_t outputW = outputShape.getDimSize(1);
2486 const int64_t outputC = outputShape.getDimSize(2);
2487 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2489 return emitOpError() <<
"requires output dimension 0 to have size " << N
2490 <<
", got " << outputN;
2492 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2494 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2495 <<
", got " << outputW;
2496 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2498 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2499 <<
", got " << outputC;
2504 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2505 MLIRContext *context, ::std::optional<Location> location,
2506 ResizeOp::Adaptor adaptor,
2509 outputShape.resize(4, ShapedType::kDynamic);
2512 if (!inputShape.hasRank())
2515 outputShape[0] = inputShape.getDimSize(0);
2516 outputShape[3] = inputShape.getDimSize(3);
2517 int64_t inputHeight = inputShape.getDimSize(1);
2518 int64_t inputWidth = inputShape.getDimSize(2);
2520 if ((inputHeight == ShapedType::kDynamic) ||
2521 (inputWidth == ShapedType::kDynamic))
2535 const int64_t outputHeight =
2536 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2540 const int64_t outputWidth =
2541 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2545 if (outputHeight < 0 || outputWidth < 0) {
2548 "calculated output height and width must be non-negative, "
2550 outputHeight,
", width = ", outputWidth);
2553 outputShape[1] = outputHeight;
2554 outputShape[2] = outputWidth;
2560 const Value input = getInput();
2561 const Value output = getOutput();
2562 const RankedTensorType inputType =
2563 llvm::dyn_cast<RankedTensorType>(input.
getType());
2564 const RankedTensorType outputType =
2565 llvm::dyn_cast<RankedTensorType>(output.
getType());
2577 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2578 return emitOpError(
"expect all scale values to be > 0, got ")
2581 const int64_t scaleYN = scaleValues[0];
2582 const int64_t scaleYD = scaleValues[1];
2583 const int64_t scaleXN = scaleValues[2];
2584 const int64_t scaleXD = scaleValues[3];
2586 const int64_t offsetY = offsetValues[0];
2587 const int64_t offsetX = offsetValues[1];
2589 const int64_t borderY = borderValues[0];
2590 const int64_t borderX = borderValues[1];
2597 const int64_t oh = outputType.getDimSize(1);
2598 const int64_t ow = outputType.getDimSize(2);
2599 const int64_t ih = inputType.getDimSize(1);
2600 const int64_t iw = inputType.getDimSize(2);
2606 if (ih != ShapedType::kDynamic && ih != 1) {
2607 const std::optional<int64_t> calculatedOutHeightMinusOne =
2608 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2609 if (!calculatedOutHeightMinusOne.has_value())
2610 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2612 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2613 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2614 <<
") / " << scaleYD;
2615 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2616 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2617 return emitOpError(
"calculated output height did not match expected: ")
2618 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2625 if (iw != ShapedType::kDynamic && iw != 1) {
2626 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2627 const std::optional<int64_t> calculatedOutWidthMinusOne =
2629 if (!calculatedOutWidthMinusOne.has_value())
2630 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2632 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2633 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2634 <<
") / " << scaleXD;
2635 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2636 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2637 return emitOpError(
"calculated output width did not match expected: ")
2638 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2644 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2645 MLIRContext *context, ::std::optional<Location> location,
2646 ScatterOp::Adaptor adaptor,
2649 outputShape.resize(3, ShapedType::kDynamic);
2651 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2652 if (valuesInShape.hasRank()) {
2653 outputShape[0] = valuesInShape.getDimSize(0);
2654 outputShape[1] = valuesInShape.getDimSize(1);
2655 outputShape[2] = valuesInShape.getDimSize(2);
2658 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2659 if (indicesShape.hasRank()) {
2660 if (outputShape[0] == ShapedType::kDynamic)
2661 outputShape[0] = indicesShape.getDimSize(0);
2665 if (inputShape.hasRank()) {
2666 if (outputShape[0] == ShapedType::kDynamic)
2667 outputShape[0] = inputShape.getDimSize(0);
2668 if (outputShape[2] == ShapedType::kDynamic)
2669 outputShape[2] = inputShape.getDimSize(2);
2691 int64_t N = ShapedType::kDynamic;
2692 int64_t K = ShapedType::kDynamic;
2693 int64_t
W = ShapedType::kDynamic;
2694 int64_t
C = ShapedType::kDynamic;
2695 if (valuesInShape.hasRank()) {
2696 N = valuesInShape.getDimSize(0);
2697 K = valuesInShape.getDimSize(1);
2698 C = valuesInShape.getDimSize(2);
2700 if (indicesShape.hasRank()) {
2701 const int64_t indicesN = indicesShape.getDimSize(0);
2702 W = indicesShape.getDimSize(1);
2703 if (N == ShapedType::kDynamic)
2705 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2706 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2707 <<
", got " << indicesN;
2709 if (inputShape.hasRank()) {
2710 const int64_t inputN = inputShape.getDimSize(0);
2711 const int64_t inputW = inputShape.getDimSize(1);
2712 const int64_t inputC = inputShape.getDimSize(2);
2713 if (N == ShapedType::kDynamic)
2715 else if (inputN != ShapedType::kDynamic && N != inputN)
2716 return emitOpError() <<
"requires input dimension 0 to have size " << N
2717 <<
", got " << inputN;
2718 if (W == ShapedType::kDynamic)
2720 else if (inputW != ShapedType::kDynamic && W != inputW)
2721 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2722 <<
", got " << inputW;
2724 if (C == ShapedType::kDynamic)
2726 else if (inputC != ShapedType::kDynamic && C != inputC)
2727 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2728 <<
", got " << inputC;
2730 if (outputShape.hasRank()) {
2731 const int64_t outputN = outputShape.getDimSize(0);
2732 const int64_t outputK = outputShape.getDimSize(1);
2733 const int64_t outputC = outputShape.getDimSize(2);
2734 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2736 return emitOpError() <<
"requires values_out dimension 0 to have size "
2737 << N <<
", got " << outputN;
2738 if (K == ShapedType::kDynamic)
2740 else if (outputK != ShapedType::kDynamic && K != outputK)
2741 return emitOpError() <<
"requires values_out dimension 1 to have size "
2742 << K <<
", got " << outputK;
2743 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2745 return emitOpError() <<
"requires values_out dimension 2 to have size "
2746 <<
C <<
", got " << outputC;
2748 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2749 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2758 int64_t axisVal = axis.getValue().getSExtValue();
2759 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2765 operandShape.
getDims(outputShape);
2766 outputShape[axisVal] = 1;
2771 #define COMPATIBLE_RETURN_TYPES(OP) \
2772 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2773 if (l.size() != r.size() || l.size() != 1) \
2775 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2777 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2780 #define REDUCE_SHAPE_INFER(OP) \
2781 LogicalResult OP::inferReturnTypeComponents( \
2782 MLIRContext *context, ::std::optional<Location> location, \
2783 OP::Adaptor adaptor, \
2784 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2786 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2787 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2788 const Properties &prop = adaptor.getProperties(); \
2789 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2790 inferredReturnShapes); \
2792 COMPATIBLE_RETURN_TYPES(OP)
2800 #undef REDUCE_SHAPE_INFER
2802 #undef COMPATIBLE_RETURN_TYPES
2804 template <
typename T>
2807 TensorType inputType = op.getInput().getType();
2808 TensorType outputType = op.getOutput().getType();
2809 int32_t reduceAxis = op.getAxis();
2811 if (reduceAxis < 0) {
2812 op.emitOpError(
"reduce axis must not be negative");
2816 int64_t inputRank = inputType.getRank();
2819 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2820 op.emitOpError(
"expect input tensor rank (")
2821 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2827 int64_t outputRank = outputType.getRank();
2828 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2830 "expect output tensor rank to be equal to input tensor rank");
2833 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2834 op.emitOpError(
"expect output tensor rank (")
2835 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2841 if (outputRank != 0) {
2842 auto outputShape = outputType.
getShape();
2843 if (!outputType.isDynamicDim(reduceAxis) &&
2844 outputShape[reduceAxis] != 1) {
2845 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2846 << outputShape[reduceAxis];
2873 #define NARY_SHAPE_INFER(OP) \
2874 LogicalResult OP::inferReturnTypeComponents( \
2875 MLIRContext *context, ::std::optional<Location> location, \
2876 ValueShapeRange operands, DictionaryAttr attributes, \
2877 OpaqueProperties properties, RegionRange regions, \
2878 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2879 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2919 #undef PRED_SHAPE_INFER
2921 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2922 MLIRContext *context, ::std::optional<Location> location,
2923 NegateOp::Adaptor adaptor,
2925 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2932 const Type input1Type = getInput1().getType();
2933 const Type outputType = getOutput().getType();
2940 return emitOpError() <<
"requires the same shape for input1 and output";
2943 const Type input1ZpEType =
2945 if (input1EType != input1ZpEType) {
2946 return emitOpError(
"expect both input1 and its zero point are the same "
2947 "element type, got ")
2948 << input1EType <<
" and " << input1ZpEType;
2951 const Type outputZpEType =
2953 if (outputEType != outputZpEType) {
2954 return emitOpError(
"expect both output and its zero point are the same "
2955 "element type, got ")
2956 << outputEType <<
" and " << outputZpEType;
2959 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2960 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2963 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2964 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2975 outputShape.resize(4, ShapedType::kDynamic);
2990 if (ShapedType::isStatic(height)) {
2991 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2992 outputShape[1] = padded / stride[0] + 1;
2995 if (ShapedType::isStatic(width)) {
2996 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2997 outputShape[2] = padded / stride[1] + 1;
3004 LogicalResult Conv2DOp::inferReturnTypeComponents(
3005 MLIRContext *context, ::std::optional<Location> location,
3006 Conv2DOp::Adaptor adaptor,
3010 int64_t inputWidth = ShapedType::kDynamic;
3011 int64_t inputHeight = ShapedType::kDynamic;
3012 int64_t weightWidth = ShapedType::kDynamic;
3013 int64_t weightHeight = ShapedType::kDynamic;
3018 if (inputShape.hasRank()) {
3019 outputShape[0] = inputShape.getDimSize(0);
3020 inputHeight = inputShape.getDimSize(1);
3021 inputWidth = inputShape.getDimSize(2);
3025 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3026 if (weightShape.hasRank()) {
3027 outputShape[3] = weightShape.getDimSize(0);
3028 weightHeight = weightShape.getDimSize(1);
3029 weightWidth = weightShape.getDimSize(2);
3034 if (biasShape.hasRank()) {
3035 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3036 ? biasShape.getDimSize(0)
3044 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3045 int64_t inputSize = inputHeight + padding[0] + padding[1];
3046 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3047 int64_t unstridedResult = inputSize - filterSize + 1;
3048 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3051 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3052 int64_t inputSize = inputWidth + padding[2] + padding[3];
3053 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3054 int64_t unstridedResult = inputSize - filterSize + 1;
3055 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3069 LogicalResult Conv3DOp::inferReturnTypeComponents(
3070 MLIRContext *context, ::std::optional<Location> location,
3071 Conv3DOp::Adaptor adaptor,
3075 int64_t inputWidth = ShapedType::kDynamic;
3076 int64_t inputHeight = ShapedType::kDynamic;
3077 int64_t inputDepth = ShapedType::kDynamic;
3079 int64_t weightWidth = ShapedType::kDynamic;
3080 int64_t weightHeight = ShapedType::kDynamic;
3081 int64_t weightDepth = ShapedType::kDynamic;
3085 if (inputShape.hasRank()) {
3086 outputShape[0] = inputShape.getDimSize(0);
3087 inputDepth = inputShape.getDimSize(1);
3088 inputHeight = inputShape.getDimSize(2);
3089 inputWidth = inputShape.getDimSize(3);
3093 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3094 if (weightShape.hasRank()) {
3095 outputShape[4] = weightShape.getDimSize(0);
3096 weightDepth = weightShape.getDimSize(1);
3097 weightHeight = weightShape.getDimSize(2);
3098 weightWidth = weightShape.getDimSize(3);
3103 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3104 outputShape[4] = biasShape.getDimSize(0);
3111 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3112 int32_t inputSize = inputDepth + pad[0] + pad[1];
3113 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3114 int32_t unstridedResult = inputSize - filterSize + 1;
3115 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3118 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3119 int32_t inputSize = inputHeight + pad[2] + pad[3];
3120 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3121 int32_t unstridedResult = inputSize - filterSize + 1;
3122 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3125 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3126 int32_t inputSize = inputWidth + pad[4] + pad[5];
3127 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3128 int32_t unstridedResult = inputSize - filterSize + 1;
3129 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3143 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3144 MLIRContext *context, ::std::optional<Location> location,
3145 AvgPool2dOp::Adaptor adaptor,
3148 const Properties &prop = adaptor.getProperties();
3150 inferredReturnShapes);
3153 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3154 MLIRContext *context, ::std::optional<Location> location,
3155 MaxPool2dOp::Adaptor adaptor,
3158 const Properties &prop = adaptor.getProperties();
3160 inferredReturnShapes);
3174 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3175 MLIRContext *context, ::std::optional<Location> location,
3176 DepthwiseConv2DOp::Adaptor adaptor,
3180 int64_t inputWidth = ShapedType::kDynamic;
3181 int64_t inputHeight = ShapedType::kDynamic;
3182 int64_t inputChannels = ShapedType::kDynamic;
3184 int64_t weightWidth = ShapedType::kDynamic;
3185 int64_t weightHeight = ShapedType::kDynamic;
3186 int64_t depthChannels = ShapedType::kDynamic;
3190 if (inputShape.hasRank()) {
3191 outputShape[0] = inputShape.getDimSize(0);
3192 inputHeight = inputShape.getDimSize(1);
3193 inputWidth = inputShape.getDimSize(2);
3194 inputChannels = inputShape.getDimSize(3);
3198 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3199 if (weightShape.hasRank()) {
3200 weightHeight = weightShape.getDimSize(0);
3201 weightWidth = weightShape.getDimSize(1);
3202 inputChannels = ShapedType::isDynamic(inputChannels)
3203 ? weightShape.getDimSize(2)
3205 depthChannels = weightShape.getDimSize(3);
3210 if (ShapedType::isStatic(inputChannels) &&
3211 ShapedType::isStatic(depthChannels)) {
3212 outputShape[3] = inputChannels * depthChannels;
3217 if (biasShape.hasRank()) {
3218 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3219 ? biasShape.getDimSize(0)
3227 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3228 int64_t inputSize = inputHeight + padding[0] + padding[1];
3229 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3230 int64_t unstridedResult = inputSize - filterSize + 1;
3231 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3234 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3235 int64_t inputSize = inputWidth + padding[2] + padding[3];
3236 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3237 int64_t unstridedResult = inputSize - filterSize + 1;
3238 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3252 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3253 MLIRContext *context, ::std::optional<Location> location,
3254 TransposeConv2DOp::Adaptor adaptor,
3258 int64_t inputWidth = ShapedType::kDynamic;
3259 int64_t inputHeight = ShapedType::kDynamic;
3260 int64_t weightWidth = ShapedType::kDynamic;
3261 int64_t weightHeight = ShapedType::kDynamic;
3265 if (inputShape.hasRank()) {
3266 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3267 ? inputShape.getDimSize(0)
3269 inputHeight = inputShape.getDimSize(1);
3270 inputWidth = inputShape.getDimSize(2);
3274 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3275 if (weightShape.hasRank()) {
3276 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3277 ? weightShape.getDimSize(0)
3279 weightHeight = weightShape.getDimSize(1);
3280 weightWidth = weightShape.getDimSize(2);
3285 if (biasShape.hasRank()) {
3286 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3287 ? biasShape.getDimSize(0)
3294 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3295 int64_t calculateSize =
3296 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3298 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3301 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3302 int64_t calculateSize =
3303 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3305 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3317 const int64_t strideY = strides[0];
3318 const int64_t strideX = strides[1];
3320 if (strideY < 1 || strideX < 1)
3321 return emitOpError(
"expect all stride values to be >= 1, got [")
3324 const auto checkPadAgainstKernelDim =
3325 [
this](int64_t pad_value, int64_t kernel_dim_size,
3326 llvm::StringRef pad_name,
3327 llvm::StringRef kernel_dim_name) -> LogicalResult {
3328 if (pad_value <= -kernel_dim_size)
3329 return emitOpError(
"expected ")
3330 << pad_name <<
" > -" << kernel_dim_name
3331 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3332 << kernel_dim_name <<
"=" << kernel_dim_size;
3337 const int64_t outPadTop = padding[0];
3338 const int64_t outPadBottom = padding[1];
3339 const int64_t outPadLeft = padding[2];
3340 const int64_t outPadRight = padding[3];
3342 const auto weightType =
3343 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3346 const int64_t kernelHeight = weightType.getDimSize(1);
3347 if (ShapedType::isStatic(kernelHeight)) {
3348 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349 "out_pad_top",
"KH")))
3352 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353 "out_pad_bottom",
"KH")))
3357 const int64_t kernelWidth = weightType.getDimSize(2);
3358 if (ShapedType::isStatic(kernelWidth)) {
3359 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360 "out_pad_left",
"KW")))
3363 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364 "out_pad_right",
"KW")))
3370 const auto outputType =
3371 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3375 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3376 if (inputType && weightType) {
3377 const int64_t inputHeight = inputType.getDimSize(1);
3378 const int64_t kernelHeight = weightType.getDimSize(1);
3379 const int64_t outputHeight = outputType.getDimSize(1);
3381 if (ShapedType::isStatic(inputHeight) &&
3382 ShapedType::isStatic(outputHeight)) {
3384 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3386 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387 "+ out_pad_top + out_pad_bottom + KH, but got ")
3388 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3389 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3390 <<
" + " << kernelHeight;
3393 const int64_t inputWidth = inputType.getDimSize(2);
3394 const int64_t kernelWidth = weightType.getDimSize(2);
3395 const int64_t outputWidth = outputType.getDimSize(2);
3397 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3399 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3401 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3402 "+ out_pad_left + out_pad_right + KW, but got ")
3403 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3404 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3409 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3414 const int64_t biasChannels = biasType.getDimSize(0);
3417 if (biasChannels == ShapedType::kDynamic)
3420 const int64_t outputChannels = outputType.getDimSize(3);
3421 if (!ShapedType::isDynamic(outputChannels) &&
3422 biasChannels != outputChannels && biasChannels != 1)
3424 "bias channels expected to be equal to output channels (")
3425 << outputChannels <<
") or 1, got " << biasChannels;
3431 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3433 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3437 auto inputElementType =
3439 if (!mlir::isa<IntegerType>(inputElementType)) {
3440 emitOpError(
"expect input to have integer element type, got ")
3441 << inputElementType;
3445 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3447 emitOpError(
"expect shaped tensor for output, got ")
3448 << getOutput().getType();
3452 auto outputElementType =
3454 if (!mlir::isa<IntegerType>(outputElementType)) {
3455 emitOpError(
"expect output to have integer element type, got ")
3456 << outputElementType;
3468 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3472 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3476 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3477 if (!multiplierType) {
3478 emitOpError(
"expect shaped tensor for multiplier, got ")
3479 << getMultiplier().getType();
3483 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3485 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3490 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3492 << multiplierType.getElementType();
3497 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3499 "expect i16 element type for multiplier for scale32=false, got ")
3500 << multiplierType.getElementType();
3504 if (!inputType.hasRank())
3510 int64_t numChannels = 1;
3511 if (getPerChannel()) {
3512 if (inputType.getRank() < 1) {
3513 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3514 "true, but got rank ")
3515 << inputType.getRank();
3518 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3521 if (!multiplierType.hasRank())
3526 if (multiplierShape[0] != ShapedType::kDynamic &&
3527 multiplierShape[0] != numChannels) {
3528 emitOpError(
"expect shape of { ")
3529 << numChannels <<
" } for multiplier input, got { "
3530 << multiplierShape[0] <<
" }";
3534 if (!shiftType.hasRank())
3539 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540 emitOpError(
"expect shape of { ")
3541 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3548 LogicalResult RescaleOp::inferReturnTypeComponents(
3549 MLIRContext *context, ::std::optional<Location> location,
3550 RescaleOp::Adaptor adaptor,
3557 LogicalResult IfOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional<Location> location,
3559 IfOp::Adaptor adaptor,
3562 for (
Region *region : adaptor.getRegions()) {
3563 for (
auto &block : *region)
3564 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3565 yieldOps.push_back(returnOp);
3568 if (yieldOps.empty())
3573 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574 for (
auto operand : yieldOps.front().getOperands()) {
3575 resultKnowledge.push_back(
3579 for (
auto yieldOp : yieldOps) {
3580 if (resultKnowledge.size() != yieldOp.getNumOperands())
3584 int32_t index = it.index();
3586 resultKnowledge[index],
3590 resultKnowledge[index] = meet;
3595 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3601 LogicalResult WhileOp::inferReturnTypeComponents(
3602 MLIRContext *context, ::std::optional<Location> location,
3603 WhileOp::Adaptor adaptor,
3606 for (
auto &block : adaptor.getBodyGraph())
3607 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3608 yieldOps.push_back(returnOp);
3612 if (yieldOps.empty())
3617 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618 for (
auto operand : yieldOps.front().getOperands()) {
3619 resultKnowledge.push_back(
3623 for (
auto yieldOp : yieldOps) {
3624 if (resultKnowledge.size() != yieldOp.getNumOperands())
3628 int32_t index = it.index();
3630 resultKnowledge[index],
3632 resultKnowledge[index] = meet;
3638 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3646 return llvm::to_vector<4>(vt.getShape());
3647 return std::nullopt;
3684 bool printBlockTerminators =
false;
3686 p <<
" " << getCondition();
3687 if (!getResults().empty()) {
3688 p <<
" -> (" << getResultTypes() <<
")";
3690 printBlockTerminators =
true;
3695 printBlockTerminators);
3698 auto &elseRegion = getElseGraph();
3699 if (!elseRegion.
empty()) {
3703 printBlockTerminators);
3711 "'then_graph' arguments", getInputList(),
3717 "'else_graph' arguments", getInputList(),
3722 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3724 "'then_graph' results", getOutputList(),
3729 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3731 "'else_graph' results", getOutputList(),
3736 auto condType = getCondition().getType();
3738 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
3746 getOutputList(),
"'output_list'")
3751 "'cond_graph' arguments", getInputList(),
3757 "'body_graph' arguments", getInputList(),
3762 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3764 "'body_graph' results", getInputList(),
3771 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3772 if (condYield.getInputs().size() != 1)
3773 return emitOpError() <<
"require 'cond_graph' only have one result";
3775 auto condOutType = condYield.getInputs()[0].getType();
3777 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
3781 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
3792 TensorType inputType = getInput1().getType();
3793 TensorType outputType = getOutput().getType();
3794 int32_t reverseAxis = getAxis();
3796 if (reverseAxis < 0)
3797 return emitOpError(
"expected non-negative reverse axis");
3799 int64_t inputRank = inputType.getRank();
3802 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803 return emitOpError(
"expect input tensor rank (")
3804 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3808 int64_t outputRank = outputType.getRank();
3809 if (inputType.
hasRank() && outputRank != inputType.getRank())
3811 "expect output tensor rank to be equal to input tensor rank");
3812 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813 return emitOpError(
"expect output tensor rank (")
3814 << outputRank <<
") to be larger than reverse axis ("
3815 << reverseAxis <<
")";
3831 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
3832 if (!predicateType) {
3833 return emitOpError(
"expect shaped tensor for input1, got ")
3834 << getInput1().getType();
3836 auto predicateElementType = predicateType.getElementType();
3837 if (!predicateElementType.isInteger(1)) {
3838 return emitOpError(
"expect element type of bool for input1, got ")
3839 << predicateElementType;
3846 StringRef symName = getName();
3848 if (succeeded(varOp))
3849 return emitOpError(
"illegal to have multiple declaration of '")
3883 FunctionType functionType;
3888 result.
addTypes(functionType.getResults());
3890 if (functionType.getNumInputs() != operands.size()) {
3892 <<
"expected as many input types as operands "
3893 <<
"(expected " << operands.size() <<
" got "
3894 << functionType.getNumInputs() <<
")";
3904 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905 regionArgs[i].type = functionType.getInput(i);
3907 return failure(parser.
parseRegion(*cond, regionArgs) ||
3915 StringRef prefix =
"") {
3916 assert(blocksArgs.size() == initializers.size() &&
3917 "expected same length of arguments and initializers");
3918 if (initializers.empty())
3921 parser << prefix <<
'(';
3922 llvm::interleaveComma(
3923 llvm::zip(blocksArgs, initializers), parser,
3924 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3930 getInputList(),
" ");
3933 getResults().getTypes());
3948 if (llvm::isa<FloatType>(srcElemType)) {
3950 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3951 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
3953 if (llvm::isa<IntegerType>(srcElemType)) {
3956 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
3958 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3959 return std::nullopt;
3967 return mlir::isa<tosa::shapeType>(t);
3974 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3980 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981 Operation *definingOp = v.getDefiningOp();
3983 return op->
emitOpError(
"shape operand is not compile time resolvable");
3992 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3993 return op->
emitOpError(
"must have operands with tosa shape type");
3997 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3998 return op->
emitOpError(
"must have result with tosa shape type");
4011 auto getRank = [](
const Type type) {
4012 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4018 for (
auto type : operandTypes) {
4019 if (getRank(type) != rank) {
4020 return op->
emitOpError(
"operands don't have matching ranks");
4023 for (
auto type : resultTypes) {
4024 if (getRank(type) != rank) {
4025 return op->
emitOpError(
"result shape has different rank than operands");
4037 auto valuesRank = getValues().getType().getRank();
4038 if (valuesRank != 1)
4039 return emitOpError(
"expect elements in attribute values with rank 1");
4041 auto count = getValues().getNumElements();
4042 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4043 if (!(count == rank || (count == 1 && rank == 0))) {
4044 return emitOpError(
"expect number of elements in attribute values (")
4045 << count <<
") to be equal to the rank (" << rank
4046 <<
") for the result shape type";
4055 #define GET_ATTRDEF_CLASSES
4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4061 #define GET_TYPEDEF_CLASSES
4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4068 #define GET_OP_CLASSES
4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)