30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
36#define GEN_PASS_DEF_TOSAVALIDATION
37#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
48 for (
const auto index : operandIndices) {
51 return op->
emitOpError(
"expected compile time resolvable constant, but "
52 "got variable value for operand #")
59static LogicalResult checkConstantOperandMul(
Operation *op,
61 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
63 return checkConstantOperands(op, {2});
68static LogicalResult checkConstantOperandTable(
Operation *op,
70 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
72 return checkConstantOperands(op, {1});
77static LogicalResult checkConstantOperandPad(
Operation *op,
79 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
81 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
84 return checkConstantOperands(op, {2});
89static LogicalResult checkConstantOperandRescale(
Operation *op,
91 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
93 return checkConstantOperands(op, {1, 2, 3, 4});
99static LogicalResult checkConstantOperandConvOps(
Operation *op,
101 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
103 return checkConstantOperands(op, {3, 4});
108static LogicalResult checkConstantOperandMatMul(
Operation *op,
110 if (!env.
allows(Extension::dynamic) &&
111 isa<tosa::MatMulOp, tosa::MatMulTOp>(op)) {
113 return checkConstantOperands(op, {2, 3});
120 if (!env.
allows(Extension::dynamic) &&
121 isa<tosa::RowGatherBlockScaledOp>(op)) {
122 auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
123 const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
124 return checkConstantOperands(op, {rowCountIndex});
129static LogicalResult checkConstantOperandRowGather(
Operation *op,
131 if (!env.
allows(Extension::dynamic) && isa<tosa::RowGatherOp>(op)) {
133 return checkConstantOperands(op, {2});
138static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
140 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
142 return checkConstantOperands(op, {1, 2});
149 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dAdaptiveOp>(op)) {
153 return checkConstantOperands(op, {1, 2});
158static LogicalResult checkConstantOperandNegate(
Operation *op,
160 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
162 return checkConstantOperands(op, {1, 2});
167static LogicalResult checkConstantOperandSilceShape(
Operation *op,
169 if (!env.
allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
171 return checkConstantOperands(op, {1, 2});
182 explicit TosaValidation() { populateConstantOperandChecks(); }
184 explicit TosaValidation(
const TosaValidationOptions &
options)
186 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
187 this->allowInvalidOpDatatypeCombinations =
188 options.allowInvalidOpDatatypeCombinations;
190 void runOnOperation() final;
192 LogicalResult applyConstantOperandCheck(Operation *op) {
193 for (
auto &checker : constCheckers) {
194 if (
failed(checker(op, targetEnv)))
200 LogicalResult applyFunctionSignatureCheck(func::FuncOp op);
201 LogicalResult applyLevelCheck(Operation *op);
202 LogicalResult applyAttributeCheck(Operation *op);
205 LogicalResult applyVariableCheck(Operation *op);
208 LogicalResult applyErrorIfCheck(Operation *op);
211 void populateConstantOperandChecks() {
212 constCheckers.emplace_back(checkConstantOperandMul);
213 constCheckers.emplace_back(checkConstantOperandTable);
214 constCheckers.emplace_back(checkConstantOperandPad);
215 constCheckers.emplace_back(checkConstantOperandRescale);
216 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
217 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
218 constCheckers.emplace_back(
219 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
220 constCheckers.emplace_back(
221 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
222 constCheckers.emplace_back(checkConstantOperandMatMul);
223 constCheckers.emplace_back(checkConstantOperandRowGather);
224 constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
225 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
226 constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
227 constCheckers.emplace_back(checkConstantOperandNegate);
228 constCheckers.emplace_back(checkConstantOperandSilceShape);
231 LogicalResult levelCheck(Operation *op,
const int32_t calculatedValue,
232 const int32_t maxLevel,
const StringRef inputName,
233 const StringRef levelName) {
234 if (calculatedValue > maxLevel)
236 <<
"failed level check: " << inputName <<
" <= " << levelName
237 <<
" (" << maxLevel <<
"), got " << calculatedValue;
241 LogicalResult levelCheckKernel(Operation *op, int32_t v,
242 const StringRef inputName) {
243 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
247 LogicalResult levelCheckStride(Operation *op, int32_t v,
248 const StringRef inputName) {
249 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
253 LogicalResult levelCheckScale(Operation *op, int32_t v,
254 const StringRef inputName) {
255 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
259 LogicalResult levelCheckListSize(Operation *op, int32_t v,
260 const StringRef inputName) {
261 const std::string inputDesc =
262 llvm::formatv(
"length(tensor_list_shape({0}))", inputName);
263 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
264 inputDesc,
"MAX_TENSOR_LIST_SIZE");
268 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
269 const StringRef operandOrResult,
270 int32_t highest_rank) {
271 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
273 return op->
emitOpError() <<
"failed level check: unranked tensor";
274 if (type.getRank() > highest_rank)
275 return op->
emitOpError() <<
"failed level check: " << operandOrResult
276 <<
" rank(shape) <= MAX_RANK";
282 LogicalResult levelCheckRank(Operation *op,
const Value &v,
283 const StringRef operandOrResult,
284 int32_t highest_rank) {
285 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
289 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
290 const StringRef operandOrResult);
293 LogicalResult levelCheckSize(Operation *op,
const Value &v,
294 const StringRef operandOrResult) {
295 return levelCheckSize(op, v.
getType(), operandOrResult);
299 LogicalResult levelCheckShapeLength(Operation *op,
const Type typeToCheck,
300 const StringRef operandOrResult) {
301 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
302 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
304 <<
"failed shape type level check: " << typeToCheck
305 <<
" exceeds MAX_SHAPE_LEN";
311 template <
typename T>
312 LogicalResult levelCheckSizes(T tosaOp) {
313 auto op = tosaOp.getOperation();
315 if (
failed(levelCheckSize(op, v,
"operand")))
320 if (
failed(levelCheckSize(op, v,
"result")))
327 template <
typename T>
328 LogicalResult levelCheckRanks(T tosaOp) {
329 auto op = tosaOp.getOperation();
330 const TosaLevel tosaLevel = targetEnv.getLevel();
344 template <
typename T>
345 LogicalResult levelCheckShapeLengths(T tosaOp) {
346 for (
const auto &v : tosaOp->getOperands()) {
347 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"operand")))
350 for (
const auto &v : tosaOp->getResults()) {
351 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"result")))
359 LogicalResult levelCheckRanksAndSizes(Operation *op);
362 template <
typename T>
363 LogicalResult levelCheckPool(Operation *op) {
364 if (
auto poolOp = dyn_cast<T>(op)) {
365 for (
auto k : poolOp.getKernel()) {
366 if (
failed(levelCheckKernel(op, k,
"kernel"))) {
370 for (
auto s : poolOp.getStride()) {
371 if (
failed(levelCheckStride(op, s,
"stride"))) {
375 for (
auto p : poolOp.getPad()) {
376 if (
failed(levelCheckKernel(op, p,
"pad"))) {
384 template <
typename T>
385 static constexpr bool IsSupportedAdaptivePoolOp =
386 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
387 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
389 template <
typename T,
typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
391 LogicalResult levelCheckAdaptivePool(Operation *op) {
392 auto poolOp = dyn_cast<T>(op);
396 SmallVector<int64_t> kernelValues;
399 for (
const auto k : kernelValues)
400 if (
failed(levelCheckKernel(op, k,
"kernel")))
404 SmallVector<int64_t> strideValues;
407 for (
const auto s : strideValues)
408 if (
failed(levelCheckStride(op, s,
"stride")))
412 SmallVector<int64_t> padValues;
414 for (
const auto p : padValues)
415 if (
failed(levelCheckKernel(op, p,
"pad")))
423 template <
typename T>
424 LogicalResult levelCheckConv(Operation *op) {
425 if (
auto convOp = dyn_cast<T>(op)) {
427 for (
auto k : convOp.getDilation()) {
428 if (
failed(levelCheckKernel(op, k,
"dilation"))) {
432 for (
auto p : convOp.getPad()) {
433 if (
failed(levelCheckKernel(op, p,
"pad"))) {
437 for (
auto s : convOp.getStride()) {
438 if (
failed(levelCheckStride(op, s,
"stride"))) {
442 auto dilation = convOp.getDilation();
443 if (ShapedType weightType =
445 auto shape = weightType.getShape();
446 if (isa<tosa::Conv2DOp>(op)) {
447 assert(shape.size() == 4);
448 assert(dilation.size() == 2);
449 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
450 "dilation_y * KH")) ||
451 failed(levelCheckKernel(op, dilation[1] * shape[2],
454 }
else if (isa<tosa::Conv3DOp>(op)) {
455 assert(shape.size() == 5);
456 assert(dilation.size() == 3);
457 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
458 "dilation_d * KD")) ||
459 failed(levelCheckKernel(op, dilation[1] * shape[2],
460 "dilation_y * KH")) ||
461 failed(levelCheckKernel(op, dilation[2] * shape[3],
464 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
465 assert(shape.size() == 4);
466 assert(dilation.size() == 2);
467 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
468 "dilation_y * KH")) ||
469 failed(levelCheckKernel(op, dilation[1] * shape[1],
478 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
479 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
483 SmallVector<int64_t> padValues;
485 for (
const auto p : padValues)
486 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL")))
490 SmallVector<int64_t> strideValues;
493 for (
const auto s : strideValues)
494 if (
failed(levelCheckKernel(op, s,
"stride <= MAX_KERNEL")))
498 SmallVector<int64_t> dilationValues;
501 int64_t KH = ShapedType::kDynamic;
502 int64_t KW = ShapedType::kDynamic;
503 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
504 KH = weightDataShape.getDimSize(1);
505 KW = weightDataShape.getDimSize(2);
506 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
507 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
508 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
510 if (!ShapedType::isDynamic(KH) &&
511 failed(levelCheckKernel(op, dilationValues[0] * KH,
512 "dilation_y * KH <= MAX_KERNEL)")))
515 if (!ShapedType::isDynamic(KW) &&
516 failed(levelCheckKernel(op, dilationValues[1] * KW,
517 "dilation_x * KW <= MAX_KERNEL)")))
525 template <
typename T>
526 LogicalResult levelCheckFFT(Operation *op) {
529 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
530 auto shape = type.getShape();
531 assert(shape.size() == 3);
532 if (
failed(levelCheckKernel(op, shape[1],
"H")) ||
533 failed(levelCheckKernel(op, shape[2],
"W"))) {
543 LogicalResult levelCheckTransposeConv2d(Operation *op) {
544 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
545 if (ShapedType filterType =
546 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
547 auto shape = filterType.getShape();
548 assert(shape.size() == 4);
550 if (
failed(levelCheckKernel(op, shape[1],
"KH")) ||
551 failed(levelCheckKernel(op, shape[2],
"KW"))) {
555 for (
auto p : transpose.getOutPad()) {
556 if (
failed(levelCheckKernel(op, p,
"pad"))) {
560 for (
auto s : transpose.getStride()) {
561 if (
failed(levelCheckStride(op, s,
"stride"))) {
570 LogicalResult levelCheckResize(Operation *op) {
571 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
572 SmallVector<int64_t> scale;
577 const int64_t scaleYN = scale[0];
578 const int64_t scaleYD = scale[1];
579 const int64_t scaleXN = scale[2];
580 const int64_t scaleXD = scale[3];
582 levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d")) ||
584 levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d"))) {
595 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
596 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
604 getMaxNestedDepth(op, depth);
607 LogicalResult levelCheckMaxNesting(Operation *op) {
608 int32_t maxNestedDepth = 0;
609 getMaxNestedDepth(op, maxNestedDepth);
611 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
612 if (maxNestedDepth >= maxNestingLevel)
614 <<
"failed level check: tosa_nesting_depth < MAX_NESTING" <<
" ("
615 << maxNestingLevel <<
"), got " << maxNestedDepth;
619 LogicalResult levelCheckListSize(Operation *op) {
620 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
621 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
623 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
624 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
626 failed(levelCheckListSize(op, custom.getOutputList().size(),
631 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
633 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
634 failed(levelCheckListSize(op, condIf.getOutputList().size(),
639 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
640 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
641 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
645 if (
auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
646 return levelCheckListSize(op, concat_shape.getInput().size(),
"input");
650 LogicalResult attributeCheckRescale(Operation *op) {
651 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
652 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
653 !targetEnv.allows(Extension::doubleround)) {
655 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
656 <<
"requires extension [doubleround]";
659 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
660 !targetEnv.allows(Extension::inexactround)) {
662 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
663 <<
"requires extension [inexactround]";
670 LogicalResult CheckVariable(Operation *op);
671 LogicalResult CheckVariableReadOrWrite(Operation *op);
672 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
675 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
678 TosaProfileCompliance profileComp;
679 tosa::TargetEnv targetEnv;
683LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
684 auto *op = tosaOp.getOperation();
685 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
690 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
698LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
699 auto *op = tosaOp.getOperation();
702 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
710LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
711 auto *op = tosaOp.getOperation();
713 if (
failed(levelCheckRank(op, variableType,
"variable type",
721LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
722 auto *op = tosaOp.getOperation();
724 if (
failed(levelCheckSize(op, variableType,
"variable type")))
730LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
731#define CHECK_RANKS_AND_SIZES(tosaOp) \
732 if (isa<tosa::tosaOp##Op>(op)) { \
733 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
735 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
739#define CHECK_SIZES(tosaOp) \
740 if (isa<tosa::tosaOp##Op>(op)) { \
741 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
745#define CHECK_SHAPE_LEN(tosaOp) \
746 if (isa<tosa::tosaOp##Op>(op)) { \
747 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
879#undef CHECK_RANKS_AND_SIZES
881#undef CHECK_SHAPE_LEN
886LogicalResult TosaValidation::levelCheckSize(Operation *op,
887 const Type &typeToCheck,
888 const StringRef operandOrResult) {
889 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
891 return op->
emitOpError() <<
"failed level check: unranked tensor";
892 auto shape = type.getShape();
893 for (
auto dim : shape) {
894 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
895 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
896 const TosaSpecificationVersion minRequiredVersion(1, 1,
true);
906 return op->
emitOpError() <<
"failed level check: " << operandOrResult
907 <<
" shape dimension cannot be dynamic when"
908 <<
" targeting TOSA specification version 1.0"
913 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
914 int64_t size = element_bytes * type.getNumElements();
921 const int64_t max_size =
925 <<
"failed level check: " << operandOrResult
926 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
931LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
938 if (
failed(levelCheckRanksAndSizes(op)))
941 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
942 failed(levelCheckAdaptivePool<tosa::AvgPool2dAdaptiveOp>(op)) ||
943 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
944 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
945 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
946 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
947 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
948 failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
949 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
950 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op)) ||
951 failed(levelCheckConv2DBlockScaled(op))) {
956 if (
failed(levelCheckListSize(op))) {
960 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
961 if (
failed(levelCheckMaxNesting(op))) {
969LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
970 if (
failed(attributeCheckRescale(op)))
975inline bool CompatibleTypes(
const mlir::Type &type,
976 const mlir::Type &declaredType) {
978 return type == declaredType;
981LogicalResult TosaValidation::CheckVariable(Operation *op) {
982 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
983 mlir::StringAttr nameAttr = variableOp.getNameAttr();
985 if (variablesMap.count(nameAttr))
986 return op->
emitOpError() <<
"name has already been declared";
988 auto elementType = variableOp.getType();
989 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
990 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
991 RankedTensorType variableType =
992 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
994 variablesMap[nameAttr] = variableType;
1000LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
1001 if (isa<mlir::tosa::VariableReadOp>(op) ||
1002 isa<mlir::tosa::VariableWriteOp>(op)) {
1003 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
1004 if (!variablesMap.count(nameAttr))
1005 return op->
emitOpError() <<
"name has not been declared";
1007 auto varType = variablesMap[nameAttr];
1010 auto type = v.getType();
1011 if (!CompatibleTypes(type, varType))
1012 return op->
emitOpError() <<
"operand type does not equal variable type";
1016 auto type = v.getType();
1017 if (!CompatibleTypes(type, varType))
1018 return op->
emitOpError() <<
"result type does not equal variable type";
1025LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
1026 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
1031LogicalResult checkErrorIfResize(Operation *op) {
1032 auto resize = dyn_cast<tosa::ResizeOp>(op);
1036 const Value input = resize.getInput();
1037 const Value output = resize.getOutput();
1038 const RankedTensorType inputType =
1039 llvm::dyn_cast<RankedTensorType>(input.
getType());
1040 const RankedTensorType outputType =
1041 llvm::dyn_cast<RankedTensorType>(output.
getType());
1043 if (!inputType || !outputType)
1044 return op->
emitOpError(
"expect ranked input/output tensor");
1048 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1049 const SmallVector<int64_t, 4> sizes = {
1050 outputType.getDimSize(1), outputType.getDimSize(2),
1051 inputType.getDimSize(1), inputType.getDimSize(2)};
1052 const int64_t *maxDim = llvm::max_element(sizes);
1053 if (maxDim != sizes.end() && *maxDim >= 16384)
1055 "expect input/output height/width dims to be < 16384, ")
1056 <<
"got [OH, OW, IH, IW] = " << sizes;
1059 SmallVector<int64_t> scale;
1063 const int64_t scaleYN = scale[0];
1064 const int64_t scaleYD = scale[1];
1065 const int64_t scaleXN = scale[2];
1066 const int64_t scaleXD = scale[3];
1069 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
1071 "expect all scale numerator values to be <= (1 << 11), "
1073 << scaleYN <<
", scale_x_n=" << scaleXN;
1075 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
1076 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
1077 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
1086 const int64_t offsetX = offset[1];
1089 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1091 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1092 << offsetY <<
"/" << scaleYN;
1093 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1095 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1096 << offsetX <<
"/" << scaleXN;
1098 const int64_t borderY = border[0];
1100 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1102 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1103 << borderY <<
"/" << scaleYN;
1104 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1106 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1107 << borderX <<
"/" << scaleXN;
1120 const int64_t rhs) -> std::optional<int64_t> {
1122 return std::nullopt;
1126 const int64_t oh = outputType.getDimSize(1);
1128 const int64_t ih = inputType.getDimSize(1);
1129 const int64_t iw = inputType.getDimSize(2);
1131 if (ih != ShapedType::kDynamic) {
1132 const std::optional<int64_t> calculatedOutHeightMinusOne =
1133 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1134 if (!calculatedOutHeightMinusOne.has_value())
1136 "expected (input_height - 1) * scale_y_n - offset_y + "
1138 <<
"to be wholly divisible by scale_y_d, got ((" << ih
1139 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
1140 <<
") / " << scaleYD;
1141 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1142 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1144 "calculated output height did not match expected: ")
1145 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1148 if (iw != ShapedType::kDynamic) {
1149 const std::optional<int64_t> calculatedOutWidthMinusOne =
1150 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1151 if (!calculatedOutWidthMinusOne.has_value())
1153 "expected (input_width - 1) * scale_x_n - offset_x + "
1155 <<
"to be wholly divisible by scale_x_d, got ((" << iw
1156 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
1157 <<
") / " << scaleXD;
1158 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1159 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1160 return op->
emitOpError(
"calculated output width did not match expected: ")
1161 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1167LogicalResult checkErrorIfMul(Operation *op) {
1168 auto mul = dyn_cast<tosa::MulOp>(op);
1174 ElementsAttr shift_elem;
1177 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1179 if (inputElemType.isInteger(32)) {
1181 if (shift < 0 || shift > 63)
1183 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1188 <<
"requires shift = 0 for all input data types that "
1189 "are not int32_t, but got: "
1196LogicalResult checkErrorIfTable(Operation *op) {
1197 auto table = dyn_cast<tosa::TableOp>(op);
1203 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1205 const ShapeAdaptor tableShape(table.getTable().getType());
1206 if (tableShape.hasStaticShape()) {
1207 const auto numElements = tableShape.getNumElements();
1208 if (numElements != tableSize)
1209 return op->
emitOpError() <<
"requires table size of " << tableSize
1210 <<
", got " << numElements;
1216LogicalResult checkErrorIfRescale(Operation *op) {
1217 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1221 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1222 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1223 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1224 !outputType.getElementType().isInteger())
1227 auto inElemType = inputType.getElementType();
1228 auto outElemType = outputType.getElementType();
1229 auto inWidth = inElemType.getIntOrFloatBitWidth();
1230 auto outWidth = outElemType.getIntOrFloatBitWidth();
1232 bool inputUnsigned = rescale.getInputUnsigned();
1233 bool outputUnsigned = rescale.getOutputUnsigned();
1235 bool scale32 = rescale.getScale32();
1236 auto roundingMode = rescale.getRoundingMode();
1239 if (scale32 && inWidth == 48)
1240 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1243 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1245 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1248 if (inputUnsigned && outputUnsigned)
1249 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1252 if (outWidth == 32 && inputUnsigned)
1254 <<
"i32 output type is not allowed with unsigned input.";
1257 if (inWidth == 32 && outputUnsigned)
1259 <<
"i32 input type is not allowed with unsigned output.";
1262 if (inWidth == 48 && outputUnsigned)
1264 <<
"i48 input type is not allowed with unsigned output.";
1267 if (inWidth == 48 && inputUnsigned)
1268 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1271 if (inWidth == 32 && inputUnsigned)
1272 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1275 if (outWidth == 32 && outputUnsigned)
1276 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1281LogicalResult checkErrorIfPad(Operation *op) {
1282 auto pad = dyn_cast<tosa::PadOp>(op);
1286 DenseIntElementsAttr paddingAttr;
1291 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1292 if (val.getSExtValue() < 0)
1293 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1294 << val.getSExtValue();
1300LogicalResult checkErrorIfReshape(Operation *op) {
1301 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1305 SmallVector<int64_t> shapeValues;
1311 return op->
emitOpError(
"shape input contains inferable dimension (")
1314 "which does not conform to the TOSA specification";
1319LogicalResult checkErrorIfSlice(Operation *op) {
1320 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1324 SmallVector<int64_t> startValues;
1325 SmallVector<int64_t> sizeValues;
1327 sliceOp.getStart().getDefiningOp(), startValues);
1328 const bool hasSizeValues =
1332 return op->
emitOpError(
"start input contains inferable dimension (")
1334 <<
") which does not conform to the TOSA specification";
1336 return op->
emitOpError(
"size input contains inferable dimension (")
1339 "does not conform to the TOSA specification";
1344static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1345 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1346 Region *operandRegion = operand.getParentRegion();
1347 return operandRegion && region->isAncestor(operandRegion);
1351static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1352 bool noLiveInValue =
true;
1353 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1354 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1355 noLiveInValue =
false;
1360 return noLiveInValue ?
success() : failure();
1363LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1364 StringRef regionName) {
1365 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1368 <<
"is not conformant to the TOSA specification. It requires the '"
1369 << regionName <<
"' region is isolated from above.\n";
1372LogicalResult checkErrorIfCondIf(Operation *op) {
1373 auto ifOp = dyn_cast<tosa::IfOp>(op);
1406 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1407 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1412LogicalResult checkErrorIfWhileLoop(Operation *op) {
1413 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1417 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1418 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1423LogicalResult checkErrorIfScatter(Operation *op) {
1424 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1429 DenseIntElementsAttr indicesAttr;
1433 auto const indicesType =
1434 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1435 if (!indicesType || !indicesType.hasRank()) {
1441 op->
emitOpError(
"indices values contain duplicates");
1448LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1449 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1450 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1451 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfReshape(op)) ||
1452 failed(checkErrorIfSlice(op)) ||
failed(checkErrorIfCondIf(op)) ||
1453 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1458LogicalResult TosaValidation::applyFunctionSignatureCheck(func::FuncOp op) {
1459 const auto isShapeType = [](Type type) {
return isa<tosa::shapeType>(type); };
1460 if (llvm::any_of(op.getArgumentTypes(), isShapeType))
1461 return op.emitOpError()
1462 <<
"Function argument types must be a tensor type to be TOSA "
1463 "compliant, got !tosa.shape type";
1464 if (llvm::any_of(op.getResultTypes(), isShapeType))
1465 return op.emitOpError()
1466 <<
"Function return types must be a tensor type to be TOSA "
1467 "compliant, got !tosa.shape type";
1471bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1472 if (isa<FloatType>(type)) {
1473 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1474 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1475 Float6E3M2FNType, Float8E8M0FNUType>(type);
1476 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1477 if (intTy.isSignless()) {
1478 switch (intTy.getWidth()) {
1488 }
else if (allowUnsigned && intTy.isUnsigned()) {
1489 switch (intTy.getWidth()) {
1496 }
else if (isa<tosa::shapeType>(type))
1498 else if (isa<tosa::mxint8Type>(type))
1503void TosaValidation::runOnOperation() {
1504 ModuleOp modOp = getOperation();
1505 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1510 const auto maybeTargetEnv =
1512 if (
failed(maybeTargetEnv))
1513 return signalPassFailure();
1514 targetEnv = *maybeTargetEnv;
1516 const auto functions = modOp.getOps<func::FuncOp>();
1517 if (llvm::any_of(functions, [&](func::FuncOp func) {
1518 return failed(applyFunctionSignatureCheck(func));
1520 return signalPassFailure();
1522 modOp.walk([&](Operation *op) {
1531 const bool allowUnsigned =
1532 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1535 if (!isValidElementType(elementTy, allowUnsigned)) {
1536 op->
emitOpError() <<
"is not profile-aligned: element type "
1537 << elementTy <<
" is not legal";
1538 return signalPassFailure();
1543 if (!isValidElementType(elementTy, allowUnsigned)) {
1544 op->
emitOpError() <<
"is not profile-aligned: element type "
1545 << elementTy <<
" is not legal";
1546 return signalPassFailure();
1550 if (strictOpSpecAlignment &&
1552 return signalPassFailure();
1554 if (strictOpSpecAlignment &&
1556 return signalPassFailure();
1558 if (!allowInvalidOpDatatypeCombinations &&
1560 return signalPassFailure();
1564 if (
failed(applyConstantOperandCheck(op)))
1565 signalPassFailure();
1568 if (
failed(applyLevelCheck(op)))
1569 signalPassFailure();
1572 if (
failed(applyAttributeCheck(op)))
1573 signalPassFailure();
1576 if (
failed(applyVariableCheck(op)))
1577 signalPassFailure();
1580 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1581 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.