17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #include "llvm/ADT/StringExtras.h"
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
46 for (
const auto index : operandIndices) {
49 return op->
emitOpError(
"expected compile time resolvable constant, but "
50 "got variable value for operand #")
57 static LogicalResult checkConstantOperandMul(
Operation *op,
59 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
61 return checkConstantOperands(op, {2});
66 static LogicalResult checkConstantOperandTable(
Operation *op,
68 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
70 return checkConstantOperands(op, {1});
75 static LogicalResult checkConstantOperandPad(
Operation *op,
77 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
79 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
82 return checkConstantOperands(op, {2});
87 static LogicalResult checkConstantOperandRescale(
Operation *op,
89 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
91 return checkConstantOperands(op, {1, 2, 3, 4});
97 static LogicalResult checkConstantOperandConvOps(
Operation *op,
99 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
101 return checkConstantOperands(op, {3, 4});
106 static LogicalResult checkConstantOperandMatMul(
Operation *op,
108 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
110 return checkConstantOperands(op, {2, 3});
115 static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
117 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
119 return checkConstantOperands(op, {1, 2});
124 static LogicalResult checkConstantOperandNegate(
Operation *op,
126 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
128 return checkConstantOperands(op, {1, 2});
134 int32_t MAX_RANK = 0;
135 int32_t MAX_KERNEL = 0;
136 int32_t MAX_STRIDE = 0;
137 int32_t MAX_SCALE = 0;
138 int32_t MAX_LOG2_SIZE = 0;
139 int32_t MAX_NESTING = 0;
140 int32_t MAX_TENSOR_LIST_SIZE = 0;
143 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145 MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146 MAX_NESTING == rhs.MAX_NESTING &&
147 MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
159 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
161 explicit TosaValidation() { populateConstantOperandChecks(); }
163 explicit TosaValidation(
const TosaValidationOptions &
options)
165 this->profile =
options.profile;
166 this->extension =
options.extension;
167 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
168 this->allowInvalidOpDatatypeCombinations =
169 options.allowInvalidOpDatatypeCombinations;
172 void runOnOperation() final;
174 LogicalResult applyConstantOperandCheck(
Operation *op) {
175 for (
auto &checker : constCheckers) {
176 if (failed(checker(op, targetEnv)))
182 LogicalResult applyLevelCheck(
Operation *op);
183 LogicalResult applyAttributeCheck(
Operation *op);
186 LogicalResult applyVariableCheck(
Operation *op);
189 LogicalResult applyErrorIfCheck(
Operation *op);
192 void populateConstantOperandChecks() {
193 constCheckers.emplace_back(checkConstantOperandMul);
194 constCheckers.emplace_back(checkConstantOperandTable);
195 constCheckers.emplace_back(checkConstantOperandPad);
196 constCheckers.emplace_back(checkConstantOperandRescale);
197 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199 constCheckers.emplace_back(
200 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201 constCheckers.emplace_back(
202 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203 constCheckers.emplace_back(checkConstantOperandMatMul);
204 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205 constCheckers.emplace_back(checkConstantOperandNegate);
208 bool levelCheckKernel(
Operation *op, int32_t v,
const StringRef checkDesc) {
209 if (v > tosaLevel.MAX_KERNEL) {
210 op->
emitOpError() <<
"failed level check: " << checkDesc;
216 bool levelCheckStride(
Operation *op, int32_t v,
const StringRef checkDesc) {
217 if (v > tosaLevel.MAX_STRIDE) {
218 op->
emitOpError() <<
"failed level check: " << checkDesc;
224 bool levelCheckScale(
Operation *op, int32_t v,
const StringRef checkDesc) {
225 if (v > tosaLevel.MAX_SCALE) {
226 op->
emitOpError() <<
"failed level check: " << checkDesc;
232 bool levelCheckListSize(
Operation *op, int32_t v,
const StringRef checkDesc) {
233 if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234 op->
emitOpError() <<
"failed level check for MAX_TENSOR_LIST_SIZE: "
241 template <
typename T>
242 bool levelCheckRank(
Operation *op,
const T &v,
243 const StringRef operandOrResult, int32_t highest_rank) {
244 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
245 if (!type.hasRank()) {
246 op->
emitOpError() <<
"failed level check: unranked tensor";
249 if (type.getRank() > highest_rank) {
250 op->
emitOpError() <<
"failed level check: " << operandOrResult
251 <<
" rank(shape) <= MAX_RANK";
260 const StringRef operandOrResult);
263 template <
typename T>
264 bool levelCheckSizes(T tosaOp) {
265 auto op = tosaOp.getOperation();
267 if (!levelCheckSize(op, v,
"operand"))
272 if (!levelCheckSize(op, v,
"result"))
279 template <
typename T>
280 bool levelCheckRanks(T tosaOp) {
281 auto op = tosaOp.getOperation();
283 if (!levelCheckRank(op, v,
"operand", tosaLevel.MAX_RANK))
289 if (
auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
290 if (!levelCheckRank(op, elemAttr,
"attribute", tosaLevel.MAX_RANK))
297 if (!levelCheckRank(op, v,
"result", tosaLevel.MAX_RANK))
304 bool levelCheckRanksAndSizes(
Operation *op);
307 template <
typename T>
309 if (
auto poolOp = dyn_cast<T>(op)) {
310 for (
auto k : poolOp.getKernel()) {
311 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
315 for (
auto s : poolOp.getStride()) {
316 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
320 for (
auto p : poolOp.getPad()) {
321 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
330 template <
typename T>
332 if (
auto convOp = dyn_cast<T>(op)) {
334 for (
auto k : convOp.getDilation()) {
335 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
339 for (
auto p : convOp.getPad()) {
340 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
344 for (
auto s : convOp.getStride()) {
345 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
349 auto dilation = convOp.getDilation();
350 if (ShapedType weightType =
352 auto shape = weightType.getShape();
353 if (isa<tosa::Conv2DOp>(op)) {
354 assert(shape.size() == 4);
355 assert(dilation.size() == 2);
356 if (!levelCheckKernel(op, dilation[0] * shape[1],
357 "dilation_y * KH <= MAX_KERNEL)") ||
358 !levelCheckKernel(op, dilation[1] * shape[2],
359 "dilation_x * KW <= MAX_KERNEL)"))
361 }
else if (isa<tosa::Conv3DOp>(op)) {
362 assert(shape.size() == 5);
363 assert(dilation.size() == 3);
364 if (!levelCheckKernel(op, dilation[0] * shape[1],
365 "dilation_d * KD <= MAX_KERNEL)") ||
366 !levelCheckKernel(op, dilation[1] * shape[2],
367 "dilation_y * KH <= MAX_KERNEL)") ||
368 !levelCheckKernel(op, dilation[2] * shape[3],
369 "dilation_x * KW <= MAX_KERNEL)"))
371 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
372 assert(shape.size() == 4);
373 assert(dilation.size() == 2);
374 if (!levelCheckKernel(op, dilation[0] * shape[0],
375 "dilation_y * KH <= MAX_KERNEL)") ||
376 !levelCheckKernel(op, dilation[1] * shape[1],
377 "dilation_x * KW <= MAX_KERNEL)"))
386 template <
typename T>
390 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
391 auto shape = type.getShape();
392 assert(shape.size() == 3);
393 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
394 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
404 bool levelCheckTransposeConv2d(
Operation *op) {
405 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
406 if (ShapedType filterType =
407 dyn_cast<ShapedType>(
transpose.getWeight().getType())) {
408 auto shape = filterType.getShape();
409 assert(shape.size() == 4);
411 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
412 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
417 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
422 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
432 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
438 const int64_t scaleYN = scale[0];
439 const int64_t scaleYD = scale[1];
440 const int64_t scaleXN = scale[2];
441 const int64_t scaleXD = scale[3];
442 if (!levelCheckScale(op, scaleYN / scaleYD,
443 "scale_y_n/scale_y_d <= MAX_SCALE") ||
444 !levelCheckScale(op, scaleXN / scaleXD,
445 "scale_x_n/scale_x_d <= MAX_SCALE")) {
456 static void getMaxNestedDepth(
Operation *op, int32_t &depth) {
457 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
465 getMaxNestedDepth(op, depth);
469 bool levelCheckMaxNesting(
Operation *op) {
470 int32_t maxNestedDepth = 0;
471 getMaxNestedDepth(op, maxNestedDepth);
473 if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
474 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
475 <<
" >= MAX_NESTING";
482 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
483 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
485 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
486 if (!levelCheckListSize(op, custom.getInputList().size(),
"input_list") ||
487 !levelCheckListSize(op, custom.getOutputList().size(),
492 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
493 if (!levelCheckListSize(op, condIf.getInputList().size(),
"inputs") ||
494 !levelCheckListSize(op, condIf.getOutputList().size(),
"outputs")) {
498 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
499 if (!levelCheckListSize(op, w.getInputList().size(),
"inputs") ||
500 !levelCheckListSize(op, w.getOutputList().size(),
"outputs")) {
507 bool attributeCheckRescale(
Operation *op) {
508 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
509 if (rescale.getRoundingMode() ==
"DOUBLE_ROUND" &&
510 !targetEnv.allows(Extension::doubleround)) {
512 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
513 <<
"requires extension [doubleround]";
515 }
else if (rescale.getRoundingMode() ==
"INEXACT_ROUND" &&
516 !targetEnv.allows(Extension::inexactround)) {
518 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
519 <<
"requires extension [inexactround]";
528 void configLevelAndProfile() {
529 tosaLevel = TOSA_LEVEL_NONE;
530 if (level == TosaLevelEnum::EightK) {
531 tosaLevel = TOSA_LEVEL_EIGHTK;
534 if (!profile.empty()) {
535 for (std::string &prof : profile) {
536 auto profSymbol = symbolizeProfile(prof);
538 targetEnv.addProfile(profSymbol.value());
540 llvm::errs() <<
"unknown TOSA profile name passed in: " << prof
541 <<
", supported profiles are `pro_int` and `pro_fp`\n";
542 return signalPassFailure();
547 if (!extension.empty()) {
548 for (std::string &ext : extension) {
549 auto extSymbol = symbolizeExtension(ext);
551 targetEnv.addExtension(extSymbol.value());
553 llvm::errs() <<
"unknown TOSA extension name passed in: " << ext
554 <<
", supported extension are int16, int4, bf16, "
555 <<
"fp8e4m3, fp8e5m2, fft, variable, controlflow, "
556 <<
"doubleround, inexactround and dynamic\n";
557 return signalPassFailure();
564 bool CheckVariableReadOrWrite(
Operation *op);
565 bool isValidElementType(
Type type);
577 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
578 auto op = tosaOp.getOperation();
579 if (!levelCheckRank(op, tosaOp.getInput(),
"operand", tosaLevel.MAX_RANK))
583 if (!levelCheckRank(op, tosaOp.getOutput(),
"result", tosaLevel.MAX_RANK - 1))
590 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
591 auto op = tosaOp.getOperation();
594 if (!levelCheckRank(op, tosaOp.getCondition(),
"operand", tosaLevel.MAX_RANK))
600 bool TosaValidation::levelCheckRanksAndSizes(
Operation *op) {
601 #define CHECK_RANKS_AND_SIZES(tosaOp) \
602 if (isa<tosa::tosaOp##Op>(op)) { \
603 if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
605 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
609 #define CHECK_SIZES(tosaOp) \
610 if (isa<tosa::tosaOp##Op>(op)) { \
611 if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
713 #undef CHECK_RANKS_AND_SIZES
719 bool TosaValidation::levelCheckSize(
Operation *op,
const Value &v,
720 const StringRef operandOrResult) {
721 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
722 if (!type.hasRank()) {
723 op->
emitOpError() <<
"failed level check: unranked tensor";
726 auto shape = type.getShape();
727 for (
auto dim : shape) {
728 if (mlir::ShapedType::isDynamic(dim)) {
729 op->
emitOpError() <<
"failed level check: " << operandOrResult
730 <<
" shape dimension cannot be dynamic";
735 int64_t element_bits = type.getElementTypeBitWidth();
736 int64_t element_bytes =
std::max(INT64_C(1), element_bits / 8);
737 int64_t size = element_bytes * type.getNumElements();
744 const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
745 if (size > max_size) {
747 <<
"failed level check: " << operandOrResult
748 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
755 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
756 if (tosaLevel == TOSA_LEVEL_NONE) {
762 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
763 !levelCheckConv<tosa::Conv2DOp>(op) ||
764 !levelCheckConv<tosa::Conv3DOp>(op) ||
765 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
766 !levelCheckFFT<tosa::FFT2dOp>(op) ||
767 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
768 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
769 !levelCheckResize(op)) {
773 if (!levelCheckRanksAndSizes(op)) {
778 if (!levelCheckListSize(op)) {
782 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
783 if (!levelCheckMaxNesting(op)) {
791 LogicalResult TosaValidation::applyAttributeCheck(
Operation *op) {
792 if (!attributeCheckRescale(op))
797 inline bool CompatibleTypes(
const mlir::Type &type,
800 return type == declaredType;
803 bool TosaValidation::CheckVariable(
Operation *op) {
804 if (isa<mlir::tosa::VariableOp>(op)) {
805 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
807 if (variablesMap.count(nameAttr)) {
808 op->
emitOpError() <<
"name has already been declared";
812 auto typeAttr = cast<mlir::TypeAttr>(op->
getAttr(
"type"));
815 variablesMap[nameAttr] = type;
821 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
822 if (isa<mlir::tosa::VariableReadOp>(op) ||
823 isa<mlir::tosa::VariableWriteOp>(op)) {
824 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
825 if (!variablesMap.count(nameAttr)) {
830 auto varType = variablesMap[nameAttr];
834 if (!CompatibleTypes(type, varType)) {
835 op->
emitOpError() <<
"operand type does not equal variable type";
842 if (!CompatibleTypes(type, varType)) {
843 op->
emitOpError() <<
"result type does not equal variable type";
852 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
853 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
860 auto resize = dyn_cast<tosa::ResizeOp>(op);
864 const Value input = resize.getInput();
865 const Value output = resize.getOutput();
866 const RankedTensorType inputType =
867 llvm::dyn_cast<RankedTensorType>(input.
getType());
868 const RankedTensorType outputType =
869 llvm::dyn_cast<RankedTensorType>(output.
getType());
871 if (!inputType || !outputType) {
872 op->
emitOpError(
"expect ranked input/output tensor");
878 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
880 outputType.getDimSize(1), outputType.getDimSize(2),
881 inputType.getDimSize(1), inputType.getDimSize(2)};
882 const int64_t *maxDim = llvm::max_element(sizes);
883 if (maxDim != sizes.end() && *maxDim >= 16384) {
884 op->
emitOpError(
"expect input/output height/width dims to be < 16384, ")
885 <<
"got [OH, OW, IH, IW] = " << sizes;
895 const int64_t scaleYN = scale[0];
896 const int64_t scaleYD = scale[1];
897 const int64_t scaleXN = scale[2];
898 const int64_t scaleXD = scale[3];
901 if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
902 op->
emitOpError(
"expect all scale numerator values to be <= (1 << 11), "
904 << scaleYN <<
", scale_x_n=" << scaleXN;
908 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
909 op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
910 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
921 const int64_t offsetY = offset[0];
922 const int64_t offsetX = offset[1];
925 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
927 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
928 << offsetY <<
"/" << scaleYN;
931 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
933 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
934 << offsetX <<
"/" << scaleXN;
938 const int64_t borderY = border[0];
939 const int64_t borderX = border[1];
940 if (borderY < -16 * scaleYN || borderY >= scaleYN) {
942 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
943 << borderY <<
"/" << scaleYN;
946 if (borderX < -16 * scaleXN || borderX >= scaleXN) {
948 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
949 << borderX <<
"/" << scaleXN;
964 const int64_t rhs) -> std::optional<int64_t> {
970 const int64_t oh = outputType.getDimSize(1);
971 const int64_t ow = outputType.getDimSize(2);
972 const int64_t ih = inputType.getDimSize(1);
973 const int64_t iw = inputType.getDimSize(2);
975 if (ih != ShapedType::kDynamic) {
976 const std::optional<int64_t> calculatedOutHeightMinusOne =
977 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
978 if (!calculatedOutHeightMinusOne.has_value()) {
979 op->
emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
981 <<
"to be wholly divisible by scale_y_d, got ((" << ih <<
" - 1) * "
982 << scaleYN <<
" - " << offsetY <<
" + " << borderY <<
") / "
986 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
987 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
988 op->
emitOpError(
"calculated output height did not match expected: ")
989 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
994 if (iw != ShapedType::kDynamic) {
995 const std::optional<int64_t> calculatedOutWidthMinusOne =
996 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
997 if (!calculatedOutWidthMinusOne.has_value()) {
998 op->
emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
1000 <<
"to be wholly divisible by scale_x_d, got ((" << iw <<
" - 1) * "
1001 << scaleXN <<
" - " << offsetX <<
" + " << borderX <<
") / "
1005 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1006 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1007 op->
emitOpError(
"calculated output width did not match expected: ")
1008 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1017 auto mul = dyn_cast<tosa::MulOp>(op);
1023 ElementsAttr shift_elem;
1027 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1029 if (inputElemType.isInteger(32)) {
1031 if (shift < 0 || shift > 63) {
1032 op->
emitOpError() <<
"requires 0 <= shift && shift <= 63, but got: "
1039 op->
emitOpError() <<
"requires shift = 0 for all input data types that "
1040 "are not int32_t, but got: "
1050 auto table = dyn_cast<tosa::TableOp>(op);
1056 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1059 if (tableShape.hasStaticShape()) {
1060 const auto numElements = tableShape.getNumElements();
1061 if (numElements != tableSize) {
1062 op->
emitOpError() <<
"requires table size of " << tableSize <<
", got "
1071 bool checkErrorIfRescale(
Operation *op) {
1072 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1076 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1077 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1078 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1079 !outputType.getElementType().isInteger())
1082 auto inElemType = inputType.getElementType();
1083 auto outElemType = outputType.getElementType();
1084 auto inWidth = inElemType.getIntOrFloatBitWidth();
1085 auto outWidth = outElemType.getIntOrFloatBitWidth();
1087 bool inputUnsigned = rescale.getInputUnsigned();
1088 bool outputUnsigned = rescale.getOutputUnsigned();
1090 bool scale32 = rescale.getScale32();
1091 auto roundingMode = rescale.getRoundingMode();
1094 if (scale32 && inWidth == 48) {
1095 op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1100 if (!scale32 && roundingMode ==
"DOUBLE_ROUND") {
1101 op->
emitOpError() <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1106 if (inputUnsigned && outputUnsigned) {
1107 op->
emitOpError() <<
"input and output cannot be both unsigned.";
1112 if (outWidth == 32 && inputUnsigned) {
1113 op->
emitOpError() <<
"i32 output type is not allowed with unsigned input.";
1118 if (inWidth == 32 && outputUnsigned) {
1119 op->
emitOpError() <<
"i32 input type is not allowed with unsigned output.";
1124 if (inWidth == 48 && outputUnsigned) {
1125 op->
emitOpError() <<
"i48 input type is not allowed with unsigned output.";
1130 if (inWidth == 48 && inputUnsigned) {
1131 op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1136 if (inWidth == 32 && inputUnsigned) {
1137 op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1142 if (outWidth == 32 && outputUnsigned) {
1143 op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1151 auto pad = dyn_cast<tosa::PadOp>(op);
1160 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1161 if (val.getSExtValue() < 0) {
1162 op->
emitOpError() <<
"padding value must all be non-negative, got "
1163 << val.getSExtValue();
1171 LogicalResult TosaValidation::applyErrorIfCheck(
Operation *op) {
1172 if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1173 !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1174 !checkErrorIfPad(op))
1179 bool TosaValidation::isValidElementType(
Type type) {
1180 if (isa<FloatType>(type)) {
1181 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1182 Float8E5M2Type>(type);
1183 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1184 if (intTy.isSignless()) {
1185 switch (intTy.getWidth()) {
1195 }
else if (mlir::isa<tosa::shapeType>(type)) {
1201 void TosaValidation::runOnOperation() {
1202 configLevelAndProfile();
1208 getOperation().walk([&](
Operation *op) {
1215 auto elementTy = getElementTypeOrSelf(operand);
1216 if (!isValidElementType(elementTy)) {
1217 op->emitOpError() <<
"is not profile-aligned: element type "
1218 << elementTy <<
" is not legal";
1219 return signalPassFailure();
1223 auto elementTy = getElementTypeOrSelf(resultTy);
1224 if (!isValidElementType(elementTy)) {
1225 op->emitOpError() <<
"is not profile-aligned: element type "
1226 << elementTy <<
" is not legal";
1227 return signalPassFailure();
1231 if (strictOpSpecAlignment &&
1232 failed(profileComp.checkProfile(op, targetEnv)))
1233 return signalPassFailure();
1235 if (strictOpSpecAlignment &&
1236 failed(profileComp.checkExtension(op, targetEnv)))
1237 return signalPassFailure();
1239 if (!allowInvalidOpDatatypeCombinations &&
1240 failed(profileComp.checkInvalid(op))) {
1241 op->emitOpError(
"illegal: operand/result data types not supported");
1242 return signalPassFailure();
1247 if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
1248 signalPassFailure();
1251 if (failed(applyLevelCheck(op)))
1252 signalPassFailure();
1255 if (failed(applyAttributeCheck(op)))
1256 signalPassFailure();
1259 if (failed(applyVariableCheck(op)))
1260 signalPassFailure();
1263 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1264 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This class represents the capability enabled in the target implementation such as profile,...
bool allows(Profile prof) const
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.