29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(
Dialect *dialect)
89 LogicalResult writeAttribute(
Attribute attr,
91 return ::writeAttribute(attr, writer);
101 LogicalResult writeType(
Type type,
103 return ::writeType(type, writer);
110 std::unique_ptr<DialectVersion>
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
131 return {&getBodyGraph()};
138 void TosaDialect::initialize() {
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
151 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152 declarePromisedInterfaces<
153 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159 GreaterEqualOp, MatMulOp>();
166 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167 return builder.
create<tosa::ConstShapeOp>(
168 loc, type, llvm::cast<DenseIntElementsAttr>(value));
170 if (llvm::isa<ElementsAttr>(value))
171 return builder.
create<tosa::ConstOp>(loc, type,
172 llvm::cast<ElementsAttr>(value));
185 <<
"expected attribute";
187 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
204 bool needsSpace =
false;
205 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206 if (!typedAttr || typedAttr.getType() != type.getValue()) {
223 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
231 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
232 srcType = quantType.getStorageType();
241 Value valZp, StringRef name) {
246 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
250 if (!bothInts || !sameBitWidth) {
252 <<
"expected " << name <<
" and " << name
253 <<
"_zp to both be integer of the same bitwidth, but got " << eType
254 <<
" vs. " << eZpType;
261 Value src, int32_t val) {
266 const auto padConstAttr{
267 llvm::isa<FloatType>(srcElemType)
272 return builder.
create<tosa::ConstOp>(loc, padConstType, padConstAttr);
279 template <
typename T>
281 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
282 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
284 auto inputEType = inputType.getElementType();
285 auto weightEType = weightType.getElementType();
287 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
289 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
290 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
291 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
293 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
294 inputEType = quantType.getStorageType();
296 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
297 weightEType = quantType.getStorageType();
299 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
300 biasEType = quantType.getStorageType();
302 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
303 resultEType = quantType.getStorageType();
305 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
309 "expect both bias and result to have same element type, got ")
310 << biasEType <<
" and " << resultEType;
314 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
315 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
316 if (inputEType != weightEType) {
318 "expect both input and weight to have same element type, got ")
319 << inputEType <<
" and " << weightEType;
324 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
325 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
328 if (inputIsFloat != weightIsFloat) {
330 "expect both input and weight to be float or not together, got ")
331 << inputEType <<
" and " << weightEType;
336 if (inputEType != inputZpEType) {
337 return op.emitOpError(
"expect both input and its zero point are the same "
338 "element type, got ")
339 << inputEType <<
" and " << inputZpEType;
343 if (weightEType != weightZpEType) {
344 return op.emitOpError(
"expect both weight and its zero point are the same "
345 "element type, got ")
346 << weightEType <<
" and " << weightZpEType;
349 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
350 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
353 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
354 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
362 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
363 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
365 if (!attrType || !outputType) {
366 emitOpError(
"expected tensors for attr/result type");
370 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
371 outputType.getElementType())) {
372 if (result.getStorageType() == attrType.getElementType())
376 if (attrType.getElementType() != outputType.getElementType()) {
377 emitOpError(
"expected same attr/result element types");
384 template <
typename T>
387 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
389 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
390 inputEType = quantType.getStorageType();
392 auto accType = op.getAccType();
393 if (inputEType.isInteger(8) && !accType.isInteger(32))
394 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
396 if (inputEType.isInteger(16) && !accType.isInteger(48))
397 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
399 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
400 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
402 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
403 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
405 if (inputEType.isBF16() && !accType.isF32())
406 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
408 if (inputEType.isF32() && !accType.isF32())
409 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
412 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
414 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
415 resultEType = quantType.getStorageType();
425 template <
typename T>
428 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
429 return op.emitOpError(
"expect all padding values to be >= 0, got ")
433 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
434 return op.emitOpError(
"expect all stride values to be >= 1, got ")
438 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
439 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
442 const RankedTensorType outputType =
443 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
448 const RankedTensorType inputType =
449 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
450 const RankedTensorType weightType =
451 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
453 if (inputType && weightType) {
454 const auto verifyOutputSize =
455 [&op](
const int64_t inputSize,
const int64_t kernelSize,
456 const int64_t outputSize,
const int64_t padBefore,
457 const int64_t padAfter,
const int64_t stride,
458 const int64_t dilation,
const llvm::StringRef dimName,
459 const llvm::StringRef dimAxis,
460 const llvm::StringRef padBeforeName,
461 const llvm::StringRef padAfterName) -> LogicalResult {
462 if (inputSize == ShapedType::kDynamic ||
463 kernelSize == ShapedType::kDynamic)
468 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
469 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
471 if (!calculatedOutSizeMinusOne.has_value())
472 return op.emitOpError(
"expected input_")
473 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
474 << padAfterName <<
" - (kernel_" << dimName
475 <<
" - 1) * dilation_" << dimAxis
476 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
477 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
478 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
481 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
482 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
483 return op.emitOpError(
"calculated output ")
484 << dimName <<
" did not match expected: "
485 <<
"calculated=" << calculatedOutSize
486 <<
", expected=" << outputSize;
492 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
493 if (failed(verifyOutputSize(
494 inputType.getDimSize(1), weightType.getDimSize(1),
495 outputType.getDimSize(1), padding[0], padding[1], strides[0],
496 dilations[0],
"height",
"y",
"top",
"bottom")))
499 if (failed(verifyOutputSize(
500 inputType.getDimSize(2), weightType.getDimSize(2),
501 outputType.getDimSize(2), padding[2], padding[3], strides[1],
502 dilations[1],
"width",
"x",
"left",
"right")))
507 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
508 if (failed(verifyOutputSize(
509 inputType.getDimSize(1), weightType.getDimSize(0),
510 outputType.getDimSize(1), padding[0], padding[1], strides[0],
511 dilations[0],
"height",
"y",
"top",
"bottom")))
514 if (failed(verifyOutputSize(
515 inputType.getDimSize(2), weightType.getDimSize(1),
516 outputType.getDimSize(2), padding[2], padding[3], strides[1],
517 dilations[1],
"width",
"x",
"left",
"right")))
522 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
523 if (failed(verifyOutputSize(
524 inputType.getDimSize(1), weightType.getDimSize(1),
525 outputType.getDimSize(1), padding[0], padding[1], strides[0],
526 dilations[0],
"depth",
"d",
"front",
"back")))
529 if (failed(verifyOutputSize(
530 inputType.getDimSize(2), weightType.getDimSize(2),
531 outputType.getDimSize(2), padding[2], padding[3], strides[1],
532 dilations[1],
"height",
"y",
"top",
"bottom")))
535 if (failed(verifyOutputSize(
536 inputType.getDimSize(3), weightType.getDimSize(3),
537 outputType.getDimSize(3), padding[4], padding[5], strides[2],
538 dilations[2],
"width",
"x",
"left",
"right")))
543 const RankedTensorType biasType =
544 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
549 const int64_t biasChannels = biasType.getDimSize(0);
550 const int64_t outputChannels =
551 outputType.getDimSize(outputType.getRank() - 1);
552 if (biasChannels == ShapedType::kDynamic ||
553 outputChannels == ShapedType::kDynamic)
557 if (biasChannels != outputChannels && biasChannels != 1)
558 return op.emitOpError(
559 "bias channels expected to be equal to output channels (")
560 << outputChannels <<
") or 1, got " << biasChannels;
567 StringRef name1,
Type type2,
569 auto shapeType1 = dyn_cast<ShapedType>(type1);
570 auto shapeType2 = dyn_cast<ShapedType>(type2);
571 if (!shapeType1 || !shapeType2)
574 auto elemType1 = shapeType1.getElementType();
575 auto elemType2 = shapeType2.getElementType();
576 if (elemType1 != elemType2)
578 <<
"require same element type for " << name1 <<
" (" << elemType1
579 <<
") and " << name2 <<
" (" << elemType2 <<
")";
583 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
584 << name2 <<
" (" << type2 <<
")";
594 if (list1.size() != list2.size())
596 <<
"require same number of values in " << name1 <<
" ("
597 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
599 for (
auto [type1, type2] :
613 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
621 tosa::VariableOp varOp =
nullptr;
635 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
636 if (symName == tosaOp.getName()) {
651 template <
typename T>
653 StringRef symName = op.getName();
656 return op->emitOpError(
"'")
657 << symName <<
"' has not been declared by 'tosa.variable'";
660 Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
669 template <
typename T>
671 auto inputType = llvm::dyn_cast<TensorType>(inType);
672 auto outputType = llvm::dyn_cast<TensorType>(outType);
674 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
678 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
681 auto inputElementType = inputType.getElementType();
682 auto outputElementType = outputType.getElementType();
683 auto inputQuantType =
684 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
685 auto outputQuantType =
686 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
687 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
688 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
689 inputElementType != outputElementType) {
694 op.emitOpError(
"expect input and output to have same element type, got ")
695 << inputElementType <<
" and " << outputElementType;
702 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
705 if (
const auto resultETy = resultType.getElementType();
706 !resultETy.isIntOrIndex())
707 return emitOpError(
"result tensor is not of integer type");
709 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
710 if (!inputType.hasRank())
714 const int64_t axis = getAxisAttr().getInt();
715 if (((axis < 0) || axis >= inputType.getRank()))
716 return emitOpError(
"specified axis is outside the rank of the tensor");
718 if (!resultType.hasRank())
724 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
726 return emitOpError(
"expected output shape '")
727 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
732 template <
typename T>
735 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
736 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
740 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
741 return op.emitOpError(
"expect all stride values to be >= 1, got ")
745 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
746 return op.emitOpError(
"expect all padding values to be >= 0, got ")
750 const int64_t kernelX = kernel[1];
751 const int64_t padLeft = padding[2];
752 const int64_t padRight = padding[3];
753 if (padRight >= kernelX || padLeft >= kernelX)
754 return op.emitOpError(
"expected left/right padding to be less than the "
755 "width of the kernel, got pad_left=")
756 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
758 const int64_t kernelY = kernel[0];
759 const int64_t padTop = padding[0];
760 const int64_t padBottom = padding[1];
761 if (padTop >= kernelY || padBottom >= kernelY)
762 return op.emitOpError(
"expected top/bottom padding to be less than the "
763 "height of the kernel, got pad_top=")
764 << padTop <<
", pad_bottom=" << padBottom
765 <<
", kernel_y=" << kernelY;
767 const auto inputType =
768 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
769 const auto outputType =
770 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
771 if (!inputType || !outputType)
774 const auto verifyOutputSize =
775 [&op](
const int64_t inputSize,
const int64_t outputSize,
776 const int64_t kernelSize,
const int64_t strideSize,
777 const int64_t padBefore,
const int64_t padAfter,
778 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
779 const llvm::StringRef padBeforeName,
780 const llvm::StringRef padAfterName) -> LogicalResult {
781 if (ShapedType::isDynamic(inputSize))
784 const std::optional<int64_t> calculatedOutSizeMinusOne =
785 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
786 if (!calculatedOutSizeMinusOne.has_value())
787 return op.emitOpError(
"expected input_")
788 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
789 << padAfterName <<
" - kernel_" << dimAxis
790 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
791 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
792 << kernelSize <<
") / " << strideSize;
794 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
795 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
796 return op.emitOpError(
"calculated output ")
797 << dimName <<
" did not match expected: "
798 <<
"calculated=" << calculatedOutSize
799 <<
", expected=" << outputSize;
804 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
805 kernel[0], strides[0], padding[0], padding[1],
806 "height",
"y",
"top",
"bottom")))
809 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
810 kernel[1], strides[1], padding[2], padding[3],
811 "width",
"x",
"left",
"right")))
826 auto accType = getAccType();
827 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
828 return emitOpError(
"accumulator type for integer tensor is not i32");
830 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
831 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
833 if (inputETy.
isBF16() && !accType.isF32())
834 return emitOpError(
"accumulator type for bf16 tensor is not f32");
836 if (inputETy.
isF32() && !accType.isF32())
837 return emitOpError(
"accumulator type for f32 tensor is not f32");
839 if (inputETy != inputZpETy)
840 return emitOpError(
"expect both input and its zero point are the same "
841 "element type, got ")
842 << inputETy <<
" and " << inputZpETy;
844 if (resultETy != outputZpETy)
845 return emitOpError(
"expect both output and its zero point are the same "
846 "element type, got ")
847 << resultETy <<
" and " << outputZpETy;
849 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
850 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
853 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
854 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
862 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
864 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
865 inputETy = quantType.getStorageType();
868 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
870 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
871 outputETy = quantType.getStorageType();
873 if (inputETy != outputETy)
874 return emitOpError(
"input/output element types are incompatible.");
876 auto maxValAttr = getMaxValAttr();
877 auto minValAttr = getMinValAttr();
881 if (inputETy.
isInteger(dataTypeBitWidth)) {
885 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
886 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
887 if (!intMaxValAttr || !intMinValAttr ||
888 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
889 (intMaxValAttr.getType() != inputETy))
890 return emitOpError(
"min/max attributes types are incompatible with "
891 "input/output element types.");
893 const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
894 const APInt minVal = intMinValAttr.getValue();
895 const APInt maxVal = intMaxValAttr.getValue();
896 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
897 return emitOpError(
"expected min_val <= max_val, got min_val=")
898 << minValAttr <<
", max_val=" << maxValAttr;
903 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
904 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
905 if (!floatMaxValAttr || !floatMinValAttr ||
906 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
907 (floatMaxValAttr.getType() != inputETy))
908 return emitOpError(
"min/max attributes types are incompatible with "
909 "input/output element types.");
911 const APFloat minVal = floatMinValAttr.getValue();
912 const APFloat maxVal = floatMaxValAttr.getValue();
913 if (minVal.isNaN() || maxVal.isNaN())
914 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
915 << minValAttr <<
", max_val=" << maxValAttr;
918 return emitOpError(
"expected min_val <= max_val, got min_val=")
919 << minValAttr <<
", max_val=" << maxValAttr;
939 result.
addOperands({input, weight, bias, zps.first, zps.second});
944 Type finalOutputType = outputType;
961 result.
addOperands({input, weight, bias, zps.first, zps.second});
965 Type finalOutputType = outputType;
984 Type finalOutputType{outputType};
987 auto inputBits = eType.getIntOrFloatBitWidth();
989 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
990 assert(outputShapedType &&
"Output must be a shaped type");
992 IntegerType accElementType;
998 finalOutputType = outputShapedType.clone(accElementType);
1009 DenseArrayAttr kernel, DenseArrayAttr stride,
1010 DenseArrayAttr pad, TypeAttr accType) {
1013 int64_t outputZp{0};
1015 if (
auto quantAttr =
1017 inputZp = quantAttr.getInputZp();
1018 outputZp = quantAttr.getOutputZp();
1020 const std::optional<Value> inputZpOp =
1025 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1027 const std::optional<Value> outputZpOp =
1030 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1031 "quantized AVG_POOL2D op");
1034 if (inputZpOp && outputZpOp) {
1035 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1046 result.
types.push_back(outputType);
1056 int64_t input1Zp{0};
1057 int64_t outputZp{0};
1060 input1Zp = quantAttr.getInputZp();
1061 outputZp = quantAttr.getOutputZp();
1063 const std::optional<Value> input1ZpOp =
1067 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1070 const std::optional<Value> outputZpOp =
1074 loc,
"Failed to create output zero point for quantized NEGATE op");
1077 if (input1ZpOp && outputZpOp) {
1078 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1086 result.
types.push_back(outputType);
1099 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1102 result.
addOperands({input, paddings, padConstOp});
1103 result.
types.push_back(outputType);
1112 int64_t outRank = 0;
1113 for (
int i = 0, e = operands.size(); i != e; ++i) {
1115 if (!shape.hasRank()) {
1120 outRank = std::max<int64_t>(outRank, shape.getRank());
1123 outShape.resize(outRank, 1);
1125 for (
int i = 0, e = operands.size(); i != e; ++i) {
1127 auto rankDiff = outShape.size() - shape.getRank();
1129 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1130 auto dim1 = outShape[i + rankDiff];
1131 auto dim2 = shape.getDimSize(i);
1132 auto resolvedDim = dim1;
1136 }
else if (dim2 == 1) {
1138 }
else if (dim1 != dim2) {
1141 outShape[i + rankDiff] = resolvedDim;
1148 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1149 MLIRContext *context, ::std::optional<Location> location,
1150 ArgMaxOp::Adaptor adaptor,
1153 IntegerAttr axis = adaptor.getProperties().axis;
1154 int32_t axisVal = axis.getValue().getSExtValue();
1156 if (!inputShape.hasRank()) {
1162 outShape.reserve(inputShape.getRank() - 1);
1163 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1166 outShape.push_back(inputShape.getDimSize(i));
1173 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1174 MLIRContext *context, ::std::optional<Location> location,
1175 RFFT2dOp::Adaptor adaptor,
1177 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1179 if (!inputShape.hasRank())
1183 outputShape.resize(3, ShapedType::kDynamic);
1184 outputShape[0] = inputShape.getDimSize(0);
1185 outputShape[1] = inputShape.getDimSize(1);
1186 int64_t inWidth = inputShape.getDimSize(2);
1190 if (inWidth != ShapedType::kDynamic)
1191 outputShape[2] = inWidth / 2 + 1;
1200 const llvm::StringRef dimName) {
1201 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1204 << dimName <<
" to be a power of two, got " << dimSize;
1210 const auto outputTypes = getResultTypes();
1212 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1214 const auto inputType =
1215 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1219 const int64_t height = inputType.getDimSize(1);
1220 if (!ShapedType::isDynamic(height) &&
1224 const int64_t width = inputType.getDimSize(2);
1225 if (!ShapedType::isDynamic(width) &&
1229 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1235 outputType.getShape().drop_back())))
1236 return emitOpError(
"expected batch and height dimensions of input/output "
1237 "to match, got input=")
1238 << inputType <<
" output=" << outputType;
1241 const int64_t outputWidth = outputType.getDimSize(2);
1242 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1243 (outputWidth != (width / 2) + 1))
1245 "expected output width to be equal to input_width / 2 + 1, got ")
1251 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1252 MLIRContext *context, ::std::optional<Location> location,
1253 FFT2dOp::Adaptor adaptor,
1255 inferredReturnShapes.push_back(
1257 inferredReturnShapes.push_back(
1263 const auto inputRealType =
1264 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1265 const auto inputImagType =
1266 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1267 if (!inputRealType || !inputImagType)
1270 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1271 return ShapedType::isDynamic(a) ? a : b;
1274 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1275 inputImagType.getDimSize(1));
1276 if (!ShapedType::isDynamic(height) &&
1280 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1281 inputImagType.getDimSize(2));
1282 if (!ShapedType::isDynamic(width) &&
1289 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1290 MLIRContext *context, ::std::optional<Location> location,
1291 ConcatOp::Adaptor adaptor,
1294 const Properties &prop = adaptor.getProperties();
1295 int32_t axis = prop.axis.getValue().getSExtValue();
1297 bool hasRankedInput =
false;
1298 for (
auto operand : adaptor.getOperands()) {
1300 if (!operandShape.hasRank())
1304 if (!hasRankedInput)
1305 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1308 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1309 if (i == axis || operandShape.isDynamicDim(i))
1311 if (outputShape[i] == ShapedType::kDynamic)
1312 outputShape[i] = operandShape.getDimSize(i);
1313 if (outputShape[i] != operandShape.getDimSize(i))
1315 "Cannot concat tensors with different sizes"
1316 " on the non-axis dimension ",
1320 hasRankedInput =
true;
1323 if (adaptor.getInput1().empty())
1327 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1328 if (!hasRankedInput) {
1334 int64_t concatDimSize = 0;
1335 for (
auto operand : adaptor.getOperands()) {
1340 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1341 concatDimSize = ShapedType::kDynamic;
1345 concatDimSize += operandShape.getDimSize(axis);
1348 outputShape[axis] = concatDimSize;
1356 auto outType = getOutput().getType();
1360 if (inputList.empty())
1361 return emitOpError(
"expect at least one input");
1363 if (!llvm::all_of(inputList, [&](
auto input) {
1365 *
this, input.getType(), outType));
1370 const int32_t axis = getAxis();
1372 for (
const auto &input : inputList) {
1373 const Type inputType = input.getType();
1375 if (currShape.hasRank()) {
1376 firstRankedInputShape = currShape;
1378 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1379 return emitOpError(
"expect axis to be within range 0 < axis < "
1380 "rank(input1[firstRankedTensorIdx]), got ")
1386 const auto allOperandsHasRank = [](
const Value input) {
1389 if (llvm::all_of(inputList, allOperandsHasRank)) {
1390 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1392 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1394 const int64_t inputRank = inputShape.getRank();
1395 const size_t operandNum = index + 1;
1398 if (inputRank != firstInputRank)
1400 "expect all operands to have the same rank, but got ")
1401 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1405 for (
int i = 0; i < inputRank; i++) {
1406 const int64_t inputDim = inputShape.getDimSize(i);
1407 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1408 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1409 inputShape.isDynamicDim(i))
1411 if (inputDim != firstInputDim)
1412 return emitOpError(
"expect all operand shapes to have the same sizes "
1413 "on non-axis dimensions, but got ")
1414 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1415 <<
" on operands 0 and " << operandNum;
1420 int64_t axisSum = 0;
1421 for (
const auto &input : inputList) {
1423 if (inputShape.isDynamicDim(axis)) {
1428 axisSum += inputShape.getDimSize(axis);
1431 if (axisSum >= 0 && outputShape.hasRank() &&
1432 !outputShape.isDynamicDim(axis) &&
1433 axisSum != outputShape.getDimSize(axis))
1434 return emitOpError(
"requires sum of axis dimensions of input1 "
1435 "equal to output axis dimension, got ")
1436 << axisSum <<
" and " << outputShape.getDimSize(axis);
1442 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1443 MLIRContext *context, ::std::optional<Location> location,
1460 if (l.size() != r.size() || l.size() != 1)
1465 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1466 MLIRContext *context, ::std::optional<Location> location,
1467 MatMulOp::Adaptor adaptor,
1474 outShape.resize(3, ShapedType::kDynamic);
1476 if (lhsShape.hasRank()) {
1477 outShape[0] = lhsShape.getDimSize(0);
1478 outShape[1] = lhsShape.getDimSize(1);
1481 if (rhsShape.hasRank()) {
1482 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1484 outShape[2] = rhsShape.getDimSize(2);
1492 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1493 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1497 return emitOpError(
"expect a shaped tensor for input a, got ")
1498 << getA().getType();
1501 return emitOpError(
"expect a shaped tensor for input b, got ")
1502 << getB().getType();
1504 auto aElementType = aType.getElementType();
1505 auto bElementType = bType.getElementType();
1507 auto aQuantizedEType =
1508 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1509 auto bQuantizedEType =
1510 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1512 if (aQuantizedEType || bQuantizedEType) {
1513 if (!aQuantizedEType || !bQuantizedEType) {
1514 return emitOpError(
"expect operands to be both quantized or both not "
1516 << aElementType <<
" and " << bElementType;
1519 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1520 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1521 if (aQuantWidth != bQuantWidth) {
1522 return emitOpError(
"expect quantized operands to have same widths, got ")
1523 << aQuantWidth <<
" and " << bQuantWidth;
1527 if (aElementType != bElementType) {
1528 return emitOpError(
"expect same element type for inputs a and b, got ")
1529 << aElementType <<
" and " << bElementType;
1536 if (aEType != aZpEType) {
1537 return emitOpError(
"expect input a and a_zp have the same "
1538 "element type, got ")
1539 << aEType <<
" and " << aZpEType;
1544 if (bEType != bZpEType) {
1545 return emitOpError(
"expect input b and b_zp have the same "
1546 "element type, got ")
1547 << bEType <<
" and " << bZpEType;
1550 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1551 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1554 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1555 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1561 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1562 MLIRContext *context, ::std::optional<Location> location,
1563 PadOp::Adaptor adaptor,
1565 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1567 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1572 if (!inputShape.hasRank()) {
1573 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1582 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1587 outputShape.reserve(inputShape.getRank());
1588 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1589 if (inputShape.isDynamicDim(i)) {
1590 outputShape.push_back(ShapedType::kDynamic);
1593 auto padFront = paddingValues[i * 2];
1594 auto padBack = paddingValues[i * 2 + 1];
1595 if (padFront < 0 || padBack < 0) {
1597 outputShape.push_back(ShapedType::kDynamic);
1601 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1615 if (
auto padConst = getPadConst()) {
1623 RankedTensorType inputType =
1624 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1625 RankedTensorType outputType =
1626 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1627 if (!inputType || !outputType)
1630 auto inputRank = inputType.getRank();
1631 auto outputRank = outputType.getRank();
1632 if (inputRank != outputRank)
1633 return emitOpError() <<
"expect same input and output tensor rank, but got "
1634 <<
"inputRank: " << inputRank
1635 <<
", outputRank: " << outputRank;
1642 auto paddingValues = paddingAttr.getValues<APInt>();
1643 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1644 return emitOpError() <<
"padding tensor must have " << inputRank
1645 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1646 << paddingValues.size();
1648 auto inputShape = inputType.getShape();
1649 auto outputShape = outputType.getShape();
1651 for (int64_t i = 0; i < inputRank; ++i) {
1652 int64_t padStart = paddingValues[i * 2].getSExtValue();
1653 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1655 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1656 return emitOpError()
1657 <<
"invalid padding values at dimension " << i
1658 <<
": values must be non-negative or -1 for dynamic padding, got ["
1659 << padStart <<
", " << padEnd <<
"]";
1663 if (inputShape[i] == ShapedType::kDynamic ||
1664 outputShape[i] == ShapedType::kDynamic)
1667 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1668 return emitOpError() <<
"mismatch in output shape at dimension " << i
1669 <<
": expected " << inputShape[i] <<
" + "
1670 << padStart <<
" + " << padEnd <<
" = "
1671 << (inputShape[i] + padStart + padEnd)
1672 <<
", but got " << outputShape[i];
1680 return to_vector(llvm::map_range(shape, [](int64_t dim) {
1681 return dim == -1 ? ShapedType::kDynamic : dim;
1685 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1686 MLIRContext *context, ::std::optional<Location> location,
1687 SliceOp::Adaptor adaptor,
1696 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1704 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1707 if (inputShape.hasRank()) {
1708 for (
size_t i = 0; i < size.size(); i++) {
1709 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1710 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1711 start[i] < inputShape.getDimSize(i))) {
1713 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1716 outputShape[i] = size[i];
1720 if (size[i] == -1) {
1721 outputShape[i] = inputShape.getDimSize(i) - start[i];
1722 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1724 outputShape[i] = size[i];
1743 if (inputShape.hasRank()) {
1744 const auto inputRank = inputShape.getRank();
1746 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1748 "expect input1 and output to have the same ranks, got ")
1749 << inputRank <<
" and " << outputShape.getRank();
1751 const auto startShapeRank =
1752 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1753 if (inputRank != startShapeRank)
1754 return emitOpError(
"length of start is not equal to rank of input shape");
1756 const auto sizeShapeRank =
1757 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1758 if (inputRank != sizeShapeRank)
1759 return emitOpError(
"length of size is not equal to rank of input shape");
1765 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1766 MLIRContext *context, ::std::optional<Location> location,
1786 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1787 IntegerType lhsIntType =
1789 IntegerType rhsIntType =
1791 if (lhsIntType != rhsIntType)
1792 return emitOpError(
"requires the same element type for all operands");
1797 if (lhsIntType.getWidth() > resIntType.getWidth())
1798 return emitOpError(
"invalid data type size for operands or result");
1803 for (
int i = 0; i < 2; ++i) {
1806 "requires the same element type for all operands and results");
1810 ElementsAttr shift_elem;
1812 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1814 return emitOpError() <<
"require shift to be 0 for float type";
1825 auto hasRank = [](
const Type type) {
1826 if (
auto shaped_type = dyn_cast<ShapedType>(type))
1827 return shaped_type.hasRank();
1832 auto rankedOperandTypes =
1833 llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1835 auto rankedResultTypes =
1836 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1839 if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1843 auto getRank = [](
const Type type) {
1844 return cast<ShapedType>(type).getRank();
1847 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1848 : getRank(*rankedResultTypes.begin());
1850 for (
size_t i = 0; i < 2; ++i) {
1851 if (rank != getRank(rankedOperandTypes[i])) {
1852 return emitOpError(
"operands don't have matching ranks");
1856 for (
const auto type : rankedResultTypes) {
1857 if (rank != getRank(type)) {
1858 return emitOpError(
"result type has different rank than operands");
1867 return mlir::cast<ShapedType>(type).getShape();
1873 return emitOpError(
"operands don't have broadcast-compatible shapes");
1879 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1880 MLIRContext *context, ::std::optional<Location> location,
1881 TableOp::Adaptor adaptor,
1883 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1885 if (!inputShape.hasRank()) {
1890 inferredReturnShapes.resize(1);
1891 inputShape.getDims(inferredReturnShapes[0]);
1896 TensorType inputType = getInput1().getType();
1897 TensorType outputType = getOutput().getType();
1900 inputType.getRank() != outputType.getRank())
1901 return emitOpError()
1902 <<
"expected input tensor rank to equal result tensor rank";
1904 auto inputDims = inputType.
getShape();
1905 auto outputDims = outputType.
getShape();
1907 int64_t dim = it.index();
1908 auto [inputDim, outputDim] = it.value();
1909 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1910 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1911 <<
" doesn't match dim(input, " << dim
1912 <<
") = " << inputDim;
1924 multiples = llvm::to_vector(
1925 llvm::map_range(multiplesAttr.getValues<APInt>(),
1926 [](
const APInt &val) { return val.getSExtValue(); }));
1930 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1931 MLIRContext *context, ::std::optional<Location> location,
1932 TileOp::Adaptor adaptor,
1939 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1947 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1949 if (!inputShape.hasRank()) {
1950 outputShape.resize(multiples.size(), ShapedType::kDynamic);
1951 inferredReturnShapes.push_back(
1954 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
1958 outputShape.reserve(multiples.size());
1959 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1960 if (multiples[i] == ShapedType::kDynamic) {
1961 outputShape.push_back(ShapedType::kDynamic);
1963 int64_t dim = inputShape.getDimSize(i);
1964 if (dim != ShapedType::kDynamic)
1965 dim *= multiples[i];
1966 outputShape.push_back(dim);
1980 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
1981 ShapedType outputType = llvm::cast<ShapedType>(
getType());
1983 shapeType multiplesType =
1984 llvm::cast<tosa::shapeType>(getMultiples().
getType());
1986 auto multiplesRank = multiplesType.getRank();
1988 if (inputType.hasRank()) {
1989 if (inputType.getRank() != multiplesRank)
1990 return emitOpError(
"expect 'multiples' to have rank ")
1991 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
1992 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1993 return emitOpError(
"expect same input and output tensor rank.");
1994 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1995 return emitOpError(
"expect 'multiples' array to have length ")
1996 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
1999 if (getConstantMultiples(multiples).succeeded() &&
2000 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2002 "expect element of 'multiples' to be positive integer or -1.");
2008 if (l.size() != r.size() || l.size() != 1)
2013 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2014 MLIRContext *context, ::std::optional<Location> location,
2015 ReshapeOp::Adaptor adaptor,
2017 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2022 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2032 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2033 inferredReturnShapes.push_back(
2041 int64_t numElements = inputShape.getNumElements();
2042 int64_t staticMul = 1;
2043 for (
auto val : newShapeValue) {
2044 if (!ShapedType::isDynamic(val)) {
2050 for (
auto &val : newShapeValue) {
2051 if (ShapedType::isDynamic(val))
2052 val = numElements / staticMul;
2055 inferredReturnShapes.push_back(
2066 TensorType inputType = getInput1().getType();
2067 RankedTensorType outputType =
getType();
2072 return mlir::success();
2075 if ((int64_t)shapeValues.size() != outputType.getRank())
2076 return emitOpError() <<
"new shape does not match result rank";
2078 for (
auto [newShapeDim, outputShapeDim] :
2079 zip(shapeValues, outputType.getShape())) {
2080 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2081 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2082 return emitOpError() <<
"new shape is inconsistent with result shape";
2084 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2085 return emitOpError() <<
"new shape has invalid tensor dimension size "
2089 if (inputType.hasStaticShape()) {
2090 int64_t inputElementsNum = inputType.getNumElements();
2091 if (outputType.hasStaticShape()) {
2092 int64_t outputElementsNum = outputType.getNumElements();
2093 if (inputElementsNum != outputElementsNum) {
2094 return emitOpError() <<
"cannot reshape " << inputElementsNum
2095 <<
" elements into " << outputElementsNum;
2099 int64_t newShapeElementsNum = std::accumulate(
2100 shapeValues.begin(), shapeValues.end(), 1LL,
2101 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2102 bool isStaticNewShape =
2103 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2104 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2105 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2106 return emitOpError() <<
"cannot reshape " << inputElementsNum
2107 <<
" elements into " << newShapeElementsNum;
2111 int missingDims = llvm::count(shapeValues, -1);
2112 if (missingDims > 1)
2113 return emitOpError() <<
"expected at most one target dimension to be -1";
2115 return mlir::success();
2122 ElementsAttr zpAttr;
2127 Type zpElemType = zpAttr.getElementType();
2129 if (llvm::isa<FloatType>(zpElemType)) {
2130 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2137 if (llvm::isa<IntegerType>(zpElemType)) {
2138 return zpAttr.getValues<APInt>()[0].getSExtValue();
2145 template <
typename T>
2147 const std::string &operand) {
2150 if (!zpElemType.
isInteger(8) && zp != 0) {
2152 std::string lower = operand;
2153 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2154 return op.emitOpError()
2155 << lower <<
" zero point must be zero for non-int8 integer types";
2163 const std::string &operand) {
2164 bool isInputZp = (operand ==
"Input");
2166 bool tensorUnsigned =
2167 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2168 StringRef tensorName = isInputZp ?
"input" :
"output";
2174 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2175 return op.emitOpError()
2176 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2178 if (zpElemType.
isInteger(16) && tensorUnsigned &&
2179 zp !=
static_cast<int16_t
>(32768)) {
2180 return op.emitOpError() <<
"expect " << tensorName
2181 <<
"_zp of 0 or 32768 for unsigned int16 "
2182 << tensorName <<
", got " << zp;
2189 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
2190 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2191 return getZeroPoint(get##OPERAND_NAME##Zp()); \
2193 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2194 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2213 #undef ZERO_POINT_HELPER
2215 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2216 MLIRContext *context, ::std::optional<Location> location,
2217 TransposeOp::Adaptor adaptor,
2219 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2228 const auto inputRank = inputShape.
getRank();
2232 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2238 if (inputRank == 0) {
2244 bool allTheSame =
true;
2245 for (
int i = 1, s = inputRank; i < s; i++) {
2255 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2260 outputShape.resize(inputRank, ShapedType::kDynamic);
2263 if (llvm::any_of(adaptor.getPerms(),
2264 [inputRank](
const auto i) { return i >= inputRank; }))
2267 outputShape.reserve(inputRank);
2268 for (
int i = 0, s = inputRank; i < s; i++) {
2269 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2288 if (inputShape.hasRank() &&
2289 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2290 return emitOpError() <<
"expected perms attribute to have size "
2291 << inputShape.getRank()
2292 <<
" (input rank) but got size "
2293 << constantPerms.size();
2295 if (inputShape.hasRank() && outputShape.hasRank() &&
2296 inputShape.getRank() != outputShape.getRank())
2297 return emitOpError()
2298 <<
"expected input tensor rank to equal result tensor rank";
2300 if (outputShape.hasRank() &&
2301 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2302 return emitOpError() <<
"expected perms attribute to have size "
2303 << outputShape.getRank()
2304 <<
" (output rank) but got size "
2305 << constantPerms.size();
2307 if (!llvm::all_of(constantPerms,
2308 [&constantPerms](int32_t s) {
2310 static_cast<size_t>(s) < constantPerms.size();
2313 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2314 return emitOpError() <<
"expected valid permutation indices";
2317 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2318 inputShape.getNumElements() != outputShape.getNumElements())
2319 return emitOpError() <<
"expected input1 and output to have same numbers "
2321 << inputShape.getNumElements() <<
" and "
2322 << outputShape.getNumElements();
2326 if (inputShape.hasRank() && outputShape.hasRank()) {
2327 for (
auto i = 0; i < outputShape.getRank(); i++) {
2328 if (inputShape.isDynamicDim(constantPerms[i]) ||
2329 outputShape.isDynamicDim(i))
2332 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2333 return emitOpError()
2334 <<
"expected output tensor dim " << i <<
" to match "
2335 <<
"input dim " << constantPerms[i] <<
" with value of "
2336 << inputShape.getDimSize(constantPerms[i]);
2348 Value input = getInput1();
2349 auto inputType = cast<TensorType>(input.
getType());
2352 for (
auto dim : transposePerms) {
2353 int32_t dimInInput = transposePerms[dim];
2354 if (inputType.isDynamicDim(dimInInput))
2356 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
2360 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2363 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2367 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2368 MLIRContext *context, ::std::optional<Location> location,
2369 GatherOp::Adaptor adaptor,
2372 outputShape.resize(3, ShapedType::kDynamic);
2374 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2375 if (valuesShape.hasRank()) {
2376 outputShape[0] = valuesShape.getDimSize(0);
2377 outputShape[2] = valuesShape.getDimSize(2);
2380 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2381 if (indicesShape.hasRank()) {
2382 if (outputShape[0] == ShapedType::kDynamic)
2383 outputShape[0] = indicesShape.getDimSize(0);
2384 if (outputShape[1] == ShapedType::kDynamic)
2385 outputShape[1] = indicesShape.getDimSize(1);
2403 int64_t N = ShapedType::kDynamic;
2404 int64_t
W = ShapedType::kDynamic;
2405 int64_t
C = ShapedType::kDynamic;
2407 if (valuesShape.hasRank()) {
2408 N = valuesShape.getDimSize(0);
2409 C = valuesShape.getDimSize(2);
2411 if (indicesShape.hasRank()) {
2412 const int64_t indicesN = indicesShape.getDimSize(0);
2413 W = indicesShape.getDimSize(1);
2414 if (N == ShapedType::kDynamic)
2416 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2417 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2418 <<
", got " << indicesN;
2420 if (outputShape.hasRank()) {
2421 const int64_t outputN = outputShape.getDimSize(0);
2422 const int64_t outputW = outputShape.getDimSize(1);
2423 const int64_t outputC = outputShape.getDimSize(2);
2424 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2426 return emitOpError() <<
"requires output dimension 0 to have size " << N
2427 <<
", got " << outputN;
2429 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2431 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2432 <<
", got " << outputW;
2433 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2435 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2436 <<
", got " << outputC;
2441 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2442 MLIRContext *context, ::std::optional<Location> location,
2443 ResizeOp::Adaptor adaptor,
2446 outputShape.resize(4, ShapedType::kDynamic);
2449 if (!inputShape.hasRank())
2452 outputShape[0] = inputShape.getDimSize(0);
2453 outputShape[3] = inputShape.getDimSize(3);
2454 int64_t inputHeight = inputShape.getDimSize(1);
2455 int64_t inputWidth = inputShape.getDimSize(2);
2457 if ((inputHeight == ShapedType::kDynamic) ||
2458 (inputWidth == ShapedType::kDynamic))
2473 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2478 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2487 const Value input = getInput();
2488 const Value output = getOutput();
2489 const RankedTensorType inputType =
2490 llvm::dyn_cast<RankedTensorType>(input.
getType());
2491 const RankedTensorType outputType =
2492 llvm::dyn_cast<RankedTensorType>(output.
getType());
2495 return emitOpError(
"expect a ranked input tensor");
2497 return emitOpError(
"expect a ranked output tensor");
2499 const int64_t oh = outputType.getDimSize(1);
2500 const int64_t ow = outputType.getDimSize(2);
2501 const int64_t ih = inputType.getDimSize(1);
2502 const int64_t iw = inputType.getDimSize(2);
2514 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2515 return emitOpError(
"expect all scale values to be > 0, got ")
2518 const int64_t scaleYN = scaleValues[0];
2519 const int64_t scaleYD = scaleValues[1];
2520 const int64_t scaleXN = scaleValues[2];
2521 const int64_t scaleXD = scaleValues[3];
2523 const int64_t offsetY = offsetValues[0];
2524 const int64_t offsetX = offsetValues[1];
2526 const int64_t borderY = borderValues[0];
2527 const int64_t borderX = borderValues[1];
2533 if (ih != ShapedType::kDynamic && ih != 1) {
2534 const std::optional<int64_t> calculatedOutHeightMinusOne =
2535 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2536 if (!calculatedOutHeightMinusOne.has_value())
2537 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2539 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2540 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2541 <<
") / " << scaleYD;
2542 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2543 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2544 return emitOpError(
"calculated output height did not match expected: ")
2545 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2552 if (iw != ShapedType::kDynamic && iw != 1) {
2553 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2554 const std::optional<int64_t> calculatedOutWidthMinusOne =
2556 if (!calculatedOutWidthMinusOne.has_value())
2557 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2559 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2560 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2561 <<
") / " << scaleXD;
2562 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2563 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2564 return emitOpError(
"calculated output width did not match expected: ")
2565 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2571 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2572 MLIRContext *context, ::std::optional<Location> location,
2573 ScatterOp::Adaptor adaptor,
2576 outputShape.resize(3, ShapedType::kDynamic);
2578 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2579 if (valuesInShape.hasRank()) {
2580 outputShape[0] = valuesInShape.getDimSize(0);
2581 outputShape[1] = valuesInShape.getDimSize(1);
2582 outputShape[2] = valuesInShape.getDimSize(2);
2585 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2586 if (indicesShape.hasRank()) {
2587 if (outputShape[0] == ShapedType::kDynamic)
2588 outputShape[0] = indicesShape.getDimSize(0);
2592 if (inputShape.hasRank()) {
2593 if (outputShape[0] == ShapedType::kDynamic)
2594 outputShape[0] = inputShape.getDimSize(0);
2595 if (outputShape[2] == ShapedType::kDynamic)
2596 outputShape[2] = inputShape.getDimSize(2);
2618 int64_t axisVal = axis.getValue().getSExtValue();
2619 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2625 operandShape.
getDims(outputShape);
2626 outputShape[axisVal] = 1;
2631 #define COMPATIBLE_RETURN_TYPES(OP) \
2632 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2633 if (l.size() != r.size() || l.size() != 1) \
2635 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2637 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2640 #define REDUCE_SHAPE_INFER(OP) \
2641 LogicalResult OP::inferReturnTypeComponents( \
2642 MLIRContext *context, ::std::optional<Location> location, \
2643 OP::Adaptor adaptor, \
2644 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2646 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2647 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2648 const Properties &prop = adaptor.getProperties(); \
2649 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2650 inferredReturnShapes); \
2652 COMPATIBLE_RETURN_TYPES(OP)
2660 #undef REDUCE_SHAPE_INFER
2662 #undef COMPATIBLE_RETURN_TYPES
2664 template <
typename T>
2667 TensorType inputType = op.getInput().getType();
2668 TensorType outputType = op.getOutput().getType();
2669 int32_t reduceAxis = op.getAxis();
2671 if (reduceAxis < 0) {
2672 op.emitOpError(
"reduce axis must not be negative");
2676 int64_t inputRank = inputType.getRank();
2679 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2680 op.emitOpError(
"expect input tensor rank (")
2681 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2687 int64_t outputRank = outputType.getRank();
2688 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2690 "expect output tensor rank to be equal to input tensor rank");
2693 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2694 op.emitOpError(
"expect output tensor rank (")
2695 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2701 if (outputRank != 0) {
2702 auto outputShape = outputType.
getShape();
2703 if (!outputType.isDynamicDim(reduceAxis) &&
2704 outputShape[reduceAxis] != 1) {
2705 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2706 << outputShape[reduceAxis];
2733 #define NARY_SHAPE_INFER(OP) \
2734 LogicalResult OP::inferReturnTypeComponents( \
2735 MLIRContext *context, ::std::optional<Location> location, \
2736 ValueShapeRange operands, DictionaryAttr attributes, \
2737 OpaqueProperties properties, RegionRange regions, \
2738 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2739 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2779 #undef PRED_SHAPE_INFER
2781 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2782 MLIRContext *context, ::std::optional<Location> location,
2783 NegateOp::Adaptor adaptor,
2785 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2792 const Type input1Type = getInput1().getType();
2793 const Type outputType = getOutput().getType();
2800 return emitOpError() <<
"requires the same shape for input1 and output";
2803 const Type input1ZpEType =
2805 if (input1EType != input1ZpEType) {
2806 return emitOpError(
"expect both input1 and its zero point are the same "
2807 "element type, got ")
2808 << input1EType <<
" and " << input1ZpEType;
2811 const Type outputZpEType =
2813 if (outputEType != outputZpEType) {
2814 return emitOpError(
"expect both output and its zero point are the same "
2815 "element type, got ")
2816 << outputEType <<
" and " << outputZpEType;
2819 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2820 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2823 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2824 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2835 outputShape.resize(4, ShapedType::kDynamic);
2850 if (!ShapedType::isDynamic(height)) {
2851 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2852 outputShape[1] = padded / stride[0] + 1;
2855 if (!ShapedType::isDynamic(width)) {
2856 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2857 outputShape[2] = padded / stride[1] + 1;
2864 LogicalResult Conv2DOp::inferReturnTypeComponents(
2865 MLIRContext *context, ::std::optional<Location> location,
2866 Conv2DOp::Adaptor adaptor,
2870 int64_t inputWidth = ShapedType::kDynamic;
2871 int64_t inputHeight = ShapedType::kDynamic;
2872 int64_t weightWidth = ShapedType::kDynamic;
2873 int64_t weightHeight = ShapedType::kDynamic;
2878 if (inputShape.hasRank()) {
2879 outputShape[0] = inputShape.getDimSize(0);
2880 inputHeight = inputShape.getDimSize(1);
2881 inputWidth = inputShape.getDimSize(2);
2885 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2886 if (weightShape.hasRank()) {
2887 outputShape[3] = weightShape.getDimSize(0);
2888 weightHeight = weightShape.getDimSize(1);
2889 weightWidth = weightShape.getDimSize(2);
2894 if (biasShape.hasRank()) {
2895 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2896 ? biasShape.getDimSize(0)
2904 if (!ShapedType::isDynamic(inputHeight) &&
2905 !ShapedType::isDynamic(weightHeight)) {
2906 int64_t inputSize = inputHeight + padding[0] + padding[1];
2907 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2908 int64_t unstridedResult = inputSize - filterSize + 1;
2909 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2912 if (!ShapedType::isDynamic(inputWidth) &&
2913 !ShapedType::isDynamic(weightWidth)) {
2914 int64_t inputSize = inputWidth + padding[2] + padding[3];
2915 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2916 int64_t unstridedResult = inputSize - filterSize + 1;
2917 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2931 LogicalResult Conv3DOp::inferReturnTypeComponents(
2932 MLIRContext *context, ::std::optional<Location> location,
2933 Conv3DOp::Adaptor adaptor,
2937 int64_t inputWidth = ShapedType::kDynamic;
2938 int64_t inputHeight = ShapedType::kDynamic;
2939 int64_t inputDepth = ShapedType::kDynamic;
2941 int64_t weightWidth = ShapedType::kDynamic;
2942 int64_t weightHeight = ShapedType::kDynamic;
2943 int64_t weightDepth = ShapedType::kDynamic;
2947 if (inputShape.hasRank()) {
2948 outputShape[0] = inputShape.getDimSize(0);
2949 inputDepth = inputShape.getDimSize(1);
2950 inputHeight = inputShape.getDimSize(2);
2951 inputWidth = inputShape.getDimSize(3);
2955 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2956 if (weightShape.hasRank()) {
2957 outputShape[4] = weightShape.getDimSize(0);
2958 weightDepth = weightShape.getDimSize(1);
2959 weightHeight = weightShape.getDimSize(2);
2960 weightWidth = weightShape.getDimSize(3);
2965 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2966 outputShape[4] = biasShape.getDimSize(0);
2973 if (!ShapedType::isDynamic(inputDepth) &&
2974 !ShapedType::isDynamic(weightDepth)) {
2975 int32_t inputSize = inputDepth + pad[0] + pad[1];
2976 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2977 int32_t unstridedResult = inputSize - filterSize + 1;
2978 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2981 if (!ShapedType::isDynamic(inputHeight) &&
2982 !ShapedType::isDynamic(weightHeight)) {
2983 int32_t inputSize = inputHeight + pad[2] + pad[3];
2984 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2985 int32_t unstridedResult = inputSize - filterSize + 1;
2986 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2989 if (!ShapedType::isDynamic(inputWidth) &&
2990 !ShapedType::isDynamic(weightWidth)) {
2991 int32_t inputSize = inputWidth + pad[4] + pad[5];
2992 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2993 int32_t unstridedResult = inputSize - filterSize + 1;
2994 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3008 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3009 MLIRContext *context, ::std::optional<Location> location,
3010 AvgPool2dOp::Adaptor adaptor,
3013 const Properties &prop = adaptor.getProperties();
3015 inferredReturnShapes);
3018 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3019 MLIRContext *context, ::std::optional<Location> location,
3020 MaxPool2dOp::Adaptor adaptor,
3023 const Properties &prop = adaptor.getProperties();
3025 inferredReturnShapes);
3039 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3040 MLIRContext *context, ::std::optional<Location> location,
3041 DepthwiseConv2DOp::Adaptor adaptor,
3045 int64_t inputWidth = ShapedType::kDynamic;
3046 int64_t inputHeight = ShapedType::kDynamic;
3047 int64_t inputChannels = ShapedType::kDynamic;
3049 int64_t weightWidth = ShapedType::kDynamic;
3050 int64_t weightHeight = ShapedType::kDynamic;
3051 int64_t depthChannels = ShapedType::kDynamic;
3055 if (inputShape.hasRank()) {
3056 outputShape[0] = inputShape.getDimSize(0);
3057 inputHeight = inputShape.getDimSize(1);
3058 inputWidth = inputShape.getDimSize(2);
3059 inputChannels = inputShape.getDimSize(3);
3063 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3064 if (weightShape.hasRank()) {
3065 weightHeight = weightShape.getDimSize(0);
3066 weightWidth = weightShape.getDimSize(1);
3067 inputChannels = ShapedType::isDynamic(inputChannels)
3068 ? weightShape.getDimSize(2)
3070 depthChannels = weightShape.getDimSize(3);
3075 if (!ShapedType::isDynamic(inputChannels) &&
3076 !ShapedType::isDynamic(depthChannels)) {
3077 outputShape[3] = inputChannels * depthChannels;
3082 if (biasShape.hasRank()) {
3083 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3084 ? biasShape.getDimSize(0)
3092 if (!ShapedType::isDynamic(inputHeight) &&
3093 !ShapedType::isDynamic(weightHeight)) {
3094 int64_t inputSize = inputHeight + padding[0] + padding[1];
3095 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3096 int64_t unstridedResult = inputSize - filterSize + 1;
3097 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3100 if (!ShapedType::isDynamic(inputWidth) &&
3101 !ShapedType::isDynamic(weightWidth)) {
3102 int64_t inputSize = inputWidth + padding[2] + padding[3];
3103 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3104 int64_t unstridedResult = inputSize - filterSize + 1;
3105 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3119 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3120 MLIRContext *context, ::std::optional<Location> location,
3121 TransposeConv2DOp::Adaptor adaptor,
3125 int64_t inputWidth = ShapedType::kDynamic;
3126 int64_t inputHeight = ShapedType::kDynamic;
3127 int64_t weightWidth = ShapedType::kDynamic;
3128 int64_t weightHeight = ShapedType::kDynamic;
3132 if (inputShape.hasRank()) {
3133 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3134 ? inputShape.getDimSize(0)
3136 inputHeight = inputShape.getDimSize(1);
3137 inputWidth = inputShape.getDimSize(2);
3141 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3142 if (weightShape.hasRank()) {
3143 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3144 ? weightShape.getDimSize(0)
3146 weightHeight = weightShape.getDimSize(1);
3147 weightWidth = weightShape.getDimSize(2);
3152 if (biasShape.hasRank()) {
3153 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3154 ? biasShape.getDimSize(0)
3161 if (!ShapedType::isDynamic(inputHeight) &&
3162 !ShapedType::isDynamic(weightHeight)) {
3163 int64_t calculateSize =
3164 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3166 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3169 if (!ShapedType::isDynamic(inputWidth) &&
3170 !ShapedType::isDynamic(weightWidth)) {
3171 int64_t calculateSize =
3172 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3174 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3186 const int64_t strideY = strides[0];
3187 const int64_t strideX = strides[1];
3189 if (strideY < 1 || strideX < 1)
3190 return emitOpError(
"expect all stride values to be >= 1, got [")
3193 const auto checkPadAgainstKernelDim =
3194 [
this](int64_t pad_value, int64_t kernel_dim_size,
3195 llvm::StringRef pad_name,
3196 llvm::StringRef kernel_dim_name) -> LogicalResult {
3197 if (pad_value <= -kernel_dim_size)
3198 return emitOpError(
"expected ")
3199 << pad_name <<
" > -" << kernel_dim_name
3200 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3201 << kernel_dim_name <<
"=" << kernel_dim_size;
3206 const int64_t outPadTop = padding[0];
3207 const int64_t outPadBottom = padding[1];
3208 const int64_t outPadLeft = padding[2];
3209 const int64_t outPadRight = padding[3];
3211 const auto weightType =
3212 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3215 const int64_t kernelHeight = weightType.getDimSize(1);
3216 if (!ShapedType::isDynamic(kernelHeight)) {
3217 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3218 "out_pad_top",
"KH")))
3221 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3222 "out_pad_bottom",
"KH")))
3226 const int64_t kernelWidth = weightType.getDimSize(2);
3227 if (!ShapedType::isDynamic(kernelWidth)) {
3228 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3229 "out_pad_left",
"KW")))
3232 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3233 "out_pad_right",
"KW")))
3239 const auto outputType =
3240 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3244 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3245 if (inputType && weightType) {
3246 const int64_t inputHeight = inputType.getDimSize(1);
3247 const int64_t kernelHeight = weightType.getDimSize(1);
3248 const int64_t outputHeight = outputType.getDimSize(1);
3250 if (!ShapedType::isDynamic(inputHeight) &&
3251 !ShapedType::isDynamic(outputHeight)) {
3253 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3255 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3256 "+ out_pad_top + out_pad_bottom + KH, but got ")
3257 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3258 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3259 <<
" + " << kernelHeight;
3262 const int64_t inputWidth = inputType.getDimSize(2);
3263 const int64_t kernelWidth = weightType.getDimSize(2);
3264 const int64_t outputWidth = outputType.getDimSize(2);
3266 if (!ShapedType::isDynamic(inputWidth) &&
3267 !ShapedType::isDynamic(outputWidth)) {
3269 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3271 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3272 "+ out_pad_left + out_pad_right + KW, but got ")
3273 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3274 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3279 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3284 const int64_t biasChannels = biasType.getDimSize(0);
3287 if (biasChannels == ShapedType::kDynamic)
3290 const int64_t outputChannels = outputType.getDimSize(3);
3291 if (biasChannels != outputChannels && biasChannels != 1)
3293 "bias channels expected to be equal to output channels (")
3294 << outputChannels <<
") or 1, got " << biasChannels;
3300 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3302 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3306 auto inputElementType =
3308 if (!mlir::isa<IntegerType>(inputElementType)) {
3309 emitOpError(
"expect input to have integer element type, got ")
3310 << inputElementType;
3314 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3316 emitOpError(
"expect shaped tensor for output, got ")
3317 << getOutput().getType();
3321 auto outputElementType =
3323 if (!mlir::isa<IntegerType>(outputElementType)) {
3324 emitOpError(
"expect output to have integer element type, got ")
3325 << outputElementType;
3337 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3338 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3341 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3342 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3345 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3346 if (!multiplierType) {
3347 emitOpError(
"expect shaped tensor for multiplier, got ")
3348 << getMultiplier().getType();
3352 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3354 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3359 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3360 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3361 << multiplierType.getElementType();
3366 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3368 "expect i16 element type for multiplier for scale32=false, got ")
3369 << multiplierType.getElementType();
3373 if (!inputType.hasRank())
3379 int64_t numChannels = 1;
3380 if (getPerChannel()) {
3381 if (inputType.getRank() < 1) {
3382 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3383 "true, but got rank ")
3384 << inputType.getRank();
3387 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3390 if (!multiplierType.hasRank())
3395 if (multiplierShape[0] != ShapedType::kDynamic &&
3396 multiplierShape[0] != numChannels) {
3397 emitOpError(
"expect shape of { ")
3398 << numChannels <<
" } for multiplier input, got { "
3399 << multiplierShape[0] <<
" }";
3403 if (!shiftType.hasRank())
3408 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3409 emitOpError(
"expect shape of { ")
3410 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3417 LogicalResult RescaleOp::inferReturnTypeComponents(
3418 MLIRContext *context, ::std::optional<Location> location,
3419 RescaleOp::Adaptor adaptor,
3426 LogicalResult IfOp::inferReturnTypeComponents(
3427 MLIRContext *context, ::std::optional<Location> location,
3428 IfOp::Adaptor adaptor,
3431 for (
Region *region : adaptor.getRegions()) {
3432 for (
auto &block : *region)
3433 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3434 yieldOps.push_back(returnOp);
3437 if (yieldOps.empty())
3442 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3443 for (
auto operand : yieldOps.front().getOperands()) {
3444 resultKnowledge.push_back(
3448 for (
auto yieldOp : yieldOps) {
3449 if (resultKnowledge.size() != yieldOp.getNumOperands())
3453 int32_t index = it.index();
3455 resultKnowledge[index],
3459 resultKnowledge[index] = meet;
3464 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3470 LogicalResult WhileOp::inferReturnTypeComponents(
3471 MLIRContext *context, ::std::optional<Location> location,
3472 WhileOp::Adaptor adaptor,
3475 for (
auto &block : adaptor.getBodyGraph())
3476 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3477 yieldOps.push_back(returnOp);
3481 if (yieldOps.empty())
3486 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3487 for (
auto operand : yieldOps.front().getOperands()) {
3488 resultKnowledge.push_back(
3492 for (
auto yieldOp : yieldOps) {
3493 if (resultKnowledge.size() != yieldOp.getNumOperands())
3497 int32_t index = it.index();
3499 resultKnowledge[index],
3501 resultKnowledge[index] = meet;
3507 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3513 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3514 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3515 return llvm::to_vector<4>(vt.getShape());
3516 return std::nullopt;
3553 bool printBlockTerminators =
false;
3555 p <<
" " << getCondition();
3556 if (!getResults().empty()) {
3557 p <<
" -> (" << getResultTypes() <<
")";
3559 printBlockTerminators =
true;
3564 printBlockTerminators);
3567 auto &elseRegion = getElseGraph();
3568 if (!elseRegion.
empty()) {
3572 printBlockTerminators);
3580 "'then_graph' arguments", getInputList(),
3586 "'else_graph' arguments", getInputList(),
3591 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3593 "'then_graph' results", getOutputList(),
3598 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3600 "'else_graph' results", getOutputList(),
3605 auto condType = getCondition().getType();
3607 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
3615 getOutputList(),
"'output_list'")
3620 "'cond_graph' arguments", getInputList(),
3626 "'body_graph' arguments", getInputList(),
3631 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3633 "'body_graph' results", getInputList(),
3640 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3641 if (condYield.getInputs().size() != 1)
3642 return emitOpError() <<
"require 'cond_graph' only have one result";
3644 auto condOutType = condYield.getInputs()[0].getType();
3646 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
3650 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
3661 TensorType inputType = getInput1().getType();
3662 TensorType outputType = getOutput().getType();
3663 int32_t reverseAxis = getAxis();
3665 if (reverseAxis < 0)
3666 return emitOpError(
"expected non-negative reverse axis");
3668 int64_t inputRank = inputType.getRank();
3671 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3672 return emitOpError(
"expect input tensor rank (")
3673 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3677 int64_t outputRank = outputType.getRank();
3678 if (inputType.
hasRank() && outputRank != inputType.getRank())
3680 "expect output tensor rank to be equal to input tensor rank");
3681 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3682 return emitOpError(
"expect output tensor rank (")
3683 << outputRank <<
") to be larger than reverse axis ("
3684 << reverseAxis <<
")";
3700 auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().
getType());
3701 if (!predicateType) {
3702 return emitOpError(
"expect shaped tensor for input1, got ")
3703 << getInput1().getType();
3705 auto predicateElementType = predicateType.getElementType();
3706 if (!predicateElementType.isInteger(1)) {
3707 return emitOpError(
"expect element type of bool for input1, got ")
3708 << predicateElementType;
3715 StringRef symName = getName();
3717 if (succeeded(varOp))
3718 return emitOpError(
"illegal to have multiple declaration of '")
3752 FunctionType functionType;
3757 result.
addTypes(functionType.getResults());
3759 if (functionType.getNumInputs() != operands.size()) {
3761 <<
"expected as many input types as operands "
3762 <<
"(expected " << operands.size() <<
" got "
3763 << functionType.getNumInputs() <<
")";
3773 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3774 regionArgs[i].type = functionType.getInput(i);
3776 return failure(parser.
parseRegion(*cond, regionArgs) ||
3784 StringRef prefix =
"") {
3785 assert(blocksArgs.size() == initializers.size() &&
3786 "expected same length of arguments and initializers");
3787 if (initializers.empty())
3790 parser << prefix <<
'(';
3791 llvm::interleaveComma(
3792 llvm::zip(blocksArgs, initializers), parser,
3793 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3799 getInputList(),
" ");
3802 getResults().getTypes());
3817 if (llvm::isa<FloatType>(srcElemType)) {
3819 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3820 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3822 if (llvm::isa<IntegerType>(srcElemType)) {
3825 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3827 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3828 return std::nullopt;
3836 return mlir::isa<tosa::shapeType>(t);
3843 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3849 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3850 Operation *definingOp = v.getDefiningOp();
3852 return op->
emitOpError(
"shape operand is not compile time resolvable");
3861 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3862 return op->
emitOpError(
"must have operands with tosa shape type");
3866 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3867 return op->
emitOpError(
"must have result with tosa shape type");
3880 auto getRank = [](
const Type type) {
3881 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3887 for (
auto type : operandTypes) {
3888 if (getRank(type) != rank) {
3889 return op->
emitOpError(
"operands don't have matching ranks");
3892 for (
auto type : resultTypes) {
3893 if (getRank(type) != rank) {
3894 return op->
emitOpError(
"result shape has different rank than operands");
3906 auto valuesRank = getValues().getType().getRank();
3907 if (valuesRank != 1)
3908 return emitOpError(
"expect elements in attribute values with rank 1");
3910 auto count = getValues().getNumElements();
3911 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
3912 if (!(count == rank || (count == 1 && rank == 0))) {
3913 return emitOpError(
"expect number of elements in attribute values (")
3914 << count <<
") to be equal to the rank (" << rank
3915 <<
") for the result shape type";
3924 #define GET_ATTRDEF_CLASSES
3925 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3930 #define GET_TYPEDEF_CLASSES
3931 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3937 #define GET_OP_CLASSES
3938 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static FailureOr< int64_t > getZeroPoint(Value val)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)