29#include "llvm/ADT/StringExtras.h"
30#include "llvm/Support/FormatVariadic.h"
34#define GEN_PASS_DEF_TOSAVALIDATION
35#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
46 for (
const auto index : operandIndices) {
49 return op->
emitOpError(
"expected compile time resolvable constant, but "
50 "got variable value for operand #")
57static LogicalResult checkConstantOperandMul(
Operation *op,
59 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
61 return checkConstantOperands(op, {2});
66static LogicalResult checkConstantOperandTable(
Operation *op,
68 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
70 return checkConstantOperands(op, {1});
75static LogicalResult checkConstantOperandPad(
Operation *op,
77 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
79 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
82 return checkConstantOperands(op, {2});
87static LogicalResult checkConstantOperandRescale(
Operation *op,
89 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
91 return checkConstantOperands(op, {1, 2, 3, 4});
97static LogicalResult checkConstantOperandConvOps(
Operation *op,
99 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
101 return checkConstantOperands(op, {3, 4});
106static LogicalResult checkConstantOperandMatMul(
Operation *op,
108 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
110 return checkConstantOperands(op, {2, 3});
115static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
117 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
119 return checkConstantOperands(op, {1, 2});
124static LogicalResult checkConstantOperandNegate(
Operation *op,
126 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
128 return checkConstantOperands(op, {1, 2});
133static LogicalResult checkConstantOperandSilceShape(
Operation *op,
135 if (!env.
allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
137 return checkConstantOperands(op, {1, 2});
148 explicit TosaValidation() { populateConstantOperandChecks(); }
150 explicit TosaValidation(
const TosaValidationOptions &
options)
152 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
153 this->allowInvalidOpDatatypeCombinations =
154 options.allowInvalidOpDatatypeCombinations;
156 void runOnOperation() final;
158 LogicalResult applyConstantOperandCheck(Operation *op) {
159 for (
auto &checker : constCheckers) {
160 if (
failed(checker(op, targetEnv)))
166 LogicalResult applyLevelCheck(Operation *op);
167 LogicalResult applyAttributeCheck(Operation *op);
170 LogicalResult applyVariableCheck(Operation *op);
173 LogicalResult applyErrorIfCheck(Operation *op);
176 void populateConstantOperandChecks() {
177 constCheckers.emplace_back(checkConstantOperandMul);
178 constCheckers.emplace_back(checkConstantOperandTable);
179 constCheckers.emplace_back(checkConstantOperandPad);
180 constCheckers.emplace_back(checkConstantOperandRescale);
181 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
182 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
183 constCheckers.emplace_back(
184 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
185 constCheckers.emplace_back(
186 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
187 constCheckers.emplace_back(checkConstantOperandMatMul);
188 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
189 constCheckers.emplace_back(checkConstantOperandNegate);
190 constCheckers.emplace_back(checkConstantOperandSilceShape);
193 LogicalResult levelCheck(Operation *op,
const int32_t calculatedValue,
194 const int32_t maxLevel,
const StringRef inputName,
195 const StringRef levelName) {
196 if (calculatedValue > maxLevel)
198 <<
"failed level check: " << inputName <<
" <= " << levelName
199 <<
" (" << maxLevel <<
"), got " << calculatedValue;
203 LogicalResult levelCheckKernel(Operation *op, int32_t v,
204 const StringRef inputName) {
205 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
209 LogicalResult levelCheckStride(Operation *op, int32_t v,
210 const StringRef inputName) {
211 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
215 LogicalResult levelCheckScale(Operation *op, int32_t v,
216 const StringRef inputName) {
217 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
221 LogicalResult levelCheckListSize(Operation *op, int32_t v,
222 const StringRef inputName) {
223 const std::string inputDesc =
224 llvm::formatv(
"length(tensor_list_shape({0}))", inputName);
225 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
226 inputDesc,
"MAX_TENSOR_LIST_SIZE");
230 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
231 const StringRef operandOrResult,
232 int32_t highest_rank) {
233 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
235 return op->
emitOpError() <<
"failed level check: unranked tensor";
236 if (type.getRank() > highest_rank)
237 return op->
emitOpError() <<
"failed level check: " << operandOrResult
238 <<
" rank(shape) <= MAX_RANK";
244 LogicalResult levelCheckRank(Operation *op,
const Value &v,
245 const StringRef operandOrResult,
246 int32_t highest_rank) {
247 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
251 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
252 const StringRef operandOrResult);
255 LogicalResult levelCheckSize(Operation *op,
const Value &v,
256 const StringRef operandOrResult) {
257 return levelCheckSize(op, v.
getType(), operandOrResult);
261 LogicalResult levelCheckShapeLength(Operation *op,
const Type typeToCheck,
262 const StringRef operandOrResult) {
263 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
264 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
266 <<
"failed shape type level check: " << typeToCheck
267 <<
" exceeds MAX_SHAPE_LEN";
273 template <
typename T>
274 LogicalResult levelCheckSizes(T tosaOp) {
275 auto op = tosaOp.getOperation();
277 if (
failed(levelCheckSize(op, v,
"operand")))
282 if (
failed(levelCheckSize(op, v,
"result")))
289 template <
typename T>
290 LogicalResult levelCheckRanks(T tosaOp) {
291 auto op = tosaOp.getOperation();
292 const TosaLevel tosaLevel = targetEnv.getLevel();
306 template <
typename T>
307 LogicalResult levelCheckShapeLengths(T tosaOp) {
308 for (
const auto &v : tosaOp->getOperands()) {
309 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"operand")))
312 for (
const auto &v : tosaOp->getResults()) {
313 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"result")))
321 LogicalResult levelCheckRanksAndSizes(Operation *op);
324 template <
typename T>
325 LogicalResult levelCheckPool(Operation *op) {
326 if (
auto poolOp = dyn_cast<T>(op)) {
327 for (
auto k : poolOp.getKernel()) {
328 if (
failed(levelCheckKernel(op, k,
"kernel"))) {
332 for (
auto s : poolOp.getStride()) {
333 if (
failed(levelCheckStride(op, s,
"stride"))) {
337 for (
auto p : poolOp.getPad()) {
338 if (
failed(levelCheckKernel(op, p,
"pad"))) {
347 template <
typename T>
348 LogicalResult levelCheckConv(Operation *op) {
349 if (
auto convOp = dyn_cast<T>(op)) {
351 for (
auto k : convOp.getDilation()) {
352 if (
failed(levelCheckKernel(op, k,
"dilation"))) {
356 for (
auto p : convOp.getPad()) {
357 if (
failed(levelCheckKernel(op, p,
"pad"))) {
361 for (
auto s : convOp.getStride()) {
362 if (
failed(levelCheckStride(op, s,
"stride"))) {
366 auto dilation = convOp.getDilation();
367 if (ShapedType weightType =
369 auto shape = weightType.getShape();
370 if (isa<tosa::Conv2DOp>(op)) {
371 assert(shape.size() == 4);
372 assert(dilation.size() == 2);
373 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
374 "dilation_y * KH")) ||
375 failed(levelCheckKernel(op, dilation[1] * shape[2],
378 }
else if (isa<tosa::Conv3DOp>(op)) {
379 assert(shape.size() == 5);
380 assert(dilation.size() == 3);
381 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
382 "dilation_d * KD")) ||
383 failed(levelCheckKernel(op, dilation[1] * shape[2],
384 "dilation_y * KH")) ||
385 failed(levelCheckKernel(op, dilation[2] * shape[3],
388 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
389 assert(shape.size() == 4);
390 assert(dilation.size() == 2);
391 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
392 "dilation_y * KH")) ||
393 failed(levelCheckKernel(op, dilation[1] * shape[1],
402 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
403 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
407 SmallVector<int64_t> padValues;
409 for (
const auto p : padValues)
410 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL")))
414 SmallVector<int64_t> strideValues;
417 for (
const auto s : strideValues)
418 if (
failed(levelCheckKernel(op, s,
"stride <= MAX_KERNEL")))
422 SmallVector<int64_t> dilationValues;
425 int64_t KH = ShapedType::kDynamic;
426 int64_t KW = ShapedType::kDynamic;
427 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
428 KH = weightDataShape.getDimSize(1);
429 KW = weightDataShape.getDimSize(2);
430 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
431 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
432 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
434 if (!ShapedType::isDynamic(KH) &&
435 failed(levelCheckKernel(op, dilationValues[0] * KH,
436 "dilation_y * KH <= MAX_KERNEL)")))
439 if (!ShapedType::isDynamic(KW) &&
440 failed(levelCheckKernel(op, dilationValues[1] * KW,
441 "dilation_x * KW <= MAX_KERNEL)")))
449 template <
typename T>
450 LogicalResult levelCheckFFT(Operation *op) {
453 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
454 auto shape = type.getShape();
455 assert(shape.size() == 3);
456 if (
failed(levelCheckKernel(op, shape[1],
"H")) ||
457 failed(levelCheckKernel(op, shape[2],
"W"))) {
467 LogicalResult levelCheckTransposeConv2d(Operation *op) {
468 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
469 if (ShapedType filterType =
470 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
471 auto shape = filterType.getShape();
472 assert(shape.size() == 4);
474 if (
failed(levelCheckKernel(op, shape[1],
"KH")) ||
475 failed(levelCheckKernel(op, shape[2],
"KW"))) {
479 for (
auto p : transpose.getOutPad()) {
480 if (
failed(levelCheckKernel(op, p,
"pad"))) {
484 for (
auto s : transpose.getStride()) {
485 if (
failed(levelCheckStride(op, s,
"stride"))) {
494 LogicalResult levelCheckResize(Operation *op) {
495 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
496 SmallVector<int64_t> scale;
501 const int64_t scaleYN = scale[0];
502 const int64_t scaleYD = scale[1];
503 const int64_t scaleXN = scale[2];
504 const int64_t scaleXD = scale[3];
506 levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d")) ||
508 levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d"))) {
519 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
520 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
528 getMaxNestedDepth(op, depth);
531 LogicalResult levelCheckMaxNesting(Operation *op) {
532 int32_t maxNestedDepth = 0;
533 getMaxNestedDepth(op, maxNestedDepth);
535 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
536 if (maxNestedDepth >= maxNestingLevel)
538 <<
"failed level check: tosa_nesting_depth < MAX_NESTING" <<
" ("
539 << maxNestingLevel <<
"), got " << maxNestedDepth;
543 LogicalResult levelCheckListSize(Operation *op) {
544 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
545 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
547 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
548 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
550 failed(levelCheckListSize(op, custom.getOutputList().size(),
555 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
557 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
558 failed(levelCheckListSize(op, condIf.getOutputList().size(),
563 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
564 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
565 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
569 if (
auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
570 return levelCheckListSize(op, concat_shape.getInput().size(),
"input");
574 LogicalResult attributeCheckRescale(Operation *op) {
575 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
576 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
577 !targetEnv.allows(Extension::doubleround)) {
579 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
580 <<
"requires extension [doubleround]";
583 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
584 !targetEnv.allows(Extension::inexactround)) {
586 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
587 <<
"requires extension [inexactround]";
594 LogicalResult CheckVariable(Operation *op);
595 LogicalResult CheckVariableReadOrWrite(Operation *op);
596 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
599 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
602 TosaProfileCompliance profileComp;
603 tosa::TargetEnv targetEnv;
607LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
608 auto *op = tosaOp.getOperation();
609 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
614 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
622LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
623 auto *op = tosaOp.getOperation();
626 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
634LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
635 auto *op = tosaOp.getOperation();
637 if (
failed(levelCheckRank(op, variableType,
"variable type",
645LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
646 auto *op = tosaOp.getOperation();
648 if (
failed(levelCheckSize(op, variableType,
"variable type")))
654LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
655#define CHECK_RANKS_AND_SIZES(tosaOp) \
656 if (isa<tosa::tosaOp##Op>(op)) { \
657 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
659 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
663#define CHECK_SIZES(tosaOp) \
664 if (isa<tosa::tosaOp##Op>(op)) { \
665 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
669#define CHECK_SHAPE_LEN(tosaOp) \
670 if (isa<tosa::tosaOp##Op>(op)) { \
671 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
798#undef CHECK_RANKS_AND_SIZES
800#undef CHECK_SHAPE_LEN
805LogicalResult TosaValidation::levelCheckSize(Operation *op,
806 const Type &typeToCheck,
807 const StringRef operandOrResult) {
808 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
810 return op->
emitOpError() <<
"failed level check: unranked tensor";
811 auto shape = type.getShape();
812 for (
auto dim : shape) {
813 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
814 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
815 const TosaSpecificationVersion minRequiredVersion(1, 1);
825 return op->
emitOpError() <<
"failed level check: " << operandOrResult
826 <<
" shape dimension cannot be dynamic when"
827 <<
" targeting TOSA specification version 1.0"
832 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
833 int64_t size = element_bytes * type.getNumElements();
840 const int64_t max_size =
844 <<
"failed level check: " << operandOrResult
845 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
850LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
857 if (
failed(levelCheckRanksAndSizes(op)))
860 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
861 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
862 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
863 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
864 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
865 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
866 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
867 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op)) ||
868 failed(levelCheckConv2DBlockScaled(op))) {
873 if (
failed(levelCheckListSize(op))) {
877 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
878 if (
failed(levelCheckMaxNesting(op))) {
886LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
887 if (
failed(attributeCheckRescale(op)))
892inline bool CompatibleTypes(
const mlir::Type &type,
893 const mlir::Type &declaredType) {
895 return type == declaredType;
898LogicalResult TosaValidation::CheckVariable(Operation *op) {
899 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
900 mlir::StringAttr nameAttr = variableOp.getNameAttr();
902 if (variablesMap.count(nameAttr))
903 return op->
emitOpError() <<
"name has already been declared";
905 auto elementType = variableOp.getType();
906 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
907 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
908 RankedTensorType variableType =
909 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
911 variablesMap[nameAttr] = variableType;
917LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
918 if (isa<mlir::tosa::VariableReadOp>(op) ||
919 isa<mlir::tosa::VariableWriteOp>(op)) {
920 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
921 if (!variablesMap.count(nameAttr))
922 return op->
emitOpError() <<
"name has not been declared";
924 auto varType = variablesMap[nameAttr];
927 auto type = v.getType();
928 if (!CompatibleTypes(type, varType))
929 return op->
emitOpError() <<
"operand type does not equal variable type";
933 auto type = v.getType();
934 if (!CompatibleTypes(type, varType))
935 return op->
emitOpError() <<
"result type does not equal variable type";
942LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
943 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
948LogicalResult checkErrorIfResize(Operation *op) {
949 auto resize = dyn_cast<tosa::ResizeOp>(op);
953 const Value input = resize.getInput();
954 const Value output = resize.getOutput();
955 const RankedTensorType inputType =
956 llvm::dyn_cast<RankedTensorType>(input.
getType());
957 const RankedTensorType outputType =
958 llvm::dyn_cast<RankedTensorType>(output.
getType());
960 if (!inputType || !outputType)
961 return op->
emitOpError(
"expect ranked input/output tensor");
965 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
966 const SmallVector<int64_t, 4> sizes = {
967 outputType.getDimSize(1), outputType.getDimSize(2),
968 inputType.getDimSize(1), inputType.getDimSize(2)};
969 const int64_t *maxDim = llvm::max_element(sizes);
970 if (maxDim != sizes.end() && *maxDim >= 16384)
972 "expect input/output height/width dims to be < 16384, ")
973 <<
"got [OH, OW, IH, IW] = " << sizes;
986 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
988 "expect all scale numerator values to be <= (1 << 11), "
990 << scaleYN <<
", scale_x_n=" << scaleXN;
992 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
993 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
994 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
1002 const int64_t offsetY = offset[0];
1003 const int64_t offsetX = offset[1];
1006 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1008 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1009 << offsetY <<
"/" << scaleYN;
1010 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1012 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1013 << offsetX <<
"/" << scaleXN;
1015 const int64_t borderY = border[0];
1016 const int64_t borderX = border[1];
1017 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1019 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1020 << borderY <<
"/" << scaleYN;
1021 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1023 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1024 << borderX <<
"/" << scaleXN;
1037 const int64_t rhs) -> std::optional<int64_t> {
1039 return std::nullopt;
1043 const int64_t oh = outputType.getDimSize(1);
1044 const int64_t ow = outputType.getDimSize(2);
1045 const int64_t ih = inputType.getDimSize(1);
1046 const int64_t iw = inputType.getDimSize(2);
1048 if (ih != ShapedType::kDynamic) {
1049 const std::optional<int64_t> calculatedOutHeightMinusOne =
1050 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1051 if (!calculatedOutHeightMinusOne.has_value())
1053 "expected (input_height - 1) * scale_y_n - offset_y + "
1055 <<
"to be wholly divisible by scale_y_d, got ((" << ih
1056 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
1057 <<
") / " << scaleYD;
1058 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1059 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1061 "calculated output height did not match expected: ")
1062 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1065 if (iw != ShapedType::kDynamic) {
1066 const std::optional<int64_t> calculatedOutWidthMinusOne =
1067 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1068 if (!calculatedOutWidthMinusOne.has_value())
1070 "expected (input_width - 1) * scale_x_n - offset_x + "
1072 <<
"to be wholly divisible by scale_x_d, got ((" << iw
1073 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
1074 <<
") / " << scaleXD;
1075 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1076 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1077 return op->
emitOpError(
"calculated output width did not match expected: ")
1078 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1084LogicalResult checkErrorIfMul(Operation *op) {
1085 auto mul = dyn_cast<tosa::MulOp>(op);
1091 ElementsAttr shift_elem;
1094 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1096 if (inputElemType.isInteger(32)) {
1098 if (shift < 0 || shift > 63)
1100 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1105 <<
"requires shift = 0 for all input data types that "
1106 "are not int32_t, but got: "
1113LogicalResult checkErrorIfTable(Operation *op) {
1114 auto table = dyn_cast<tosa::TableOp>(op);
1120 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1122 const ShapeAdaptor tableShape(table.getTable().getType());
1123 if (tableShape.hasStaticShape()) {
1124 const auto numElements = tableShape.getNumElements();
1125 if (numElements != tableSize)
1126 return op->
emitOpError() <<
"requires table size of " << tableSize
1127 <<
", got " << numElements;
1133LogicalResult checkErrorIfRescale(Operation *op) {
1134 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1138 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1139 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1140 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1141 !outputType.getElementType().isInteger())
1144 auto inElemType = inputType.getElementType();
1145 auto outElemType = outputType.getElementType();
1146 auto inWidth = inElemType.getIntOrFloatBitWidth();
1147 auto outWidth = outElemType.getIntOrFloatBitWidth();
1149 bool inputUnsigned = rescale.getInputUnsigned();
1150 bool outputUnsigned = rescale.getOutputUnsigned();
1152 bool scale32 = rescale.getScale32();
1153 auto roundingMode = rescale.getRoundingMode();
1156 if (scale32 && inWidth == 48)
1157 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1160 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1162 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1165 if (inputUnsigned && outputUnsigned)
1166 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1169 if (outWidth == 32 && inputUnsigned)
1171 <<
"i32 output type is not allowed with unsigned input.";
1174 if (inWidth == 32 && outputUnsigned)
1176 <<
"i32 input type is not allowed with unsigned output.";
1179 if (inWidth == 48 && outputUnsigned)
1181 <<
"i48 input type is not allowed with unsigned output.";
1184 if (inWidth == 48 && inputUnsigned)
1185 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1188 if (inWidth == 32 && inputUnsigned)
1189 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1192 if (outWidth == 32 && outputUnsigned)
1193 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1198LogicalResult checkErrorIfPad(Operation *op) {
1199 auto pad = dyn_cast<tosa::PadOp>(op);
1203 DenseIntElementsAttr paddingAttr;
1208 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1209 if (val.getSExtValue() < 0)
1210 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1211 << val.getSExtValue();
1217static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1218 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1219 Region *operandRegion = operand.getParentRegion();
1220 return operandRegion && region->isAncestor(operandRegion);
1224static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1225 bool noLiveInValue =
true;
1226 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1227 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1228 noLiveInValue =
false;
1233 return noLiveInValue ?
success() : failure();
1236LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1237 StringRef regionName) {
1238 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1241 <<
"is not conformant to the TOSA specification. It requires the '"
1242 << regionName <<
"' region is isolated from above.\n";
1245LogicalResult checkErrorIfCondIf(Operation *op) {
1246 auto ifOp = dyn_cast<tosa::IfOp>(op);
1279 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1280 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1285LogicalResult checkErrorIfWhileLoop(Operation *op) {
1286 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1290 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1291 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1296LogicalResult checkErrorIfScatter(Operation *op) {
1297 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1302 DenseIntElementsAttr indicesAttr;
1306 auto const indicesType =
1307 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1308 if (!indicesType || !indicesType.hasRank()) {
1314 op->
emitOpError(
"indices values contain duplicates");
1321LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1322 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1323 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1324 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfCondIf(op)) ||
1325 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1330bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1331 if (isa<FloatType>(type)) {
1332 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1333 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1334 Float6E3M2FNType, Float8E8M0FNUType>(type);
1335 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1336 if (intTy.isSignless()) {
1337 switch (intTy.getWidth()) {
1347 }
else if (allowUnsigned && intTy.isUnsigned()) {
1348 switch (intTy.getWidth()) {
1355 }
else if (isa<tosa::shapeType>(type))
1357 else if (isa<tosa::mxint8Type>(type))
1362void TosaValidation::runOnOperation() {
1363 ModuleOp modOp = getOperation();
1364 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1369 const auto maybeTargetEnv =
1371 if (
failed(maybeTargetEnv))
1372 return signalPassFailure();
1373 targetEnv = *maybeTargetEnv;
1375 modOp.walk([&](Operation *op) {
1384 const bool allowUnsigned =
1385 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1388 if (!isValidElementType(elementTy, allowUnsigned)) {
1389 op->
emitOpError() <<
"is not profile-aligned: element type "
1390 << elementTy <<
" is not legal";
1391 return signalPassFailure();
1396 if (!isValidElementType(elementTy, allowUnsigned)) {
1397 op->
emitOpError() <<
"is not profile-aligned: element type "
1398 << elementTy <<
" is not legal";
1399 return signalPassFailure();
1403 if (strictOpSpecAlignment &&
1405 return signalPassFailure();
1407 if (strictOpSpecAlignment &&
1409 return signalPassFailure();
1411 if (!allowInvalidOpDatatypeCombinations &&
1413 return signalPassFailure();
1417 if (
failed(applyConstantOperandCheck(op)))
1418 signalPassFailure();
1421 if (
failed(applyLevelCheck(op)))
1422 signalPassFailure();
1425 if (
failed(applyAttributeCheck(op)))
1426 signalPassFailure();
1429 if (
failed(applyVariableCheck(op)))
1430 signalPassFailure();
1433 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1434 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.