29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(
Dialect *dialect)
89 LogicalResult writeAttribute(
Attribute attr,
91 return ::writeAttribute(attr, writer);
101 LogicalResult writeType(
Type type,
103 return ::writeType(type, writer);
110 std::unique_ptr<DialectVersion>
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
131 return {&getBodyGraph()};
138 void TosaDialect::initialize() {
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
151 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152 declarePromisedInterfaces<
153 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159 GreaterEqualOp, MatMulOp>();
166 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167 return builder.
create<tosa::ConstShapeOp>(
168 loc, type, llvm::cast<DenseIntElementsAttr>(value));
170 if (llvm::isa<ElementsAttr>(value))
171 return builder.
create<tosa::ConstOp>(loc, type,
172 llvm::cast<ElementsAttr>(value));
185 <<
"expected attribute";
187 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
204 bool needsSpace =
false;
205 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206 if (!typedAttr || typedAttr.getType() != type.getValue()) {
223 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
231 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
232 srcType = quantType.getStorageType();
241 Value valZp, StringRef name) {
246 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
250 if (!bothInts || !sameBitWidth) {
252 <<
"expected " << name <<
" and " << name
253 <<
"_zp to both be integer of the same bitwidth, but got " << eType
254 <<
" vs. " << eZpType;
261 Value src, int32_t val) {
266 const auto padConstAttr{
267 llvm::isa<FloatType>(srcElemType)
272 return builder.
create<tosa::ConstOp>(loc, padConstType, padConstAttr);
279 template <
typename T>
283 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
285 op.emitOpError(
"expect a ranked tensor for input, got ") << op.getInput();
289 auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
291 op.emitOpError(
"expect a ranked tensor for weight, got ") << op.getWeight();
295 auto inputEType = inputType.getElementType();
296 auto weightEType = weightType.getElementType();
298 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
300 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
301 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
302 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
304 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
305 inputEType = quantType.getStorageType();
307 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
308 weightEType = quantType.getStorageType();
310 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
311 biasEType = quantType.getStorageType();
313 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
314 resultEType = quantType.getStorageType();
316 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
320 "expect both bias and result to have same element type, got ")
321 << biasEType <<
" and " << resultEType;
325 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
326 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
327 if (inputEType != weightEType) {
329 "expect both input and weight to have same element type, got ")
330 << inputEType <<
" and " << weightEType;
335 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
336 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
339 if (inputIsFloat != weightIsFloat) {
341 "expect both input and weight to be float or not together, got ")
342 << inputEType <<
" and " << weightEType;
347 if (inputEType != inputZpEType) {
348 return op.emitOpError(
"expect both input and its zero point are the same "
349 "element type, got ")
350 << inputEType <<
" and " << inputZpEType;
354 if (weightEType != weightZpEType) {
355 return op.emitOpError(
"expect both weight and its zero point are the same "
356 "element type, got ")
357 << weightEType <<
" and " << weightZpEType;
360 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
361 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
364 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
365 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
373 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
374 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
376 if (!attrType || !outputType) {
377 emitOpError(
"expected tensors for attr/result type");
381 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
382 outputType.getElementType())) {
383 if (result.getStorageType() == attrType.getElementType())
387 if (attrType.getElementType() != outputType.getElementType()) {
388 emitOpError(
"expected same attr/result element types");
395 template <
typename T>
398 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
400 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
401 inputEType = quantType.getStorageType();
403 auto accType = op.getAccType();
404 if (inputEType.isInteger(8) && !accType.isInteger(32))
405 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
407 if (inputEType.isInteger(16) && !accType.isInteger(48))
408 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
410 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
411 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
413 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
414 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
416 if (inputEType.isBF16() && !accType.isF32())
417 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
419 if (inputEType.isF32() && !accType.isF32())
420 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
423 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
425 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
426 resultEType = quantType.getStorageType();
429 if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
430 (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
431 (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
432 (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
433 (inputEType.isF16() && resultEType.isF16()) ||
434 (inputEType.isBF16() && resultEType.isBF16()) ||
435 (inputEType.isF32() && resultEType.isF32()))
438 return op.emitOpError(
"input/output element types are incompatible.");
442 template <
typename T>
444 auto inputType = llvm::dyn_cast<TensorType>(inType);
445 auto outputType = llvm::dyn_cast<TensorType>(outType);
447 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
451 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
454 auto inputElementType = inputType.getElementType();
455 auto outputElementType = outputType.getElementType();
456 auto inputQuantType =
457 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
458 auto outputQuantType =
459 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
460 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
461 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
462 inputElementType != outputElementType) {
467 op.emitOpError(
"expect input and output to have same element type, got ")
468 << inputElementType <<
" and " << outputElementType;
475 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
478 if (
const auto resultETy = resultType.getElementType();
479 !resultETy.isIntOrIndex())
480 return emitOpError(
"result tensor is not of integer type");
482 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
483 if (!inputType.hasRank())
487 const int64_t axis = getAxisAttr().getInt();
488 if (((axis < 0) || axis >= inputType.getRank()))
489 return emitOpError(
"specified axis is outside the rank of the tensor");
491 if (!resultType.hasRank())
498 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
500 return emitOpError(
"expected output shape '")
501 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
506 template <
typename T>
509 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
510 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
514 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
515 return op.emitOpError(
"expect all stride values to be >= 1, got ")
519 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
520 return op.emitOpError(
"expect all padding values to be >= 0, got ")
524 const int64_t kernelX = kernel[1];
525 const int64_t padLeft = padding[2];
526 const int64_t padRight = padding[3];
527 if (padRight >= kernelX || padLeft >= kernelX)
528 return op.emitOpError(
"expected left/right padding to be less than the "
529 "width of the kernel, got pad_left=")
530 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
532 const int64_t kernelY = kernel[0];
533 const int64_t padTop = padding[0];
534 const int64_t padBottom = padding[1];
535 if (padTop >= kernelY || padBottom >= kernelY)
536 return op.emitOpError(
"expected top/bottom padding to be less than the "
537 "height of the kernel, got pad_top=")
538 << padTop <<
", pad_bottom=" << padBottom
539 <<
", kernel_y=" << kernelY;
541 const auto inputType =
542 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
543 const auto outputType =
544 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
545 if (!inputType || !outputType)
548 const auto verifyOutputSize =
549 [&op](
const int64_t inputSize,
const int64_t outputSize,
550 const int64_t kernelSize,
const int64_t strideSize,
551 const int64_t padBefore,
const int64_t padAfter,
552 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
553 const llvm::StringRef padBeforeName,
554 const llvm::StringRef padAfterName) -> LogicalResult {
555 if (ShapedType::isDynamic(inputSize))
558 const std::optional<int64_t> calculatedOutSizeMinusOne =
559 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
560 if (!calculatedOutSizeMinusOne.has_value())
561 return op.emitOpError(
"expected input_")
562 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
563 << padAfterName <<
" - kernel_" << dimAxis
564 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
565 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
566 << kernelSize <<
") / " << strideSize;
568 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
569 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
570 return op.emitOpError(
"calculated output ")
571 << dimName <<
" did not match expected: "
572 <<
"calculated=" << calculatedOutSize
573 <<
", expected=" << outputSize;
578 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
579 kernel[0], strides[0], padding[0], padding[1],
580 "height",
"y",
"top",
"bottom")))
583 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
584 kernel[1], strides[1], padding[2], padding[3],
585 "width",
"x",
"left",
"right")))
600 auto accType = getAccType();
601 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
602 return emitOpError(
"accumulator type for integer tensor is not i32");
604 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
605 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
607 if (inputETy.
isBF16() && !accType.isF32())
608 return emitOpError(
"accumulator type for bf16 tensor is not f32");
610 if (inputETy.
isF32() && !accType.isF32())
611 return emitOpError(
"accumulator type for f32 tensor is not f32");
613 if (inputETy != inputZpETy)
614 return emitOpError(
"expect both input and its zero point are the same "
615 "element type, got ")
616 << inputETy <<
" and " << inputZpETy;
618 if (resultETy != outputZpETy)
619 return emitOpError(
"expect both output and its zero point are the same "
620 "element type, got ")
621 << resultETy <<
" and " << outputZpETy;
623 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
624 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
627 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
628 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
636 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
638 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
639 inputETy = quantType.getStorageType();
642 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
644 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
645 outputETy = quantType.getStorageType();
647 if (inputETy != outputETy)
648 return emitOpError(
"input/output element types are incompatible.");
650 auto maxValAttr = getMaxValAttr();
651 auto minValAttr = getMinValAttr();
655 if (inputETy.
isInteger(dataTypeBitWidth)) {
659 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
660 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
661 if (!intMaxValAttr || !intMinValAttr ||
662 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
663 (intMaxValAttr.getType() != inputETy))
664 return emitOpError(
"min/max attributes types are incompatible with "
665 "input/output element types.");
670 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
671 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
672 if (!floatMaxValAttr || !floatMinValAttr ||
673 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
674 (floatMaxValAttr.getType() != inputETy))
675 return emitOpError(
"min/max attributes types are incompatible with "
676 "input/output element types.");
696 result.
addOperands({input, weight, bias, zps.first, zps.second});
701 Type finalOutputType = outputType;
718 result.
addOperands({input, weight, bias, zps.first, zps.second});
722 Type finalOutputType = outputType;
741 Type finalOutputType{outputType};
744 auto inputBits = eType.getIntOrFloatBitWidth();
746 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
747 assert(outputShapedType &&
"Output must be a shaped type");
749 IntegerType accElementType;
755 finalOutputType = outputShapedType.clone(accElementType);
766 DenseArrayAttr kernel, DenseArrayAttr stride,
767 DenseArrayAttr pad, TypeAttr accType) {
774 inputZp = quantAttr.getInputZp();
775 outputZp = quantAttr.getOutputZp();
777 const std::optional<Value> inputZpOp =
782 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
784 const std::optional<Value> outputZpOp =
787 (void)
emitError(loc,
"Failed to create output zero point tensor for "
788 "quantized AVG_POOL2D op");
791 if (inputZpOp && outputZpOp) {
792 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
803 result.
types.push_back(outputType);
817 input1Zp = quantAttr.getInputZp();
818 outputZp = quantAttr.getOutputZp();
820 const std::optional<Value> input1ZpOp =
824 loc,
"Failed to create input1 zero point for quantized NEGATE op");
827 const std::optional<Value> outputZpOp =
831 loc,
"Failed to create output zero point for quantized NEGATE op");
834 if (input1ZpOp && outputZpOp) {
835 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
843 result.
types.push_back(outputType);
856 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
860 result.
types.push_back(outputType);
870 for (
int i = 0, e = operands.size(); i != e; ++i) {
872 if (!shape.hasRank()) {
877 outRank = std::max<int64_t>(outRank, shape.getRank());
880 outShape.resize(outRank, 1);
882 for (
int i = 0, e = operands.size(); i != e; ++i) {
884 auto rankDiff = outShape.size() - shape.getRank();
886 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
887 auto dim1 = outShape[i + rankDiff];
888 auto dim2 = shape.getDimSize(i);
889 auto resolvedDim = dim1;
893 }
else if (dim2 == 1) {
895 }
else if (dim1 != dim2) {
898 outShape[i + rankDiff] = resolvedDim;
905 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
906 MLIRContext *context, ::std::optional<Location> location,
907 ArgMaxOp::Adaptor adaptor,
910 IntegerAttr axis = adaptor.getProperties().axis;
911 int32_t axisVal = axis.getValue().getSExtValue();
913 if (!inputShape.hasRank()) {
919 outShape.reserve(inputShape.getRank() - 1);
920 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
923 outShape.push_back(inputShape.getDimSize(i));
930 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
931 MLIRContext *context, ::std::optional<Location> location,
932 RFFT2dOp::Adaptor adaptor,
934 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
936 if (!inputShape.hasRank())
940 outputShape.resize(3, ShapedType::kDynamic);
941 outputShape[0] = inputShape.getDimSize(0);
942 outputShape[1] = inputShape.getDimSize(1);
943 int64_t inWidth = inputShape.getDimSize(2);
947 if (inWidth != ShapedType::kDynamic)
948 outputShape[2] = inWidth / 2 + 1;
957 const llvm::StringRef dimName) {
958 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
961 << dimName <<
" to be a power of two, got " << dimSize;
967 const auto outputTypes = getResultTypes();
969 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
971 const auto inputType =
972 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
976 const int64_t height = inputType.getDimSize(1);
977 if (!ShapedType::isDynamic(height) &&
981 const int64_t width = inputType.getDimSize(2);
982 if (!ShapedType::isDynamic(width) &&
986 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
992 outputType.getShape().drop_back())))
993 return emitOpError(
"expected batch and height dimensions of input/output "
994 "to match, got input=")
995 << inputType <<
" output=" << outputType;
998 const int64_t outputWidth = outputType.getDimSize(2);
999 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1000 (outputWidth != (width / 2) + 1))
1002 "expected output width to be equal to input_width / 2 + 1, got ")
1008 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1009 MLIRContext *context, ::std::optional<Location> location,
1010 FFT2dOp::Adaptor adaptor,
1012 inferredReturnShapes.push_back(
1014 inferredReturnShapes.push_back(
1020 const auto inputRealType =
1021 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1022 const auto inputImagType =
1023 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1024 if (!inputRealType || !inputImagType)
1027 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1028 return ShapedType::isDynamic(a) ? a : b;
1031 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1032 inputImagType.getDimSize(1));
1033 if (!ShapedType::isDynamic(height) &&
1037 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1038 inputImagType.getDimSize(2));
1039 if (!ShapedType::isDynamic(width) &&
1046 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1047 MLIRContext *context, ::std::optional<Location> location,
1048 ConcatOp::Adaptor adaptor,
1051 const Properties &prop = adaptor.getProperties();
1052 int32_t axis = prop.axis.getValue().getSExtValue();
1054 bool hasRankedInput =
false;
1055 for (
auto operand : adaptor.getOperands()) {
1057 if (!operandShape.hasRank())
1061 if (!hasRankedInput)
1062 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1065 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1066 if (i == axis || operandShape.isDynamicDim(i))
1068 if (outputShape[i] == ShapedType::kDynamic)
1069 outputShape[i] = operandShape.getDimSize(i);
1070 if (outputShape[i] != operandShape.getDimSize(i))
1072 "Cannot concat tensors with different sizes"
1073 " on the non-axis dimension ",
1077 hasRankedInput =
true;
1080 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1081 if (!hasRankedInput) {
1087 int64_t concatDimSize = 0;
1088 for (
auto operand : adaptor.getOperands()) {
1093 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1094 concatDimSize = ShapedType::kDynamic;
1098 concatDimSize += operandShape.getDimSize(axis);
1101 outputShape[axis] = concatDimSize;
1109 auto outType = getOutput().getType();
1113 if (inputList.empty())
1114 return emitOpError(
"expect at least one input");
1116 if (!llvm::all_of(inputList, [&](
auto input) {
1118 *
this, input.getType(), outType));
1123 const int32_t axis = getAxis();
1125 for (
const auto &input : inputList) {
1126 const Type inputType = input.getType();
1128 if (currShape.hasRank()) {
1129 firstRankedInputShape = currShape;
1131 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1132 return emitOpError(
"expect axis to be within range 0 < axis < "
1133 "rank(input1[firstRankedTensorIdx]), got ")
1139 const auto allOperandsHasRank = [](
const Value input) {
1142 if (llvm::all_of(inputList, allOperandsHasRank)) {
1143 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1145 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1147 const int64_t inputRank = inputShape.getRank();
1148 const size_t operandNum = index + 1;
1151 if (inputRank != firstInputRank)
1153 "expect all operands to have the same rank, but got ")
1154 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1158 for (
int i = 0; i < inputRank; i++) {
1159 const int64_t inputDim = inputShape.getDimSize(i);
1160 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1161 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1162 inputShape.isDynamicDim(i))
1164 if (inputDim != firstInputDim)
1165 return emitOpError(
"expect all operand shapes to have the same sizes "
1166 "on non-axis dimensions, but got ")
1167 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1168 <<
" on operands 0 and " << operandNum;
1176 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1177 MLIRContext *context, ::std::optional<Location> location,
1194 if (l.size() != r.size() || l.size() != 1)
1199 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1200 MLIRContext *context, ::std::optional<Location> location,
1201 MatMulOp::Adaptor adaptor,
1208 outShape.resize(3, ShapedType::kDynamic);
1210 if (lhsShape.hasRank()) {
1211 outShape[0] = lhsShape.getDimSize(0);
1212 outShape[1] = lhsShape.getDimSize(1);
1215 if (rhsShape.hasRank()) {
1216 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1218 outShape[2] = rhsShape.getDimSize(2);
1226 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1227 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1231 return emitOpError(
"expect a shaped tensor for input a, got ")
1232 << getA().getType();
1235 return emitOpError(
"expect a shaped tensor for input b, got ")
1236 << getB().getType();
1238 auto aElementType = aType.getElementType();
1239 auto bElementType = bType.getElementType();
1241 auto aQuantizedEType =
1242 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1243 auto bQuantizedEType =
1244 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1246 if (aQuantizedEType || bQuantizedEType) {
1247 if (!aQuantizedEType || !bQuantizedEType) {
1248 return emitOpError(
"expect operands to be both quantized or both not "
1250 << aElementType <<
" and " << bElementType;
1253 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1254 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1255 if (aQuantWidth != bQuantWidth) {
1256 return emitOpError(
"expect quantized operands to have same widths, got ")
1257 << aQuantWidth <<
" and " << bQuantWidth;
1261 if (aElementType != bElementType) {
1262 return emitOpError(
"expect same element type for inputs a and b, got ")
1263 << aElementType <<
" and " << bElementType;
1270 if (aEType != aZpEType) {
1271 return emitOpError(
"expect input a and a_zp have the same "
1272 "element type, got ")
1273 << aEType <<
" and " << aZpEType;
1278 if (bEType != bZpEType) {
1279 return emitOpError(
"expect input b and b_zp have the same "
1280 "element type, got ")
1281 << bEType <<
" and " << bZpEType;
1284 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1285 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1288 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1289 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1295 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1296 MLIRContext *context, ::std::optional<Location> location,
1297 PadOp::Adaptor adaptor,
1299 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1301 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1306 if (!inputShape.hasRank()) {
1307 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1316 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1321 outputShape.reserve(inputShape.getRank());
1322 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1323 if (inputShape.isDynamicDim(i)) {
1324 outputShape.push_back(ShapedType::kDynamic);
1327 auto padFront = paddingValues[i * 2];
1328 auto padBack = paddingValues[i * 2 + 1];
1329 if (padFront < 0 || padBack < 0) {
1331 outputShape.push_back(ShapedType::kDynamic);
1335 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1349 if (
auto padConst = getPadConst()) {
1357 RankedTensorType inputType =
1358 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1359 RankedTensorType outputType =
1360 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1361 if (!inputType || !outputType)
1364 auto paddingRank = cast<tosa::shapeType>(getPadding().
getType()).getRank();
1366 if (inputType.getRank() != outputType.getRank())
1367 return emitOpError() <<
"expect same input and output tensor rank.";
1369 if (paddingRank != inputType.getRank() * 2)
1370 return emitOpError() <<
"expected padding tensor dim 0 to have size "
1371 << inputType.getRank() * 2
1372 <<
" (2*rank(shape1)) but got size " << paddingRank;
1378 return to_vector(llvm::map_range(shape, [](int64_t dim) {
1379 return dim == -1 ? ShapedType::kDynamic : dim;
1383 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1384 MLIRContext *context, ::std::optional<Location> location,
1385 SliceOp::Adaptor adaptor,
1394 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1402 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1405 if (inputShape.hasRank()) {
1406 for (
size_t i = 0; i < size.size(); i++) {
1407 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1408 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1409 start[i] < inputShape.getDimSize(i))) {
1411 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1414 outputShape[i] = size[i];
1418 if (size[i] == -1) {
1419 outputShape[i] = inputShape.getDimSize(i) - start[i];
1420 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1422 outputShape[i] = size[i];
1439 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1443 auto startShapeRank =
1444 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1445 if (inputType.getRank() != startShapeRank)
1446 return emitOpError(
"length of start is not equal to rank of input shape");
1448 auto sizeShapeRank =
1449 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1450 if (inputType.getRank() != sizeShapeRank)
1451 return emitOpError(
"length of size is not equal to rank of input shape");
1456 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1457 MLIRContext *context, ::std::optional<Location> location,
1477 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1478 IntegerType lhsIntType =
1480 IntegerType rhsIntType =
1482 if (lhsIntType != rhsIntType)
1483 return emitOpError(
"requires the same element type for all operands");
1488 if (lhsIntType.getWidth() > resIntType.getWidth())
1489 return emitOpError(
"invalid data type size for operands or result");
1494 for (
int i = 0; i < 2; ++i) {
1497 "requires the same element type for all operands and results");
1501 ElementsAttr shift_elem;
1503 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1505 return emitOpError() <<
"require shift to be 0 for float type";
1516 auto hasRank = [](
const Type type) {
1517 if (
auto shaped_type = dyn_cast<ShapedType>(type))
1518 return shaped_type.hasRank();
1523 auto rankedOperandTypes =
1524 llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1526 auto rankedResultTypes =
1527 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1530 if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1534 auto getRank = [](
const Type type) {
1535 return cast<ShapedType>(type).getRank();
1538 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1539 : getRank(*rankedResultTypes.begin());
1541 for (
size_t i = 0; i < 2; ++i) {
1542 if (rank != getRank(rankedOperandTypes[i])) {
1543 return emitOpError(
"operands don't have matching ranks");
1547 for (
const auto type : rankedResultTypes) {
1548 if (rank != getRank(type)) {
1549 return emitOpError(
"result type has different rank than operands");
1558 return mlir::cast<ShapedType>(type).getShape();
1564 return emitOpError(
"operands don't have broadcast-compatible shapes");
1570 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1571 MLIRContext *context, ::std::optional<Location> location,
1572 TableOp::Adaptor adaptor,
1574 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1576 if (!inputShape.hasRank()) {
1581 inferredReturnShapes.resize(1);
1582 inputShape.getDims(inferredReturnShapes[0]);
1587 TensorType inputType = getInput1().getType();
1588 TensorType outputType = getOutput().getType();
1591 inputType.getRank() != outputType.getRank())
1592 return emitOpError()
1593 <<
"expected input tensor rank to equal result tensor rank";
1595 auto inputDims = inputType.
getShape();
1596 auto outputDims = outputType.
getShape();
1598 int64_t dim = it.index();
1599 auto [inputDim, outputDim] = it.value();
1600 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1601 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1602 <<
" doesn't match dim(input, " << dim
1603 <<
") = " << inputDim;
1615 multiples = llvm::to_vector(
1616 llvm::map_range(multiplesAttr.getValues<APInt>(),
1617 [](
const APInt &val) { return val.getSExtValue(); }));
1621 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1622 MLIRContext *context, ::std::optional<Location> location,
1623 TileOp::Adaptor adaptor,
1630 llvm::map_range(multiplesAttr.getValues<APInt>(),
1631 [](
const APInt &val) { return val.getSExtValue(); }));
1633 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1635 if (!inputShape.hasRank()) {
1636 outputShape.resize(multiples.size(), ShapedType::kDynamic);
1639 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
1643 outputShape.reserve(multiples.size());
1644 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1645 int64_t dim = inputShape.getDimSize(i);
1646 if (dim != ShapedType::kDynamic)
1647 dim *= multiples[i];
1648 outputShape.push_back(dim);
1661 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
1662 ShapedType outputType = llvm::cast<ShapedType>(
getType());
1664 shapeType multiplesType =
1665 llvm::cast<tosa::shapeType>(getMultiples().
getType());
1667 auto multiplesRank = multiplesType.getRank();
1669 if (inputType.hasRank()) {
1670 if (inputType.getRank() != multiplesRank)
1671 return emitOpError(
"expect 'multiples' to have rank ")
1672 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
1673 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1674 return emitOpError(
"expect same input and output tensor rank.");
1675 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1676 return emitOpError(
"expect 'multiples' array to have length ")
1677 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
1680 if (getConstantMultiples(multiples).succeeded() &&
1681 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
1683 "expect element of 'multiples' to be positive integer or -1.");
1689 if (l.size() != r.size() || l.size() != 1)
1694 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1695 MLIRContext *context, ::std::optional<Location> location,
1696 ReshapeOp::Adaptor adaptor,
1698 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1703 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1713 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1714 inferredReturnShapes.push_back(
1722 int64_t numElements = inputShape.getNumElements();
1723 int64_t staticMul = 1;
1724 for (
auto val : newShapeValue) {
1725 if (!ShapedType::isDynamic(val)) {
1731 for (
auto &val : newShapeValue) {
1732 if (ShapedType::isDynamic(val))
1733 val = numElements / staticMul;
1736 inferredReturnShapes.push_back(
1747 TensorType inputType = getInput1().getType();
1748 RankedTensorType outputType =
getType();
1753 return mlir::success();
1756 if ((int64_t)shapeValues.size() != outputType.getRank())
1757 return emitOpError() <<
"new shape does not match result rank";
1759 for (
auto [newShapeDim, outputShapeDim] :
1760 zip(shapeValues, outputType.getShape())) {
1761 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1762 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1763 return emitOpError() <<
"new shape is inconsistent with result shape";
1765 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1766 return emitOpError() <<
"new shape has invalid tensor dimension size "
1770 if (inputType.hasStaticShape()) {
1771 int64_t inputElementsNum = inputType.getNumElements();
1772 if (outputType.hasStaticShape()) {
1773 int64_t outputElementsNum = outputType.getNumElements();
1774 if (inputElementsNum != outputElementsNum) {
1775 return emitOpError() <<
"cannot reshape " << inputElementsNum
1776 <<
" elements into " << outputElementsNum;
1780 int64_t newShapeElementsNum = std::accumulate(
1781 shapeValues.begin(), shapeValues.end(), 1LL,
1782 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1783 bool isStaticNewShape =
1784 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
1785 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1786 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1787 return emitOpError() <<
"cannot reshape " << inputElementsNum
1788 <<
" elements into " << newShapeElementsNum;
1792 int missingDims = llvm::count(shapeValues, -1);
1793 if (missingDims > 1)
1794 return emitOpError() <<
"expected at most one target dimension to be -1";
1796 return mlir::success();
1802 template <
typename T>
1804 ElementsAttr zpAttr;
1809 Type zpElemType = zpAttr.getElementType();
1811 if (llvm::isa<FloatType>(zpElemType)) {
1812 if (zpAttr.getValues<APFloat>()[0].isZero()) {
1819 if (llvm::isa<IntegerType>(zpElemType)) {
1820 return zpAttr.getValues<APInt>()[0].getSExtValue();
1827 template <
typename T>
1829 const std::string &operand) {
1832 if (!zpElemType.
isInteger(8) && zp != 0) {
1834 std::string lower = operand;
1835 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
1836 return op.emitOpError()
1837 << lower <<
" zero point must be zero for non-int8 integer types";
1845 const std::string &operand) {
1846 bool isInputZp = (operand ==
"Input");
1848 bool tensorUnsigned =
1849 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
1850 StringRef tensorName = isInputZp ?
"input" :
"output";
1856 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
1857 return op.emitOpError()
1858 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
1860 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
1861 return op.emitOpError() <<
"expect " << tensorName
1862 <<
"_zp of 0 or 32768 for unsigned int16 "
1863 << tensorName <<
", got " << zp;
1870 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
1871 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
1872 return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
1874 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
1875 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
1894 #undef ZERO_POINT_HELPER
1896 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1897 MLIRContext *context, ::std::optional<Location> location,
1898 TransposeOp::Adaptor adaptor,
1900 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1909 const auto inputRank = inputShape.
getRank();
1913 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
1919 if (inputRank == 0) {
1925 bool allTheSame =
true;
1926 for (
int i = 1, s = inputRank; i < s; i++) {
1936 outputShape.resize(inputRank, inputShape.
getDimSize(0));
1941 outputShape.resize(inputRank, ShapedType::kDynamic);
1944 if (llvm::any_of(adaptor.getPerms(),
1945 [inputRank](
const auto i) { return i >= inputRank; }))
1948 outputShape.reserve(inputRank);
1949 for (
int i = 0, s = inputRank; i < s; i++) {
1950 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
1963 TensorType inputType = getInput1().getType();
1964 TensorType outputType = getOutput().getType();
1968 constantPerms.size() !=
static_cast<size_t>(inputType.getRank()))
1969 return emitOpError() <<
"expected perms attribute to have size "
1970 << inputType.getRank() <<
" (input rank) but got size "
1971 << constantPerms.size();
1973 inputType.getRank() != outputType.getRank())
1974 return emitOpError()
1975 <<
"expected input tensor rank to equal result tensor rank";
1977 constantPerms.size() !=
static_cast<size_t>(outputType.getRank()))
1978 return emitOpError() <<
"expected perms attribute to have size "
1979 << outputType.getRank()
1980 <<
" (output rank) but got size "
1981 << constantPerms.size();
1983 if (!llvm::all_of(constantPerms,
1984 [&constantPerms](int32_t s) {
1986 static_cast<size_t>(s) < constantPerms.size();
1989 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
1990 return emitOpError() <<
"expected valid permutation indices";
1995 assert(constantPerms.size() ==
static_cast<size_t>(inputType.getRank()) &&
1996 inputType.getRank() == outputType.getRank());
1998 for (
auto i = 0; i < outputType.getRank(); i++) {
1999 if (inputType.isDynamicDim(constantPerms[i]) ||
2000 outputType.isDynamicDim(i))
2003 if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
2004 return emitOpError()
2005 <<
"expected output tensor dim " << i <<
" to match "
2006 <<
"input dim " << constantPerms[i] <<
" with value of "
2007 << inputType.getDimSize(constantPerms[i]);
2019 Value input = getInput1();
2020 auto inputType = cast<TensorType>(input.
getType());
2023 for (
auto dim : transposePerms) {
2024 int32_t dimInInput = transposePerms[dim];
2025 if (inputType.isDynamicDim(dimInInput))
2027 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
2031 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2034 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2038 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2039 MLIRContext *context, ::std::optional<Location> location,
2040 GatherOp::Adaptor adaptor,
2043 outputShape.resize(3, ShapedType::kDynamic);
2045 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2046 if (valuesShape.hasRank()) {
2047 outputShape[0] = valuesShape.getDimSize(0);
2048 outputShape[2] = valuesShape.getDimSize(2);
2051 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2052 if (indicesShape.hasRank()) {
2053 if (outputShape[0] == ShapedType::kDynamic)
2054 outputShape[0] = indicesShape.getDimSize(0);
2055 if (outputShape[1] == ShapedType::kDynamic)
2056 outputShape[1] = indicesShape.getDimSize(1);
2068 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2069 MLIRContext *context, ::std::optional<Location> location,
2070 ResizeOp::Adaptor adaptor,
2073 outputShape.resize(4, ShapedType::kDynamic);
2076 if (!inputShape.hasRank())
2079 outputShape[0] = inputShape.getDimSize(0);
2080 outputShape[3] = inputShape.getDimSize(3);
2081 int64_t inputHeight = inputShape.getDimSize(1);
2082 int64_t inputWidth = inputShape.getDimSize(2);
2084 if ((inputHeight == ShapedType::kDynamic) ||
2085 (inputWidth == ShapedType::kDynamic))
2100 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2105 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2114 const Value input = getInput();
2115 const Value output = getOutput();
2116 const RankedTensorType inputType =
2117 llvm::dyn_cast<RankedTensorType>(input.
getType());
2118 const RankedTensorType outputType =
2119 llvm::dyn_cast<RankedTensorType>(output.
getType());
2122 return emitOpError(
"expect a ranked input tensor");
2124 return emitOpError(
"expect a ranked output tensor");
2126 const int64_t oh = outputType.getDimSize(1);
2127 const int64_t ow = outputType.getDimSize(2);
2128 const int64_t ih = inputType.getDimSize(1);
2129 const int64_t iw = inputType.getDimSize(2);
2141 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2142 return emitOpError(
"expect all scale values to be > 0, got ")
2145 const int64_t scaleYN = scaleValues[0];
2146 const int64_t scaleYD = scaleValues[1];
2147 const int64_t scaleXN = scaleValues[2];
2148 const int64_t scaleXD = scaleValues[3];
2150 const int64_t offsetY = offsetValues[0];
2151 const int64_t offsetX = offsetValues[1];
2153 const int64_t borderY = borderValues[0];
2154 const int64_t borderX = borderValues[1];
2160 if (ih != ShapedType::kDynamic && ih != 1) {
2161 const std::optional<int64_t> calculatedOutHeightMinusOne =
2162 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2163 if (!calculatedOutHeightMinusOne.has_value())
2164 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2166 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2167 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2168 <<
") / " << scaleYD;
2169 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2170 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2171 return emitOpError(
"calculated output height did not match expected: ")
2172 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2179 if (iw != ShapedType::kDynamic && iw != 1) {
2180 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2181 const std::optional<int64_t> calculatedOutWidthMinusOne =
2183 if (!calculatedOutWidthMinusOne.has_value())
2184 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2186 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2187 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2188 <<
") / " << scaleXD;
2189 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2190 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2191 return emitOpError(
"calculated output width did not match expected: ")
2192 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2198 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2199 MLIRContext *context, ::std::optional<Location> location,
2200 ScatterOp::Adaptor adaptor,
2203 outputShape.resize(3, ShapedType::kDynamic);
2205 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2206 if (valuesInShape.hasRank()) {
2207 outputShape[0] = valuesInShape.getDimSize(0);
2208 outputShape[1] = valuesInShape.getDimSize(1);
2209 outputShape[2] = valuesInShape.getDimSize(2);
2212 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2213 if (indicesShape.hasRank()) {
2214 if (outputShape[0] == ShapedType::kDynamic)
2215 outputShape[0] = indicesShape.getDimSize(0);
2219 if (inputShape.hasRank()) {
2220 if (outputShape[0] == ShapedType::kDynamic)
2221 outputShape[0] = inputShape.getDimSize(0);
2222 if (outputShape[2] == ShapedType::kDynamic)
2223 outputShape[2] = inputShape.getDimSize(2);
2245 int64_t axisVal = axis.getValue().getSExtValue();
2246 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2252 operandShape.
getDims(outputShape);
2253 outputShape[axisVal] = 1;
2258 #define COMPATIBLE_RETURN_TYPES(OP) \
2259 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2260 if (l.size() != r.size() || l.size() != 1) \
2262 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2264 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2267 #define REDUCE_SHAPE_INFER(OP) \
2268 LogicalResult OP::inferReturnTypeComponents( \
2269 MLIRContext *context, ::std::optional<Location> location, \
2270 OP::Adaptor adaptor, \
2271 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2273 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2274 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2275 const Properties &prop = adaptor.getProperties(); \
2276 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2277 inferredReturnShapes); \
2279 COMPATIBLE_RETURN_TYPES(OP)
2287 #undef REDUCE_SHAPE_INFER
2289 #undef COMPATIBLE_RETURN_TYPES
2291 template <
typename T>
2294 TensorType inputType = op.getInput().getType();
2295 TensorType outputType = op.getOutput().getType();
2296 int32_t reduceAxis = op.getAxis();
2298 if (reduceAxis < 0) {
2299 op.emitOpError(
"reduce axis must not be negative");
2303 int64_t inputRank = inputType.getRank();
2306 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2307 op.emitOpError(
"expect input tensor rank (")
2308 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2314 int64_t outputRank = outputType.getRank();
2315 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2317 "expect output tensor rank to be equal to input tensor rank");
2320 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2321 op.emitOpError(
"expect output tensor rank (")
2322 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2328 if (outputRank != 0) {
2329 auto outputShape = outputType.
getShape();
2330 if (!outputType.isDynamicDim(reduceAxis) &&
2331 outputShape[reduceAxis] != 1) {
2332 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2333 << outputShape[reduceAxis];
2360 #define NARY_SHAPE_INFER(OP) \
2361 LogicalResult OP::inferReturnTypeComponents( \
2362 MLIRContext *context, ::std::optional<Location> location, \
2363 ValueShapeRange operands, DictionaryAttr attributes, \
2364 OpaqueProperties properties, RegionRange regions, \
2365 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2366 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2406 #undef PRED_SHAPE_INFER
2408 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2409 MLIRContext *context, ::std::optional<Location> location,
2410 NegateOp::Adaptor adaptor,
2412 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2419 const Type input1Type = getInput1().getType();
2420 const Type outputType = getOutput().getType();
2427 return emitOpError() <<
"requires the same shape for input1 and output";
2430 const Type input1ZpEType =
2432 if (input1EType != input1ZpEType) {
2433 return emitOpError(
"expect both input1 and its zero point are the same "
2434 "element type, got ")
2435 << input1EType <<
" and " << input1ZpEType;
2438 const Type outputZpEType =
2440 if (outputEType != outputZpEType) {
2441 return emitOpError(
"expect both output and its zero point are the same "
2442 "element type, got ")
2443 << outputEType <<
" and " << outputZpEType;
2446 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2447 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2450 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2451 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2462 outputShape.resize(4, ShapedType::kDynamic);
2477 if (!ShapedType::isDynamic(height)) {
2478 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2479 outputShape[1] = padded / stride[0] + 1;
2482 if (!ShapedType::isDynamic(width)) {
2483 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2484 outputShape[2] = padded / stride[1] + 1;
2491 LogicalResult Conv2DOp::inferReturnTypeComponents(
2492 MLIRContext *context, ::std::optional<Location> location,
2493 Conv2DOp::Adaptor adaptor,
2497 int64_t inputWidth = ShapedType::kDynamic;
2498 int64_t inputHeight = ShapedType::kDynamic;
2499 int64_t weightWidth = ShapedType::kDynamic;
2500 int64_t weightHeight = ShapedType::kDynamic;
2505 if (inputShape.hasRank()) {
2506 outputShape[0] = inputShape.getDimSize(0);
2507 inputHeight = inputShape.getDimSize(1);
2508 inputWidth = inputShape.getDimSize(2);
2512 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2513 if (weightShape.hasRank()) {
2514 outputShape[3] = weightShape.getDimSize(0);
2515 weightHeight = weightShape.getDimSize(1);
2516 weightWidth = weightShape.getDimSize(2);
2521 if (biasShape.hasRank()) {
2522 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2523 ? biasShape.getDimSize(0)
2531 if (!ShapedType::isDynamic(inputHeight) &&
2532 !ShapedType::isDynamic(weightHeight)) {
2533 int64_t inputSize = inputHeight + padding[0] + padding[1];
2534 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2535 int64_t unstridedResult = inputSize - filterSize + 1;
2536 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2539 if (!ShapedType::isDynamic(inputWidth) &&
2540 !ShapedType::isDynamic(weightWidth)) {
2541 int64_t inputSize = inputWidth + padding[2] + padding[3];
2542 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2543 int64_t unstridedResult = inputSize - filterSize + 1;
2544 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2556 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
2557 return emitOpError(
"expect all padding values to be >= 0, got ") << padding;
2560 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
2561 return emitOpError(
"expect all stride values to be >= 1, got ") << strides;
2564 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
2565 return emitOpError(
"expect all dilation values to be >= 1, got ")
2568 const RankedTensorType outputType =
2569 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2574 const RankedTensorType inputType =
2575 llvm::dyn_cast<RankedTensorType>(getInput().
getType());
2576 const RankedTensorType weightType =
2577 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
2579 if (inputType && weightType) {
2580 const auto verifyOutputSize =
2581 [
this](
const int64_t inputSize,
const int64_t kernelSize,
2582 const int64_t outputSize,
const int64_t padBefore,
2583 const int64_t padAfter,
const int64_t stride,
2584 const int64_t dilation,
const llvm::StringRef dimName,
2585 const llvm::StringRef dimAxis,
2586 const llvm::StringRef padBeforeName,
2587 const llvm::StringRef padAfterName) -> LogicalResult {
2588 if (inputSize == ShapedType::kDynamic ||
2589 kernelSize == ShapedType::kDynamic)
2592 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
2593 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2595 if (!calculatedOutSizeMinusOne.has_value())
2596 return emitOpError(
"expected input_")
2597 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
2598 << padAfterName <<
" - (kernel_" << dimName
2599 <<
" - 1) * dilation_" << dimAxis
2600 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
2601 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
2602 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
2605 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2606 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2607 return emitOpError(
"calculated output ")
2608 << dimName <<
" did not match expected: "
2609 <<
"calculated=" << calculatedOutSize
2610 <<
", expected=" << outputSize;
2615 if (failed(verifyOutputSize(
2616 inputType.getDimSize(1), weightType.getDimSize(1),
2617 outputType.getDimSize(1), padding[0], padding[1], strides[0],
2618 dilations[0],
"height",
"y",
"top",
"bottom")))
2621 if (failed(verifyOutputSize(
2622 inputType.getDimSize(2), weightType.getDimSize(2),
2623 outputType.getDimSize(2), padding[2], padding[3], strides[1],
2624 dilations[1],
"width",
"x",
"left",
"right")))
2628 const RankedTensorType biasType =
2629 llvm::dyn_cast<RankedTensorType>(getBias().
getType());
2634 const int64_t biasChannels = biasType.getDimSize(0);
2635 const int64_t outputChannels = outputType.getDimSize(3);
2636 if (biasChannels == ShapedType::kDynamic ||
2637 outputChannels == ShapedType::kDynamic)
2641 if (biasChannels != outputChannels && biasChannels != 1)
2643 "bias channels expected to be equal to output channels (")
2644 << outputChannels <<
") or 1, got " << biasChannels;
2648 LogicalResult Conv3DOp::inferReturnTypeComponents(
2649 MLIRContext *context, ::std::optional<Location> location,
2650 Conv3DOp::Adaptor adaptor,
2654 int64_t inputWidth = ShapedType::kDynamic;
2655 int64_t inputHeight = ShapedType::kDynamic;
2656 int64_t inputDepth = ShapedType::kDynamic;
2658 int64_t weightWidth = ShapedType::kDynamic;
2659 int64_t weightHeight = ShapedType::kDynamic;
2660 int64_t weightDepth = ShapedType::kDynamic;
2664 if (inputShape.hasRank()) {
2665 outputShape[0] = inputShape.getDimSize(0);
2666 inputDepth = inputShape.getDimSize(1);
2667 inputHeight = inputShape.getDimSize(2);
2668 inputWidth = inputShape.getDimSize(3);
2672 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2673 if (weightShape.hasRank()) {
2674 outputShape[4] = weightShape.getDimSize(0);
2675 weightDepth = weightShape.getDimSize(1);
2676 weightHeight = weightShape.getDimSize(2);
2677 weightWidth = weightShape.getDimSize(3);
2682 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2683 outputShape[4] = biasShape.getDimSize(0);
2690 if (!ShapedType::isDynamic(inputDepth) &&
2691 !ShapedType::isDynamic(weightDepth)) {
2692 int32_t inputSize = inputDepth + pad[0] + pad[1];
2693 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2694 int32_t unstridedResult = inputSize - filterSize + 1;
2695 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2698 if (!ShapedType::isDynamic(inputHeight) &&
2699 !ShapedType::isDynamic(weightHeight)) {
2700 int32_t inputSize = inputHeight + pad[2] + pad[3];
2701 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2702 int32_t unstridedResult = inputSize - filterSize + 1;
2703 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2706 if (!ShapedType::isDynamic(inputWidth) &&
2707 !ShapedType::isDynamic(weightWidth)) {
2708 int32_t inputSize = inputWidth + pad[4] + pad[5];
2709 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2710 int32_t unstridedResult = inputSize - filterSize + 1;
2711 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2724 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
2725 MLIRContext *context, ::std::optional<Location> location,
2726 AvgPool2dOp::Adaptor adaptor,
2729 const Properties &prop = adaptor.getProperties();
2731 inferredReturnShapes);
2734 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2735 MLIRContext *context, ::std::optional<Location> location,
2736 MaxPool2dOp::Adaptor adaptor,
2739 const Properties &prop = adaptor.getProperties();
2741 inferredReturnShapes);
2755 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2756 MLIRContext *context, ::std::optional<Location> location,
2757 DepthwiseConv2DOp::Adaptor adaptor,
2761 int64_t inputWidth = ShapedType::kDynamic;
2762 int64_t inputHeight = ShapedType::kDynamic;
2763 int64_t inputChannels = ShapedType::kDynamic;
2765 int64_t weightWidth = ShapedType::kDynamic;
2766 int64_t weightHeight = ShapedType::kDynamic;
2767 int64_t depthChannels = ShapedType::kDynamic;
2771 if (inputShape.hasRank()) {
2772 outputShape[0] = inputShape.getDimSize(0);
2773 inputHeight = inputShape.getDimSize(1);
2774 inputWidth = inputShape.getDimSize(2);
2775 inputChannels = inputShape.getDimSize(3);
2779 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2780 if (weightShape.hasRank()) {
2781 weightHeight = weightShape.getDimSize(0);
2782 weightWidth = weightShape.getDimSize(1);
2783 inputChannels = ShapedType::isDynamic(inputChannels)
2784 ? weightShape.getDimSize(2)
2786 depthChannels = weightShape.getDimSize(3);
2791 if (!ShapedType::isDynamic(inputChannels) &&
2792 !ShapedType::isDynamic(depthChannels)) {
2793 outputShape[3] = inputChannels * depthChannels;
2798 if (biasShape.hasRank()) {
2799 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2800 ? biasShape.getDimSize(0)
2808 if (!ShapedType::isDynamic(inputHeight) &&
2809 !ShapedType::isDynamic(weightHeight)) {
2810 int64_t inputSize = inputHeight + padding[0] + padding[1];
2811 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2812 int64_t unstridedResult = inputSize - filterSize + 1;
2813 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2816 if (!ShapedType::isDynamic(inputWidth) &&
2817 !ShapedType::isDynamic(weightWidth)) {
2818 int64_t inputSize = inputWidth + padding[2] + padding[3];
2819 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2820 int64_t unstridedResult = inputSize - filterSize + 1;
2821 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2834 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2835 MLIRContext *context, ::std::optional<Location> location,
2836 TransposeConv2DOp::Adaptor adaptor,
2840 int64_t inputWidth = ShapedType::kDynamic;
2841 int64_t inputHeight = ShapedType::kDynamic;
2842 int64_t weightWidth = ShapedType::kDynamic;
2843 int64_t weightHeight = ShapedType::kDynamic;
2847 if (inputShape.hasRank()) {
2848 outputShape[0] = ShapedType::isDynamic(outputShape[0])
2849 ? inputShape.getDimSize(0)
2851 inputHeight = inputShape.getDimSize(1);
2852 inputWidth = inputShape.getDimSize(2);
2856 ShapeAdaptor weightShape(adaptor.getWeight().getType());
2857 if (weightShape.hasRank()) {
2858 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2859 ? weightShape.getDimSize(0)
2861 weightHeight = weightShape.getDimSize(1);
2862 weightWidth = weightShape.getDimSize(2);
2867 if (biasShape.hasRank()) {
2868 outputShape[3] = ShapedType::isDynamic(outputShape[3])
2869 ? biasShape.getDimSize(0)
2876 if (!ShapedType::isDynamic(inputHeight) &&
2877 !ShapedType::isDynamic(weightHeight)) {
2878 int64_t calculateSize =
2879 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2881 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2884 if (!ShapedType::isDynamic(inputWidth) &&
2885 !ShapedType::isDynamic(weightWidth)) {
2886 int64_t calculateSize =
2887 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2889 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2903 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
2905 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
2909 auto inputElementType =
2911 if (!mlir::isa<IntegerType>(inputElementType)) {
2912 emitOpError(
"expect input to have integer element type, got ")
2913 << inputElementType;
2917 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
2919 emitOpError(
"expect shaped tensor for output, got ")
2920 << getOutput().getType();
2924 auto outputElementType =
2926 if (!mlir::isa<IntegerType>(outputElementType)) {
2927 emitOpError(
"expect output to have integer element type, got ")
2928 << outputElementType;
2940 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
2941 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
2944 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2945 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2948 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
2949 if (!multiplierType) {
2950 emitOpError(
"expect shaped tensor for multiplier, got ")
2951 << getMultiplier().getType();
2955 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
2957 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
2962 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2963 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
2964 << multiplierType.getElementType();
2969 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2971 "expect i16 element type for multiplier for scale32=false, got ")
2972 << multiplierType.getElementType();
2976 if (!inputType.hasRank())
2982 int64_t numChannels = 1;
2983 if (getPerChannel()) {
2984 numChannels = inputType.getDimSize(inputType.getRank() - 1);
2987 if (!multiplierType.hasRank())
2992 if (multiplierShape[0] != ShapedType::kDynamic &&
2993 multiplierShape[0] != numChannels) {
2994 emitOpError(
"expect shape of { ")
2995 << numChannels <<
" } for multiplier input, got { "
2996 << multiplierShape[0] <<
" }";
3000 if (!shiftType.hasRank())
3005 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3006 emitOpError(
"expect shape of { ")
3007 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3014 LogicalResult RescaleOp::inferReturnTypeComponents(
3015 MLIRContext *context, ::std::optional<Location> location,
3016 RescaleOp::Adaptor adaptor,
3023 LogicalResult IfOp::inferReturnTypeComponents(
3024 MLIRContext *context, ::std::optional<Location> location,
3025 IfOp::Adaptor adaptor,
3028 for (
Region *region : adaptor.getRegions()) {
3029 for (
auto &block : *region)
3030 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3031 yieldOps.push_back(returnOp);
3034 if (yieldOps.empty())
3039 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3040 for (
auto operand : yieldOps.front().getOperands()) {
3041 resultKnowledge.push_back(
3045 for (
auto yieldOp : yieldOps) {
3046 if (resultKnowledge.size() != yieldOp.getNumOperands())
3050 int32_t index = it.index();
3052 resultKnowledge[index],
3056 resultKnowledge[index] = meet;
3061 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3067 LogicalResult WhileOp::inferReturnTypeComponents(
3068 MLIRContext *context, ::std::optional<Location> location,
3069 WhileOp::Adaptor adaptor,
3072 for (
auto &block : adaptor.getBodyGraph())
3073 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3074 yieldOps.push_back(returnOp);
3078 if (yieldOps.empty())
3083 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3084 for (
auto operand : yieldOps.front().getOperands()) {
3085 resultKnowledge.push_back(
3089 for (
auto yieldOp : yieldOps) {
3090 if (resultKnowledge.size() != yieldOp.getNumOperands())
3094 int32_t index = it.index();
3096 resultKnowledge[index],
3098 resultKnowledge[index] = meet;
3104 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3110 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3111 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3112 return llvm::to_vector<4>(vt.getShape());
3113 return std::nullopt;
3150 bool printBlockTerminators =
false;
3152 p <<
" " << getCondition();
3153 if (!getResults().empty()) {
3154 p <<
" -> (" << getResultTypes() <<
")";
3156 printBlockTerminators =
true;
3161 printBlockTerminators);
3164 auto &elseRegion = getElseGraph();
3165 if (!elseRegion.
empty()) {
3169 printBlockTerminators);
3180 TensorType inputType = getInput1().getType();
3181 TensorType outputType = getOutput().getType();
3182 int32_t reverseAxis = getAxis();
3184 if (reverseAxis < 0)
3185 return emitOpError(
"expected non-negative reverse axis");
3187 int64_t inputRank = inputType.getRank();
3190 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3191 return emitOpError(
"expect input tensor rank (")
3192 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3196 int64_t outputRank = outputType.getRank();
3197 if (inputType.
hasRank() && outputRank != inputType.getRank())
3199 "expect output tensor rank to be equal to input tensor rank");
3200 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3201 return emitOpError(
"expect output tensor rank (")
3202 << outputRank <<
") to be larger than reverse axis ("
3203 << reverseAxis <<
")";
3219 auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().
getType());
3220 if (!predicateType) {
3221 return emitOpError(
"expect shaped tensor for input1, got ")
3222 << getInput1().getType();
3224 auto predicateElementType = predicateType.getElementType();
3225 if (!predicateElementType.isInteger(1)) {
3226 return emitOpError(
"expect element type of bool for input1, got ")
3227 << predicateElementType;
3245 FunctionType functionType;
3250 result.
addTypes(functionType.getResults());
3252 if (functionType.getNumInputs() != operands.size()) {
3254 <<
"expected as many input types as operands "
3255 <<
"(expected " << operands.size() <<
" got "
3256 << functionType.getNumInputs() <<
")";
3266 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3267 regionArgs[i].type = functionType.getInput(i);
3269 return failure(parser.
parseRegion(*cond, regionArgs) ||
3277 StringRef prefix =
"") {
3278 assert(blocksArgs.size() == initializers.size() &&
3279 "expected same length of arguments and initializers");
3280 if (initializers.empty())
3283 parser << prefix <<
'(';
3284 llvm::interleaveComma(
3285 llvm::zip(blocksArgs, initializers), parser,
3286 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3292 getInputList(),
" ");
3295 getResults().getTypes());
3310 if (llvm::isa<FloatType>(srcElemType)) {
3312 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3313 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3315 if (llvm::isa<IntegerType>(srcElemType)) {
3318 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3320 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3321 return std::nullopt;
3329 return mlir::isa<tosa::shapeType>(t);
3336 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3342 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3343 Operation *definingOp = v.getDefiningOp();
3345 return op->
emitOpError(
"shape operand is not compile time resolvable");
3354 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3355 return op->
emitOpError(
"must have operands with tosa shape type");
3359 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3360 return op->
emitOpError(
"must have result with tosa shape type");
3373 auto getRank = [](
const Type type) {
3374 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3380 for (
auto type : operandTypes) {
3381 if (getRank(type) != rank) {
3382 return op->
emitOpError(
"operands don't have matching ranks");
3385 for (
auto type : resultTypes) {
3386 if (getRank(type) != rank) {
3387 return op->
emitOpError(
"result shape has different rank than operands");
3399 auto valuesRank = getValues().getType().getRank();
3400 if (valuesRank != 1)
3401 return emitOpError(
"expect elements in attribute values with rank 1");
3403 auto count = getValues().getNumElements();
3404 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
3405 if (!(count == rank || (count == 1 && rank == 0))) {
3406 return emitOpError(
"expect number of elements in attribute values (")
3407 << count <<
") to be equal to the rank (" << rank
3408 <<
") for the result shape type";
3417 #define GET_ATTRDEF_CLASSES
3418 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3423 #define GET_TYPEDEF_CLASSES
3424 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3430 #define GET_OP_CLASSES
3431 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(T op, Value val)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)