17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #include "llvm/ADT/StringExtras.h"
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
46 for (
const auto index : operandIndices) {
49 return op->
emitOpError(
"expected compile time resolvable constant, but "
50 "got variable value for operand #")
57 static LogicalResult checkConstantOperandMul(
Operation *op,
59 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
61 return checkConstantOperands(op, {2});
66 static LogicalResult checkConstantOperandTable(
Operation *op,
68 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
70 return checkConstantOperands(op, {1});
75 static LogicalResult checkConstantOperandPad(
Operation *op,
77 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
79 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
82 return checkConstantOperands(op, {2});
87 static LogicalResult checkConstantOperandRescale(
Operation *op,
89 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
91 return checkConstantOperands(op, {1, 2, 3, 4});
97 static LogicalResult checkConstantOperandConvOps(
Operation *op,
99 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
101 return checkConstantOperands(op, {3, 4});
106 static LogicalResult checkConstantOperandMatMul(
Operation *op,
108 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
110 return checkConstantOperands(op, {2, 3});
115 static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
117 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
119 return checkConstantOperands(op, {1, 2});
124 static LogicalResult checkConstantOperandNegate(
Operation *op,
126 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
128 return checkConstantOperands(op, {1, 2});
134 int32_t MAX_RANK = 0;
135 int32_t MAX_KERNEL = 0;
136 int32_t MAX_STRIDE = 0;
137 int32_t MAX_SCALE = 0;
138 int32_t MAX_LOG2_SIZE = 0;
139 int32_t MAX_NESTING = 0;
140 int32_t MAX_TENSOR_LIST_SIZE = 0;
143 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145 MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146 MAX_NESTING == rhs.MAX_NESTING &&
147 MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
159 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
161 explicit TosaValidation() { populateConstantOperandChecks(); }
163 explicit TosaValidation(
const TosaValidationOptions &
options)
165 this->profile =
options.profile;
166 this->extension =
options.extension;
167 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
168 this->allowInvalidOpDatatypeCombinations =
169 options.allowInvalidOpDatatypeCombinations;
172 void runOnOperation() final;
174 LogicalResult applyConstantOperandCheck(
Operation *op) {
175 for (
auto &checker : constCheckers) {
176 if (failed(checker(op, targetEnv)))
182 LogicalResult applyLevelCheck(
Operation *op);
183 LogicalResult applyAttributeCheck(
Operation *op);
186 LogicalResult applyVariableCheck(
Operation *op);
189 LogicalResult applyErrorIfCheck(
Operation *op);
192 void populateConstantOperandChecks() {
193 constCheckers.emplace_back(checkConstantOperandMul);
194 constCheckers.emplace_back(checkConstantOperandTable);
195 constCheckers.emplace_back(checkConstantOperandPad);
196 constCheckers.emplace_back(checkConstantOperandRescale);
197 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199 constCheckers.emplace_back(
200 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201 constCheckers.emplace_back(
202 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203 constCheckers.emplace_back(checkConstantOperandMatMul);
204 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205 constCheckers.emplace_back(checkConstantOperandNegate);
208 bool levelCheckKernel(
Operation *op, int32_t v,
const StringRef checkDesc) {
209 if (v > tosaLevel.MAX_KERNEL) {
210 op->
emitOpError() <<
"failed level check: " << checkDesc;
216 bool levelCheckStride(
Operation *op, int32_t v,
const StringRef checkDesc) {
217 if (v > tosaLevel.MAX_STRIDE) {
218 op->
emitOpError() <<
"failed level check: " << checkDesc;
224 bool levelCheckScale(
Operation *op, int32_t v,
const StringRef checkDesc) {
225 if (v > tosaLevel.MAX_SCALE) {
226 op->
emitOpError() <<
"failed level check: " << checkDesc;
232 bool levelCheckListSize(
Operation *op, int32_t v,
const StringRef checkDesc) {
233 if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234 op->
emitOpError() <<
"failed level check for MAX_TENSOR_LIST_SIZE: "
243 const StringRef operandOrResult, int32_t highest_rank) {
244 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
245 if (!type.hasRank()) {
246 op->
emitOpError() <<
"failed level check: unranked tensor";
249 if (type.getRank() > highest_rank) {
250 op->
emitOpError() <<
"failed level check: " << operandOrResult
251 <<
" rank(shape) <= MAX_RANK";
260 const StringRef operandOrResult, int32_t highest_rank) {
261 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
266 const StringRef operandOrResult);
270 const StringRef operandOrResult) {
271 return levelCheckSize(op, v.
getType(), operandOrResult);
275 template <
typename T>
276 bool levelCheckSizes(T tosaOp) {
277 auto op = tosaOp.getOperation();
279 if (!levelCheckSize(op, v,
"operand"))
284 if (!levelCheckSize(op, v,
"result"))
291 template <
typename T>
292 bool levelCheckRanks(T tosaOp) {
293 auto op = tosaOp.getOperation();
295 if (!levelCheckRank(op, v,
"operand", tosaLevel.MAX_RANK))
300 if (!levelCheckRank(op, v,
"result", tosaLevel.MAX_RANK))
307 bool levelCheckRanksAndSizes(
Operation *op);
310 template <
typename T>
312 if (
auto poolOp = dyn_cast<T>(op)) {
313 for (
auto k : poolOp.getKernel()) {
314 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
318 for (
auto s : poolOp.getStride()) {
319 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
323 for (
auto p : poolOp.getPad()) {
324 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
333 template <
typename T>
335 if (
auto convOp = dyn_cast<T>(op)) {
337 for (
auto k : convOp.getDilation()) {
338 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
342 for (
auto p : convOp.getPad()) {
343 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
347 for (
auto s : convOp.getStride()) {
348 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
352 auto dilation = convOp.getDilation();
353 if (ShapedType weightType =
355 auto shape = weightType.getShape();
356 if (isa<tosa::Conv2DOp>(op)) {
357 assert(shape.size() == 4);
358 assert(dilation.size() == 2);
359 if (!levelCheckKernel(op, dilation[0] * shape[1],
360 "dilation_y * KH <= MAX_KERNEL)") ||
361 !levelCheckKernel(op, dilation[1] * shape[2],
362 "dilation_x * KW <= MAX_KERNEL)"))
364 }
else if (isa<tosa::Conv3DOp>(op)) {
365 assert(shape.size() == 5);
366 assert(dilation.size() == 3);
367 if (!levelCheckKernel(op, dilation[0] * shape[1],
368 "dilation_d * KD <= MAX_KERNEL)") ||
369 !levelCheckKernel(op, dilation[1] * shape[2],
370 "dilation_y * KH <= MAX_KERNEL)") ||
371 !levelCheckKernel(op, dilation[2] * shape[3],
372 "dilation_x * KW <= MAX_KERNEL)"))
374 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
375 assert(shape.size() == 4);
376 assert(dilation.size() == 2);
377 if (!levelCheckKernel(op, dilation[0] * shape[0],
378 "dilation_y * KH <= MAX_KERNEL)") ||
379 !levelCheckKernel(op, dilation[1] * shape[1],
380 "dilation_x * KW <= MAX_KERNEL)"))
389 template <
typename T>
393 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
394 auto shape = type.getShape();
395 assert(shape.size() == 3);
396 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
397 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
407 bool levelCheckTransposeConv2d(
Operation *op) {
408 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
409 if (ShapedType filterType =
410 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
411 auto shape = filterType.getShape();
412 assert(shape.size() == 4);
414 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
415 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
419 for (
auto p : transpose.getOutPad()) {
420 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
424 for (
auto s : transpose.getStride()) {
425 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
435 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
441 const int64_t scaleYN = scale[0];
442 const int64_t scaleYD = scale[1];
443 const int64_t scaleXN = scale[2];
444 const int64_t scaleXD = scale[3];
445 if (!levelCheckScale(op, scaleYN / scaleYD,
446 "scale_y_n/scale_y_d <= MAX_SCALE") ||
447 !levelCheckScale(op, scaleXN / scaleXD,
448 "scale_x_n/scale_x_d <= MAX_SCALE")) {
459 static void getMaxNestedDepth(
Operation *op, int32_t &depth) {
460 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
468 getMaxNestedDepth(op, depth);
471 bool levelCheckMaxNesting(
Operation *op) {
472 int32_t maxNestedDepth = 0;
473 getMaxNestedDepth(op, maxNestedDepth);
475 if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
476 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
477 <<
" >= MAX_NESTING";
484 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
485 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
487 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
488 if (!levelCheckListSize(op, custom.getInputList().size(),
"input_list") ||
489 !levelCheckListSize(op, custom.getOutputList().size(),
494 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
495 if (!levelCheckListSize(op, condIf.getInputList().size(),
"inputs") ||
496 !levelCheckListSize(op, condIf.getOutputList().size(),
"outputs")) {
500 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
501 if (!levelCheckListSize(op, w.getInputList().size(),
"inputs") ||
502 !levelCheckListSize(op, w.getOutputList().size(),
"outputs")) {
509 bool attributeCheckRescale(
Operation *op) {
510 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
511 if (rescale.getRoundingMode() ==
"DOUBLE_ROUND" &&
512 !targetEnv.allows(Extension::doubleround)) {
514 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
515 <<
"requires extension [doubleround]";
517 }
else if (rescale.getRoundingMode() ==
"INEXACT_ROUND" &&
518 !targetEnv.allows(Extension::inexactround)) {
520 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
521 <<
"requires extension [inexactround]";
530 void configLevelAndProfile() {
531 tosaLevel = TOSA_LEVEL_NONE;
532 if (level == TosaLevelEnum::EightK) {
533 tosaLevel = TOSA_LEVEL_EIGHTK;
536 if (!profile.empty()) {
537 for (std::string &prof : profile) {
538 auto profSymbol = symbolizeProfile(prof);
540 targetEnv.addProfile(profSymbol.value());
542 llvm::errs() <<
"unknown TOSA profile name passed in: " << prof
543 <<
", supported profiles are `pro_int` and `pro_fp`\n";
544 return signalPassFailure();
549 if (!extension.empty()) {
550 for (std::string &ext : extension) {
551 auto extSymbol = symbolizeExtension(ext);
553 targetEnv.addExtension(extSymbol.value());
555 llvm::errs() <<
"unknown TOSA extension name passed in: " << ext
556 <<
", supported extension are int16, int4, bf16, "
557 <<
"fp8e4m3, fp8e5m2, fft, variable, controlflow, "
558 <<
"doubleround, inexactround and dynamic\n";
559 return signalPassFailure();
566 bool CheckVariableReadOrWrite(
Operation *op);
567 bool isValidElementType(
Type type,
const bool allowUnsigned =
false);
579 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
580 auto op = tosaOp.getOperation();
581 if (!levelCheckRank(op, tosaOp.getInput(),
"operand", tosaLevel.MAX_RANK))
585 if (!levelCheckRank(op, tosaOp.getOutput(),
"result", tosaLevel.MAX_RANK - 1))
592 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
593 auto op = tosaOp.getOperation();
596 if (!levelCheckRank(op, tosaOp.getCondition(),
"operand", tosaLevel.MAX_RANK))
603 bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
604 auto op = tosaOp.getOperation();
606 if (!levelCheckRank(op, variableType,
"variable type", tosaLevel.MAX_RANK))
613 bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
614 auto op = tosaOp.getOperation();
616 if (!levelCheckSize(op, variableType,
"variable type"))
622 bool TosaValidation::levelCheckRanksAndSizes(
Operation *op) {
623 #define CHECK_RANKS_AND_SIZES(tosaOp) \
624 if (isa<tosa::tosaOp##Op>(op)) { \
625 if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
627 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
631 #define CHECK_SIZES(tosaOp) \
632 if (isa<tosa::tosaOp##Op>(op)) { \
633 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
735 #undef CHECK_RANKS_AND_SIZES
741 bool TosaValidation::levelCheckSize(
Operation *op,
const Type &typeToCheck,
742 const StringRef operandOrResult) {
743 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
744 if (!type.hasRank()) {
745 op->
emitOpError() <<
"failed level check: unranked tensor";
748 auto shape = type.getShape();
749 for (
auto dim : shape) {
750 if (mlir::ShapedType::isDynamic(dim)) {
751 op->
emitOpError() <<
"failed level check: " << operandOrResult
752 <<
" shape dimension cannot be dynamic";
757 int64_t element_bits = type.getElementTypeBitWidth();
758 int64_t element_bytes =
std::max(INT64_C(1), element_bits / 8);
759 int64_t size = element_bytes * type.getNumElements();
766 const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
767 if (size > max_size) {
769 <<
"failed level check: " << operandOrResult
770 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
777 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
778 if (tosaLevel == TOSA_LEVEL_NONE) {
784 if (!levelCheckRanksAndSizes(op))
788 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
789 !levelCheckConv<tosa::Conv2DOp>(op) ||
790 !levelCheckConv<tosa::Conv3DOp>(op) ||
791 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
792 !levelCheckFFT<tosa::FFT2dOp>(op) ||
793 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
794 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
795 !levelCheckResize(op)) {
800 if (!levelCheckListSize(op)) {
804 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
805 if (!levelCheckMaxNesting(op)) {
813 LogicalResult TosaValidation::applyAttributeCheck(
Operation *op) {
814 if (!attributeCheckRescale(op))
819 inline bool CompatibleTypes(
const mlir::Type &type,
822 return type == declaredType;
825 bool TosaValidation::CheckVariable(
Operation *op) {
826 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
827 mlir::StringAttr nameAttr = variableOp.getNameAttr();
829 if (variablesMap.count(nameAttr)) {
830 op->
emitOpError() <<
"name has already been declared";
834 auto elementType = variableOp.getType();
837 RankedTensorType variableType =
840 variablesMap[nameAttr] = variableType;
846 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
847 if (isa<mlir::tosa::VariableReadOp>(op) ||
848 isa<mlir::tosa::VariableWriteOp>(op)) {
849 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
850 if (!variablesMap.count(nameAttr)) {
855 auto varType = variablesMap[nameAttr];
858 auto type = v.getType();
859 if (!CompatibleTypes(type, varType)) {
860 op->
emitOpError() <<
"operand type does not equal variable type";
866 auto type = v.getType();
867 if (!CompatibleTypes(type, varType)) {
868 op->
emitOpError() <<
"result type does not equal variable type";
877 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
878 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
885 auto resize = dyn_cast<tosa::ResizeOp>(op);
889 const Value input = resize.getInput();
890 const Value output = resize.getOutput();
891 const RankedTensorType inputType =
892 llvm::dyn_cast<RankedTensorType>(input.
getType());
893 const RankedTensorType outputType =
894 llvm::dyn_cast<RankedTensorType>(output.
getType());
896 if (!inputType || !outputType) {
897 op->
emitOpError(
"expect ranked input/output tensor");
903 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
905 outputType.getDimSize(1), outputType.getDimSize(2),
906 inputType.getDimSize(1), inputType.getDimSize(2)};
907 const int64_t *maxDim = llvm::max_element(sizes);
908 if (maxDim != sizes.end() && *maxDim >= 16384) {
909 op->
emitOpError(
"expect input/output height/width dims to be < 16384, ")
910 <<
"got [OH, OW, IH, IW] = " << sizes;
920 const int64_t scaleYN = scale[0];
921 const int64_t scaleYD = scale[1];
922 const int64_t scaleXN = scale[2];
923 const int64_t scaleXD = scale[3];
926 if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
927 op->
emitOpError(
"expect all scale numerator values to be <= (1 << 11), "
929 << scaleYN <<
", scale_x_n=" << scaleXN;
933 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
934 op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
935 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
946 const int64_t offsetY = offset[0];
947 const int64_t offsetX = offset[1];
950 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
952 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
953 << offsetY <<
"/" << scaleYN;
956 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
958 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
959 << offsetX <<
"/" << scaleXN;
963 const int64_t borderY = border[0];
964 const int64_t borderX = border[1];
965 if (borderY < -16 * scaleYN || borderY >= scaleYN) {
967 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
968 << borderY <<
"/" << scaleYN;
971 if (borderX < -16 * scaleXN || borderX >= scaleXN) {
973 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
974 << borderX <<
"/" << scaleXN;
989 const int64_t rhs) -> std::optional<int64_t> {
995 const int64_t oh = outputType.getDimSize(1);
996 const int64_t ow = outputType.getDimSize(2);
997 const int64_t ih = inputType.getDimSize(1);
998 const int64_t iw = inputType.getDimSize(2);
1000 if (ih != ShapedType::kDynamic) {
1001 const std::optional<int64_t> calculatedOutHeightMinusOne =
1002 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1003 if (!calculatedOutHeightMinusOne.has_value()) {
1004 op->
emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
1006 <<
"to be wholly divisible by scale_y_d, got ((" << ih <<
" - 1) * "
1007 << scaleYN <<
" - " << offsetY <<
" + " << borderY <<
") / "
1011 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1012 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
1013 op->
emitOpError(
"calculated output height did not match expected: ")
1014 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1019 if (iw != ShapedType::kDynamic) {
1020 const std::optional<int64_t> calculatedOutWidthMinusOne =
1021 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1022 if (!calculatedOutWidthMinusOne.has_value()) {
1023 op->
emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
1025 <<
"to be wholly divisible by scale_x_d, got ((" << iw <<
" - 1) * "
1026 << scaleXN <<
" - " << offsetX <<
" + " << borderX <<
") / "
1030 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1031 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1032 op->
emitOpError(
"calculated output width did not match expected: ")
1033 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1042 auto mul = dyn_cast<tosa::MulOp>(op);
1048 ElementsAttr shift_elem;
1052 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1054 if (inputElemType.isInteger(32)) {
1056 if (shift < 0 || shift > 63) {
1057 op->
emitOpError() <<
"requires 0 <= shift && shift <= 63, but got: "
1064 op->
emitOpError() <<
"requires shift = 0 for all input data types that "
1065 "are not int32_t, but got: "
1075 auto table = dyn_cast<tosa::TableOp>(op);
1081 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1084 if (tableShape.hasStaticShape()) {
1085 const auto numElements = tableShape.getNumElements();
1086 if (numElements != tableSize) {
1087 op->
emitOpError() <<
"requires table size of " << tableSize <<
", got "
1096 bool checkErrorIfRescale(
Operation *op) {
1097 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1101 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1102 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1103 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1104 !outputType.getElementType().isInteger())
1107 auto inElemType = inputType.getElementType();
1108 auto outElemType = outputType.getElementType();
1109 auto inWidth = inElemType.getIntOrFloatBitWidth();
1110 auto outWidth = outElemType.getIntOrFloatBitWidth();
1112 bool inputUnsigned = rescale.getInputUnsigned();
1113 bool outputUnsigned = rescale.getOutputUnsigned();
1115 bool scale32 = rescale.getScale32();
1116 auto roundingMode = rescale.getRoundingMode();
1119 if (scale32 && inWidth == 48) {
1120 op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1125 if (!scale32 && roundingMode ==
"DOUBLE_ROUND") {
1126 op->
emitOpError() <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1131 if (inputUnsigned && outputUnsigned) {
1132 op->
emitOpError() <<
"input and output cannot be both unsigned.";
1137 if (outWidth == 32 && inputUnsigned) {
1138 op->
emitOpError() <<
"i32 output type is not allowed with unsigned input.";
1143 if (inWidth == 32 && outputUnsigned) {
1144 op->
emitOpError() <<
"i32 input type is not allowed with unsigned output.";
1149 if (inWidth == 48 && outputUnsigned) {
1150 op->
emitOpError() <<
"i48 input type is not allowed with unsigned output.";
1155 if (inWidth == 48 && inputUnsigned) {
1156 op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1161 if (inWidth == 32 && inputUnsigned) {
1162 op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1167 if (outWidth == 32 && outputUnsigned) {
1168 op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1176 auto pad = dyn_cast<tosa::PadOp>(op);
1185 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1186 if (val.getSExtValue() < 0) {
1187 op->
emitOpError() <<
"padding value must all be non-negative, got "
1188 << val.getSExtValue();
1197 static bool isNullaryOperation(
Operation *op) {
1198 if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
1199 isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
1204 bool checkErrorIfCondIf(
Operation *op) {
1205 auto ifOp = dyn_cast<tosa::IfOp>(op);
1218 auto isNullaryRegion = [](
Region ®ion) ->
bool {
1219 bool noLiveInValue =
true;
1220 region.walk([&noLiveInValue](
Operation *op) {
1221 if (!isNullaryOperation(op)) {
1222 noLiveInValue =
false;
1227 return noLiveInValue;
1232 bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
1233 bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
1234 bool isInputListEmpty = ifOp.getInputList().size() == 0;
1236 if ((isInputListEmpty != isThenGraphNullaryRegion) ||
1237 (isInputListEmpty != isElseGraphNullaryRegion)) {
1239 <<
"the current simplified form is not strictly conformant to the "
1240 "spec, please use the generic format\n";
1247 bool checkErrorIfScatter(
Operation *op) {
1248 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1257 auto const indicesType =
1258 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1259 if (!indicesType || !indicesType.hasRank()) {
1265 op->
emitOpError(
"indices values contain duplicates");
1272 LogicalResult TosaValidation::applyErrorIfCheck(
Operation *op) {
1273 if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1274 !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1275 !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1276 !checkErrorIfScatter(op))
1281 bool TosaValidation::isValidElementType(
Type type,
const bool allowUnsigned) {
1282 if (isa<FloatType>(type)) {
1283 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1284 Float8E5M2Type>(type);
1285 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1286 if (intTy.isSignless()) {
1287 switch (intTy.getWidth()) {
1296 }
else if (allowUnsigned && intTy.isUnsigned()) {
1297 switch (intTy.getWidth()) {
1304 }
else if (mlir::isa<tosa::shapeType>(type)) {
1310 void TosaValidation::runOnOperation() {
1311 configLevelAndProfile();
1317 getOperation().walk([&](
Operation *op) {
1326 const bool allowUnsigned =
1327 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1329 auto elementTy = getElementTypeOrSelf(operand);
1330 if (!isValidElementType(elementTy, allowUnsigned)) {
1331 op->emitOpError() <<
"is not profile-aligned: element type "
1332 << elementTy <<
" is not legal";
1333 return signalPassFailure();
1337 auto elementTy = getElementTypeOrSelf(resultTy);
1338 if (!isValidElementType(elementTy, allowUnsigned)) {
1339 op->emitOpError() <<
"is not profile-aligned: element type "
1340 << elementTy <<
" is not legal";
1341 return signalPassFailure();
1345 if (strictOpSpecAlignment &&
1346 failed(profileComp.checkProfile(op, targetEnv)))
1347 return signalPassFailure();
1349 if (strictOpSpecAlignment &&
1350 failed(profileComp.checkExtension(op, targetEnv)))
1351 return signalPassFailure();
1353 if (!allowInvalidOpDatatypeCombinations &&
1354 failed(profileComp.checkInvalid(op)))
1355 return signalPassFailure();
1359 if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
1360 signalPassFailure();
1363 if (failed(applyLevelCheck(op)))
1364 signalPassFailure();
1367 if (failed(applyAttributeCheck(op)))
1368 signalPassFailure();
1371 if (failed(applyVariableCheck(op)))
1372 signalPassFailure();
1375 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1376 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
bool allows(Profile prof) const
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.