17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #include "llvm/ADT/StringExtras.h"
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
46 for (
const auto index : operandIndices) {
49 return op->
emitOpError(
"expected compile time resolvable constant, but "
50 "got variable value for operand #")
57 static LogicalResult checkConstantOperandMul(
Operation *op,
59 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
61 return checkConstantOperands(op, {2});
66 static LogicalResult checkConstantOperandTable(
Operation *op,
68 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
70 return checkConstantOperands(op, {1});
75 static LogicalResult checkConstantOperandPad(
Operation *op,
77 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
79 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
82 return checkConstantOperands(op, {2});
87 static LogicalResult checkConstantOperandRescale(
Operation *op,
89 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
91 return checkConstantOperands(op, {1, 2, 3, 4});
97 static LogicalResult checkConstantOperandConvOps(
Operation *op,
99 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
101 return checkConstantOperands(op, {3, 4});
106 static LogicalResult checkConstantOperandMatMul(
Operation *op,
108 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
110 return checkConstantOperands(op, {2, 3});
115 static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
117 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
119 return checkConstantOperands(op, {1, 2});
124 static LogicalResult checkConstantOperandNegate(
Operation *op,
126 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
128 return checkConstantOperands(op, {1, 2});
134 int32_t MAX_RANK = 0;
135 int32_t MAX_KERNEL = 0;
136 int32_t MAX_STRIDE = 0;
137 int32_t MAX_SCALE = 0;
138 int32_t MAX_LOG2_SIZE = 0;
139 int32_t MAX_NESTING = 0;
140 int32_t MAX_TENSOR_LIST_SIZE = 0;
143 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145 MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146 MAX_NESTING == rhs.MAX_NESTING &&
147 MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
159 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
161 explicit TosaValidation() { populateConstantOperandChecks(); }
163 explicit TosaValidation(
const TosaValidationOptions &
options)
165 this->profile =
options.profile;
166 this->extension =
options.extension;
167 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
168 this->allowInvalidOpDatatypeCombinations =
169 options.allowInvalidOpDatatypeCombinations;
172 void runOnOperation() final;
174 LogicalResult applyConstantOperandCheck(
Operation *op) {
175 for (
auto &checker : constCheckers) {
176 if (
failed(checker(op, targetEnv)))
182 LogicalResult applyLevelCheck(
Operation *op);
183 LogicalResult applyAttributeCheck(
Operation *op);
186 LogicalResult applyVariableCheck(
Operation *op);
189 LogicalResult applyErrorIfCheck(
Operation *op);
192 void populateConstantOperandChecks() {
193 constCheckers.emplace_back(checkConstantOperandMul);
194 constCheckers.emplace_back(checkConstantOperandTable);
195 constCheckers.emplace_back(checkConstantOperandPad);
196 constCheckers.emplace_back(checkConstantOperandRescale);
197 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199 constCheckers.emplace_back(
200 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201 constCheckers.emplace_back(
202 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203 constCheckers.emplace_back(checkConstantOperandMatMul);
204 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205 constCheckers.emplace_back(checkConstantOperandNegate);
208 bool levelCheckKernel(
Operation *op, int32_t v,
const StringRef checkDesc) {
209 if (v > tosaLevel.MAX_KERNEL) {
210 op->
emitOpError() <<
"failed level check: " << checkDesc;
216 bool levelCheckStride(
Operation *op, int32_t v,
const StringRef checkDesc) {
217 if (v > tosaLevel.MAX_STRIDE) {
218 op->
emitOpError() <<
"failed level check: " << checkDesc;
224 bool levelCheckScale(
Operation *op, int32_t v,
const StringRef checkDesc) {
225 if (v > tosaLevel.MAX_SCALE) {
226 op->
emitOpError() <<
"failed level check: " << checkDesc;
232 bool levelCheckListSize(
Operation *op, int32_t v,
const StringRef checkDesc) {
233 if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234 op->
emitOpError() <<
"failed level check for MAX_TENSOR_LIST_SIZE: "
243 const StringRef operandOrResult, int32_t highest_rank) {
244 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
245 if (!type.hasRank()) {
246 op->
emitOpError() <<
"failed level check: unranked tensor";
249 if (type.getRank() > highest_rank) {
250 op->
emitOpError() <<
"failed level check: " << operandOrResult
251 <<
" rank(shape) <= MAX_RANK";
260 const StringRef operandOrResult, int32_t highest_rank) {
261 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
266 const StringRef operandOrResult);
270 const StringRef operandOrResult) {
271 return levelCheckSize(op, v.
getType(), operandOrResult);
275 template <
typename T>
276 bool levelCheckSizes(T tosaOp) {
277 auto op = tosaOp.getOperation();
279 if (!levelCheckSize(op, v,
"operand"))
284 if (!levelCheckSize(op, v,
"result"))
291 template <
typename T>
292 bool levelCheckRanks(T tosaOp) {
293 auto op = tosaOp.getOperation();
295 if (!levelCheckRank(op, v,
"operand", tosaLevel.MAX_RANK))
300 if (!levelCheckRank(op, v,
"result", tosaLevel.MAX_RANK))
307 bool levelCheckRanksAndSizes(
Operation *op);
310 template <
typename T>
312 if (
auto poolOp = dyn_cast<T>(op)) {
313 for (
auto k : poolOp.getKernel()) {
314 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
318 for (
auto s : poolOp.getStride()) {
319 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
323 for (
auto p : poolOp.getPad()) {
324 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
333 template <
typename T>
335 if (
auto convOp = dyn_cast<T>(op)) {
337 for (
auto k : convOp.getDilation()) {
338 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
342 for (
auto p : convOp.getPad()) {
343 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
347 for (
auto s : convOp.getStride()) {
348 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
352 auto dilation = convOp.getDilation();
353 if (ShapedType weightType =
355 auto shape = weightType.getShape();
356 if (isa<tosa::Conv2DOp>(op)) {
357 assert(shape.size() == 4);
358 assert(dilation.size() == 2);
359 if (!levelCheckKernel(op, dilation[0] * shape[1],
360 "dilation_y * KH <= MAX_KERNEL)") ||
361 !levelCheckKernel(op, dilation[1] * shape[2],
362 "dilation_x * KW <= MAX_KERNEL)"))
364 }
else if (isa<tosa::Conv3DOp>(op)) {
365 assert(shape.size() == 5);
366 assert(dilation.size() == 3);
367 if (!levelCheckKernel(op, dilation[0] * shape[1],
368 "dilation_d * KD <= MAX_KERNEL)") ||
369 !levelCheckKernel(op, dilation[1] * shape[2],
370 "dilation_y * KH <= MAX_KERNEL)") ||
371 !levelCheckKernel(op, dilation[2] * shape[3],
372 "dilation_x * KW <= MAX_KERNEL)"))
374 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
375 assert(shape.size() == 4);
376 assert(dilation.size() == 2);
377 if (!levelCheckKernel(op, dilation[0] * shape[0],
378 "dilation_y * KH <= MAX_KERNEL)") ||
379 !levelCheckKernel(op, dilation[1] * shape[1],
380 "dilation_x * KW <= MAX_KERNEL)"))
389 template <
typename T>
393 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
394 auto shape = type.getShape();
395 assert(shape.size() == 3);
396 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
397 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
407 bool levelCheckTransposeConv2d(
Operation *op) {
408 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
409 if (ShapedType filterType =
410 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
411 auto shape = filterType.getShape();
412 assert(shape.size() == 4);
414 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
415 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
419 for (
auto p : transpose.getOutPad()) {
420 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
424 for (
auto s : transpose.getStride()) {
425 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
435 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
441 const int64_t scaleYN = scale[0];
442 const int64_t scaleYD = scale[1];
443 const int64_t scaleXN = scale[2];
444 const int64_t scaleXD = scale[3];
445 if (!levelCheckScale(op, scaleYN / scaleYD,
446 "scale_y_n/scale_y_d <= MAX_SCALE") ||
447 !levelCheckScale(op, scaleXN / scaleXD,
448 "scale_x_n/scale_x_d <= MAX_SCALE")) {
459 static void getMaxNestedDepth(
Operation *op, int32_t &depth) {
460 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
468 getMaxNestedDepth(op, depth);
471 bool levelCheckMaxNesting(
Operation *op) {
472 int32_t maxNestedDepth = 0;
473 getMaxNestedDepth(op, maxNestedDepth);
475 if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
476 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
477 <<
" >= MAX_NESTING";
484 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
485 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
487 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
488 if (!levelCheckListSize(op, custom.getInputList().size(),
"input_list") ||
489 !levelCheckListSize(op, custom.getOutputList().size(),
494 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
495 if (!levelCheckListSize(op, condIf.getInputList().size(),
"inputs") ||
496 !levelCheckListSize(op, condIf.getOutputList().size(),
"outputs")) {
500 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
501 if (!levelCheckListSize(op, w.getInputList().size(),
"inputs") ||
502 !levelCheckListSize(op, w.getOutputList().size(),
"outputs")) {
509 bool attributeCheckRescale(
Operation *op) {
510 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
511 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
512 !targetEnv.allows(Extension::doubleround)) {
514 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
515 <<
"requires extension [doubleround]";
518 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
519 !targetEnv.allows(Extension::inexactround)) {
521 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
522 <<
"requires extension [inexactround]";
531 void configLevelAndProfile() {
532 tosaLevel = TOSA_LEVEL_NONE;
533 if (level == TosaLevelEnum::EightK) {
534 tosaLevel = TOSA_LEVEL_EIGHTK;
537 if (!profile.empty()) {
538 for (std::string &prof : profile) {
539 auto profSymbol = symbolizeProfile(prof);
541 targetEnv.addProfile(profSymbol.value());
543 llvm::errs() <<
"unknown TOSA profile name passed in: " << prof
544 <<
", supported profiles are `pro_int` and `pro_fp`\n";
545 return signalPassFailure();
550 if (!extension.empty()) {
551 for (std::string &ext : extension) {
552 auto extSymbol = symbolizeExtension(ext);
554 targetEnv.addExtension(extSymbol.value());
556 llvm::errs() <<
"unknown TOSA extension name passed in: " << ext
557 <<
", supported extension are int16, int4, bf16, "
558 <<
"fp8e4m3, fp8e5m2, fft, variable, controlflow, "
559 <<
"doubleround, inexactround and dynamic\n";
560 return signalPassFailure();
567 bool CheckVariableReadOrWrite(
Operation *op);
568 bool isValidElementType(
Type type,
const bool allowUnsigned =
false);
580 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
581 auto *op = tosaOp.getOperation();
582 if (!levelCheckRank(op, tosaOp.getInput(),
"operand", tosaLevel.MAX_RANK))
586 if (!levelCheckRank(op, tosaOp.getOutput(),
"result", tosaLevel.MAX_RANK - 1))
593 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
594 auto *op = tosaOp.getOperation();
597 if (!levelCheckRank(op, tosaOp.getCondition(),
"operand", tosaLevel.MAX_RANK))
604 bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
605 auto *op = tosaOp.getOperation();
607 if (!levelCheckRank(op, variableType,
"variable type", tosaLevel.MAX_RANK))
614 bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
615 auto *op = tosaOp.getOperation();
617 if (!levelCheckSize(op, variableType,
"variable type"))
623 bool TosaValidation::levelCheckRanksAndSizes(
Operation *op) {
624 #define CHECK_RANKS_AND_SIZES(tosaOp) \
625 if (isa<tosa::tosaOp##Op>(op)) { \
626 if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
628 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
632 #define CHECK_SIZES(tosaOp) \
633 if (isa<tosa::tosaOp##Op>(op)) { \
634 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
736 #undef CHECK_RANKS_AND_SIZES
742 bool TosaValidation::levelCheckSize(
Operation *op,
const Type &typeToCheck,
743 const StringRef operandOrResult) {
744 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
745 if (!type.hasRank()) {
746 op->
emitOpError() <<
"failed level check: unranked tensor";
749 auto shape = type.getShape();
750 for (
auto dim : shape) {
751 if (mlir::ShapedType::isDynamic(dim)) {
752 op->
emitOpError() <<
"failed level check: " << operandOrResult
753 <<
" shape dimension cannot be dynamic";
758 int64_t element_bits = type.getElementTypeBitWidth();
759 int64_t element_bytes =
std::max(INT64_C(1), element_bits / 8);
760 int64_t size = element_bytes * type.getNumElements();
767 const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
768 if (size > max_size) {
770 <<
"failed level check: " << operandOrResult
771 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
778 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
779 if (tosaLevel == TOSA_LEVEL_NONE) {
785 if (!levelCheckRanksAndSizes(op))
789 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
790 !levelCheckConv<tosa::Conv2DOp>(op) ||
791 !levelCheckConv<tosa::Conv3DOp>(op) ||
792 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
793 !levelCheckFFT<tosa::FFT2dOp>(op) ||
794 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
795 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
796 !levelCheckResize(op)) {
801 if (!levelCheckListSize(op)) {
805 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
806 if (!levelCheckMaxNesting(op)) {
814 LogicalResult TosaValidation::applyAttributeCheck(
Operation *op) {
815 if (!attributeCheckRescale(op))
820 inline bool CompatibleTypes(
const mlir::Type &type,
823 return type == declaredType;
826 bool TosaValidation::CheckVariable(
Operation *op) {
827 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
828 mlir::StringAttr nameAttr = variableOp.getNameAttr();
830 if (variablesMap.count(nameAttr)) {
831 op->
emitOpError() <<
"name has already been declared";
835 auto elementType = variableOp.getType();
838 RankedTensorType variableType =
841 variablesMap[nameAttr] = variableType;
847 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
848 if (isa<mlir::tosa::VariableReadOp>(op) ||
849 isa<mlir::tosa::VariableWriteOp>(op)) {
850 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
851 if (!variablesMap.count(nameAttr)) {
856 auto varType = variablesMap[nameAttr];
859 auto type = v.getType();
860 if (!CompatibleTypes(type, varType)) {
861 op->
emitOpError() <<
"operand type does not equal variable type";
867 auto type = v.getType();
868 if (!CompatibleTypes(type, varType)) {
869 op->
emitOpError() <<
"result type does not equal variable type";
878 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
879 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
886 auto resize = dyn_cast<tosa::ResizeOp>(op);
890 const Value input = resize.getInput();
891 const Value output = resize.getOutput();
892 const RankedTensorType inputType =
893 llvm::dyn_cast<RankedTensorType>(input.
getType());
894 const RankedTensorType outputType =
895 llvm::dyn_cast<RankedTensorType>(output.
getType());
897 if (!inputType || !outputType) {
898 op->
emitOpError(
"expect ranked input/output tensor");
904 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
906 outputType.getDimSize(1), outputType.getDimSize(2),
907 inputType.getDimSize(1), inputType.getDimSize(2)};
908 const int64_t *maxDim = llvm::max_element(sizes);
909 if (maxDim != sizes.end() && *maxDim >= 16384) {
910 op->
emitOpError(
"expect input/output height/width dims to be < 16384, ")
911 <<
"got [OH, OW, IH, IW] = " << sizes;
921 const int64_t scaleYN = scale[0];
922 const int64_t scaleYD = scale[1];
923 const int64_t scaleXN = scale[2];
924 const int64_t scaleXD = scale[3];
927 if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
928 op->
emitOpError(
"expect all scale numerator values to be <= (1 << 11), "
930 << scaleYN <<
", scale_x_n=" << scaleXN;
934 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
935 op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
936 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
947 const int64_t offsetY = offset[0];
948 const int64_t offsetX = offset[1];
951 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
953 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
954 << offsetY <<
"/" << scaleYN;
957 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
959 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
960 << offsetX <<
"/" << scaleXN;
964 const int64_t borderY = border[0];
965 const int64_t borderX = border[1];
966 if (borderY < -16 * scaleYN || borderY >= scaleYN) {
968 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
969 << borderY <<
"/" << scaleYN;
972 if (borderX < -16 * scaleXN || borderX >= scaleXN) {
974 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
975 << borderX <<
"/" << scaleXN;
990 const int64_t rhs) -> std::optional<int64_t> {
996 const int64_t oh = outputType.getDimSize(1);
997 const int64_t ow = outputType.getDimSize(2);
998 const int64_t ih = inputType.getDimSize(1);
999 const int64_t iw = inputType.getDimSize(2);
1001 if (ih != ShapedType::kDynamic) {
1002 const std::optional<int64_t> calculatedOutHeightMinusOne =
1003 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1004 if (!calculatedOutHeightMinusOne.has_value()) {
1005 op->
emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
1007 <<
"to be wholly divisible by scale_y_d, got ((" << ih <<
" - 1) * "
1008 << scaleYN <<
" - " << offsetY <<
" + " << borderY <<
") / "
1012 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1013 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
1014 op->
emitOpError(
"calculated output height did not match expected: ")
1015 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1020 if (iw != ShapedType::kDynamic) {
1021 const std::optional<int64_t> calculatedOutWidthMinusOne =
1022 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1023 if (!calculatedOutWidthMinusOne.has_value()) {
1024 op->
emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
1026 <<
"to be wholly divisible by scale_x_d, got ((" << iw <<
" - 1) * "
1027 << scaleXN <<
" - " << offsetX <<
" + " << borderX <<
") / "
1031 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1032 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1033 op->
emitOpError(
"calculated output width did not match expected: ")
1034 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1043 auto mul = dyn_cast<tosa::MulOp>(op);
1049 ElementsAttr shift_elem;
1053 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1055 if (inputElemType.isInteger(32)) {
1057 if (shift < 0 || shift > 63) {
1058 op->
emitOpError() <<
"requires 0 <= shift && shift <= 63, but got: "
1065 op->
emitOpError() <<
"requires shift = 0 for all input data types that "
1066 "are not int32_t, but got: "
1076 auto table = dyn_cast<tosa::TableOp>(op);
1082 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1085 if (tableShape.hasStaticShape()) {
1086 const auto numElements = tableShape.getNumElements();
1087 if (numElements != tableSize) {
1088 op->
emitOpError() <<
"requires table size of " << tableSize <<
", got "
1097 bool checkErrorIfRescale(
Operation *op) {
1098 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1102 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1103 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1104 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1105 !outputType.getElementType().isInteger())
1108 auto inElemType = inputType.getElementType();
1109 auto outElemType = outputType.getElementType();
1110 auto inWidth = inElemType.getIntOrFloatBitWidth();
1111 auto outWidth = outElemType.getIntOrFloatBitWidth();
1113 bool inputUnsigned = rescale.getInputUnsigned();
1114 bool outputUnsigned = rescale.getOutputUnsigned();
1116 bool scale32 = rescale.getScale32();
1117 auto roundingMode = rescale.getRoundingMode();
1120 if (scale32 && inWidth == 48) {
1121 op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1126 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
1127 op->
emitOpError() <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1132 if (inputUnsigned && outputUnsigned) {
1133 op->
emitOpError() <<
"input and output cannot be both unsigned.";
1138 if (outWidth == 32 && inputUnsigned) {
1139 op->
emitOpError() <<
"i32 output type is not allowed with unsigned input.";
1144 if (inWidth == 32 && outputUnsigned) {
1145 op->
emitOpError() <<
"i32 input type is not allowed with unsigned output.";
1150 if (inWidth == 48 && outputUnsigned) {
1151 op->
emitOpError() <<
"i48 input type is not allowed with unsigned output.";
1156 if (inWidth == 48 && inputUnsigned) {
1157 op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1162 if (inWidth == 32 && inputUnsigned) {
1163 op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1168 if (outWidth == 32 && outputUnsigned) {
1169 op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1177 auto pad = dyn_cast<tosa::PadOp>(op);
1186 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1187 if (val.getSExtValue() < 0) {
1188 op->
emitOpError() <<
"padding value must all be non-negative, got "
1189 << val.getSExtValue();
1198 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1199 Region *operandRegion = operand.getParentRegion();
1200 return operandRegion && region->isAncestor(operandRegion);
1204 static bool isRegionIsolatedFromAbove(
Region ®ionToCheck) {
1205 bool noLiveInValue =
true;
1206 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](
Operation *op) {
1207 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1208 noLiveInValue =
false;
1213 return noLiveInValue;
1216 LogicalResult checkIsolatedRegion(
Operation *op,
Region ®ionToCheck,
1217 StringRef regionName) {
1218 if (isRegionIsolatedFromAbove(regionToCheck))
1221 <<
"is not conformant to the TOSA specification. It requires the '"
1222 << regionName <<
"' region is isolated from above.\n";
1226 bool checkErrorIfCondIf(
Operation *op) {
1227 auto ifOp = dyn_cast<tosa::IfOp>(op);
1260 return failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1261 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else"));
1264 bool checkErrorIfWhileLoop(
Operation *op) {
1265 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1269 return failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1270 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body"));
1273 bool checkErrorIfScatter(
Operation *op) {
1274 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1283 auto const indicesType =
1284 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1285 if (!indicesType || !indicesType.hasRank()) {
1291 op->
emitOpError(
"indices values contain duplicates");
1298 LogicalResult TosaValidation::applyErrorIfCheck(
Operation *op) {
1299 if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1300 !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1301 !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1302 !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
1307 bool TosaValidation::isValidElementType(
Type type,
const bool allowUnsigned) {
1308 if (isa<FloatType>(type)) {
1309 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1310 Float8E5M2Type>(type);
1312 if (
auto intTy = dyn_cast<IntegerType>(type)) {
1313 if (intTy.isSignless()) {
1314 switch (intTy.getWidth()) {
1323 }
else if (allowUnsigned && intTy.isUnsigned()) {
1324 switch (intTy.getWidth()) {
1331 }
else if (mlir::isa<tosa::shapeType>(type)) {
1337 void TosaValidation::runOnOperation() {
1338 configLevelAndProfile();
1344 getOperation().walk([&](
Operation *op) {
1353 const bool allowUnsigned =
1354 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1356 auto elementTy = getElementTypeOrSelf(operand);
1357 if (!isValidElementType(elementTy, allowUnsigned)) {
1358 op->emitOpError() <<
"is not profile-aligned: element type "
1359 << elementTy <<
" is not legal";
1360 return signalPassFailure();
1364 auto elementTy = getElementTypeOrSelf(resultTy);
1365 if (!isValidElementType(elementTy, allowUnsigned)) {
1366 op->emitOpError() <<
"is not profile-aligned: element type "
1367 << elementTy <<
" is not legal";
1368 return signalPassFailure();
1372 if (strictOpSpecAlignment &&
1373 failed(profileComp.checkProfile(op, targetEnv)))
1374 return signalPassFailure();
1376 if (strictOpSpecAlignment &&
1377 failed(profileComp.checkExtension(op, targetEnv)))
1378 return signalPassFailure();
1380 if (!allowInvalidOpDatatypeCombinations &&
1381 failed(profileComp.checkInvalid(op)))
1382 return signalPassFailure();
1386 if (
failed(applyConstantOperandCheck(op)))
1387 signalPassFailure();
1390 if (
failed(applyLevelCheck(op)))
1391 signalPassFailure();
1394 if (
failed(applyAttributeCheck(op)))
1395 signalPassFailure();
1398 if (
failed(applyVariableCheck(op)))
1399 signalPassFailure();
1402 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1403 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
bool allows(Profile prof) const
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.