30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
36#define GEN_PASS_DEF_TOSAVALIDATION
37#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
48 for (
const auto index : operandIndices) {
51 return op->
emitOpError(
"expected compile time resolvable constant, but "
52 "got variable value for operand #")
59static LogicalResult checkConstantOperandMul(
Operation *op,
61 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
63 return checkConstantOperands(op, {2});
68static LogicalResult checkConstantOperandTable(
Operation *op,
70 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
72 return checkConstantOperands(op, {1});
77static LogicalResult checkConstantOperandPad(
Operation *op,
79 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
81 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
84 return checkConstantOperands(op, {2});
89static LogicalResult checkConstantOperandRescale(
Operation *op,
91 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
93 return checkConstantOperands(op, {1, 2, 3, 4});
99static LogicalResult checkConstantOperandConvOps(
Operation *op,
101 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
103 return checkConstantOperands(op, {3, 4});
108static LogicalResult checkConstantOperandMatMul(
Operation *op,
110 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
112 return checkConstantOperands(op, {2, 3});
119 if (!env.
allows(Extension::dynamic) &&
120 isa<tosa::RowGatherBlockScaledOp>(op)) {
121 auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
122 const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
123 return checkConstantOperands(op, {rowCountIndex});
128static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
130 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
132 return checkConstantOperands(op, {1, 2});
139 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dAdaptiveOp>(op)) {
143 return checkConstantOperands(op, {1, 2});
148static LogicalResult checkConstantOperandNegate(
Operation *op,
150 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
152 return checkConstantOperands(op, {1, 2});
157static LogicalResult checkConstantOperandSilceShape(
Operation *op,
159 if (!env.
allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
161 return checkConstantOperands(op, {1, 2});
172 explicit TosaValidation() { populateConstantOperandChecks(); }
174 explicit TosaValidation(
const TosaValidationOptions &
options)
176 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
177 this->allowInvalidOpDatatypeCombinations =
178 options.allowInvalidOpDatatypeCombinations;
180 void runOnOperation() final;
182 LogicalResult applyConstantOperandCheck(Operation *op) {
183 for (
auto &checker : constCheckers) {
184 if (
failed(checker(op, targetEnv)))
190 LogicalResult applyFunctionSignatureCheck(func::FuncOp op);
191 LogicalResult applyLevelCheck(Operation *op);
192 LogicalResult applyAttributeCheck(Operation *op);
195 LogicalResult applyVariableCheck(Operation *op);
198 LogicalResult applyErrorIfCheck(Operation *op);
201 void populateConstantOperandChecks() {
202 constCheckers.emplace_back(checkConstantOperandMul);
203 constCheckers.emplace_back(checkConstantOperandTable);
204 constCheckers.emplace_back(checkConstantOperandPad);
205 constCheckers.emplace_back(checkConstantOperandRescale);
206 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
207 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
208 constCheckers.emplace_back(
209 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
210 constCheckers.emplace_back(
211 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
212 constCheckers.emplace_back(checkConstantOperandMatMul);
213 constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
214 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
215 constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
216 constCheckers.emplace_back(checkConstantOperandNegate);
217 constCheckers.emplace_back(checkConstantOperandSilceShape);
220 LogicalResult levelCheck(Operation *op,
const int32_t calculatedValue,
221 const int32_t maxLevel,
const StringRef inputName,
222 const StringRef levelName) {
223 if (calculatedValue > maxLevel)
225 <<
"failed level check: " << inputName <<
" <= " << levelName
226 <<
" (" << maxLevel <<
"), got " << calculatedValue;
230 LogicalResult levelCheckKernel(Operation *op, int32_t v,
231 const StringRef inputName) {
232 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
236 LogicalResult levelCheckStride(Operation *op, int32_t v,
237 const StringRef inputName) {
238 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
242 LogicalResult levelCheckScale(Operation *op, int32_t v,
243 const StringRef inputName) {
244 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
248 LogicalResult levelCheckListSize(Operation *op, int32_t v,
249 const StringRef inputName) {
250 const std::string inputDesc =
251 llvm::formatv(
"length(tensor_list_shape({0}))", inputName);
252 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
253 inputDesc,
"MAX_TENSOR_LIST_SIZE");
257 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
258 const StringRef operandOrResult,
259 int32_t highest_rank) {
260 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
262 return op->
emitOpError() <<
"failed level check: unranked tensor";
263 if (type.getRank() > highest_rank)
264 return op->
emitOpError() <<
"failed level check: " << operandOrResult
265 <<
" rank(shape) <= MAX_RANK";
271 LogicalResult levelCheckRank(Operation *op,
const Value &v,
272 const StringRef operandOrResult,
273 int32_t highest_rank) {
274 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
278 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
279 const StringRef operandOrResult);
282 LogicalResult levelCheckSize(Operation *op,
const Value &v,
283 const StringRef operandOrResult) {
284 return levelCheckSize(op, v.
getType(), operandOrResult);
288 LogicalResult levelCheckShapeLength(Operation *op,
const Type typeToCheck,
289 const StringRef operandOrResult) {
290 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
291 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
293 <<
"failed shape type level check: " << typeToCheck
294 <<
" exceeds MAX_SHAPE_LEN";
300 template <
typename T>
301 LogicalResult levelCheckSizes(T tosaOp) {
302 auto op = tosaOp.getOperation();
304 if (
failed(levelCheckSize(op, v,
"operand")))
309 if (
failed(levelCheckSize(op, v,
"result")))
316 template <
typename T>
317 LogicalResult levelCheckRanks(T tosaOp) {
318 auto op = tosaOp.getOperation();
319 const TosaLevel tosaLevel = targetEnv.getLevel();
333 template <
typename T>
334 LogicalResult levelCheckShapeLengths(T tosaOp) {
335 for (
const auto &v : tosaOp->getOperands()) {
336 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"operand")))
339 for (
const auto &v : tosaOp->getResults()) {
340 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"result")))
348 LogicalResult levelCheckRanksAndSizes(Operation *op);
351 template <
typename T>
352 LogicalResult levelCheckPool(Operation *op) {
353 if (
auto poolOp = dyn_cast<T>(op)) {
354 for (
auto k : poolOp.getKernel()) {
355 if (
failed(levelCheckKernel(op, k,
"kernel"))) {
359 for (
auto s : poolOp.getStride()) {
360 if (
failed(levelCheckStride(op, s,
"stride"))) {
364 for (
auto p : poolOp.getPad()) {
365 if (
failed(levelCheckKernel(op, p,
"pad"))) {
373 template <
typename T>
374 static constexpr bool IsSupportedAdaptivePoolOp =
375 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
376 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
378 template <
typename T,
typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
380 LogicalResult levelCheckAdaptivePool(Operation *op) {
381 auto poolOp = dyn_cast<T>(op);
385 SmallVector<int64_t> kernelValues;
388 for (
const auto k : kernelValues)
389 if (
failed(levelCheckKernel(op, k,
"kernel")))
393 SmallVector<int64_t> strideValues;
396 for (
const auto s : strideValues)
397 if (
failed(levelCheckStride(op, s,
"stride")))
401 SmallVector<int64_t> padValues;
403 for (
const auto p : padValues)
404 if (
failed(levelCheckKernel(op, p,
"pad")))
412 template <
typename T>
413 LogicalResult levelCheckConv(Operation *op) {
414 if (
auto convOp = dyn_cast<T>(op)) {
416 for (
auto k : convOp.getDilation()) {
417 if (
failed(levelCheckKernel(op, k,
"dilation"))) {
421 for (
auto p : convOp.getPad()) {
422 if (
failed(levelCheckKernel(op, p,
"pad"))) {
426 for (
auto s : convOp.getStride()) {
427 if (
failed(levelCheckStride(op, s,
"stride"))) {
431 auto dilation = convOp.getDilation();
432 if (ShapedType weightType =
434 auto shape = weightType.getShape();
435 if (isa<tosa::Conv2DOp>(op)) {
436 assert(shape.size() == 4);
437 assert(dilation.size() == 2);
438 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
439 "dilation_y * KH")) ||
440 failed(levelCheckKernel(op, dilation[1] * shape[2],
443 }
else if (isa<tosa::Conv3DOp>(op)) {
444 assert(shape.size() == 5);
445 assert(dilation.size() == 3);
446 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
447 "dilation_d * KD")) ||
448 failed(levelCheckKernel(op, dilation[1] * shape[2],
449 "dilation_y * KH")) ||
450 failed(levelCheckKernel(op, dilation[2] * shape[3],
453 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
454 assert(shape.size() == 4);
455 assert(dilation.size() == 2);
456 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
457 "dilation_y * KH")) ||
458 failed(levelCheckKernel(op, dilation[1] * shape[1],
467 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
468 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
472 SmallVector<int64_t> padValues;
474 for (
const auto p : padValues)
475 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL")))
479 SmallVector<int64_t> strideValues;
482 for (
const auto s : strideValues)
483 if (
failed(levelCheckKernel(op, s,
"stride <= MAX_KERNEL")))
487 SmallVector<int64_t> dilationValues;
490 int64_t KH = ShapedType::kDynamic;
491 int64_t KW = ShapedType::kDynamic;
492 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
493 KH = weightDataShape.getDimSize(1);
494 KW = weightDataShape.getDimSize(2);
495 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
496 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
497 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
499 if (!ShapedType::isDynamic(KH) &&
500 failed(levelCheckKernel(op, dilationValues[0] * KH,
501 "dilation_y * KH <= MAX_KERNEL)")))
504 if (!ShapedType::isDynamic(KW) &&
505 failed(levelCheckKernel(op, dilationValues[1] * KW,
506 "dilation_x * KW <= MAX_KERNEL)")))
514 template <
typename T>
515 LogicalResult levelCheckFFT(Operation *op) {
518 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
519 auto shape = type.getShape();
520 assert(shape.size() == 3);
521 if (
failed(levelCheckKernel(op, shape[1],
"H")) ||
522 failed(levelCheckKernel(op, shape[2],
"W"))) {
532 LogicalResult levelCheckTransposeConv2d(Operation *op) {
533 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
534 if (ShapedType filterType =
535 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
536 auto shape = filterType.getShape();
537 assert(shape.size() == 4);
539 if (
failed(levelCheckKernel(op, shape[1],
"KH")) ||
540 failed(levelCheckKernel(op, shape[2],
"KW"))) {
544 for (
auto p : transpose.getOutPad()) {
545 if (
failed(levelCheckKernel(op, p,
"pad"))) {
549 for (
auto s : transpose.getStride()) {
550 if (
failed(levelCheckStride(op, s,
"stride"))) {
559 LogicalResult levelCheckResize(Operation *op) {
560 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
561 SmallVector<int64_t> scale;
566 const int64_t scaleYN = scale[0];
567 const int64_t scaleYD = scale[1];
568 const int64_t scaleXN = scale[2];
569 const int64_t scaleXD = scale[3];
571 levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d")) ||
573 levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d"))) {
584 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
585 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
593 getMaxNestedDepth(op, depth);
596 LogicalResult levelCheckMaxNesting(Operation *op) {
597 int32_t maxNestedDepth = 0;
598 getMaxNestedDepth(op, maxNestedDepth);
600 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
601 if (maxNestedDepth >= maxNestingLevel)
603 <<
"failed level check: tosa_nesting_depth < MAX_NESTING" <<
" ("
604 << maxNestingLevel <<
"), got " << maxNestedDepth;
608 LogicalResult levelCheckListSize(Operation *op) {
609 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
610 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
612 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
613 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
615 failed(levelCheckListSize(op, custom.getOutputList().size(),
620 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
622 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
623 failed(levelCheckListSize(op, condIf.getOutputList().size(),
628 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
629 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
630 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
634 if (
auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
635 return levelCheckListSize(op, concat_shape.getInput().size(),
"input");
639 LogicalResult attributeCheckRescale(Operation *op) {
640 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
641 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
642 !targetEnv.allows(Extension::doubleround)) {
644 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
645 <<
"requires extension [doubleround]";
648 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
649 !targetEnv.allows(Extension::inexactround)) {
651 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
652 <<
"requires extension [inexactround]";
659 LogicalResult CheckVariable(Operation *op);
660 LogicalResult CheckVariableReadOrWrite(Operation *op);
661 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
664 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
667 TosaProfileCompliance profileComp;
668 tosa::TargetEnv targetEnv;
672LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
673 auto *op = tosaOp.getOperation();
674 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
679 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
687LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
688 auto *op = tosaOp.getOperation();
691 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
699LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
700 auto *op = tosaOp.getOperation();
702 if (
failed(levelCheckRank(op, variableType,
"variable type",
710LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
711 auto *op = tosaOp.getOperation();
713 if (
failed(levelCheckSize(op, variableType,
"variable type")))
719LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
720#define CHECK_RANKS_AND_SIZES(tosaOp) \
721 if (isa<tosa::tosaOp##Op>(op)) { \
722 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
724 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
728#define CHECK_SIZES(tosaOp) \
729 if (isa<tosa::tosaOp##Op>(op)) { \
730 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
734#define CHECK_SHAPE_LEN(tosaOp) \
735 if (isa<tosa::tosaOp##Op>(op)) { \
736 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
866#undef CHECK_RANKS_AND_SIZES
868#undef CHECK_SHAPE_LEN
873LogicalResult TosaValidation::levelCheckSize(Operation *op,
874 const Type &typeToCheck,
875 const StringRef operandOrResult) {
876 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
878 return op->
emitOpError() <<
"failed level check: unranked tensor";
879 auto shape = type.getShape();
880 for (
auto dim : shape) {
881 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
882 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
883 const TosaSpecificationVersion minRequiredVersion(1, 1,
true);
893 return op->
emitOpError() <<
"failed level check: " << operandOrResult
894 <<
" shape dimension cannot be dynamic when"
895 <<
" targeting TOSA specification version 1.0"
900 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
901 int64_t size = element_bytes * type.getNumElements();
908 const int64_t max_size =
912 <<
"failed level check: " << operandOrResult
913 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
918LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
925 if (
failed(levelCheckRanksAndSizes(op)))
928 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
929 failed(levelCheckAdaptivePool<tosa::AvgPool2dAdaptiveOp>(op)) ||
930 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
931 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
932 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
933 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
934 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
935 failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
936 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
937 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op)) ||
938 failed(levelCheckConv2DBlockScaled(op))) {
943 if (
failed(levelCheckListSize(op))) {
947 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
948 if (
failed(levelCheckMaxNesting(op))) {
956LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
957 if (
failed(attributeCheckRescale(op)))
962inline bool CompatibleTypes(
const mlir::Type &type,
963 const mlir::Type &declaredType) {
965 return type == declaredType;
968LogicalResult TosaValidation::CheckVariable(Operation *op) {
969 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
970 mlir::StringAttr nameAttr = variableOp.getNameAttr();
972 if (variablesMap.count(nameAttr))
973 return op->
emitOpError() <<
"name has already been declared";
975 auto elementType = variableOp.getType();
976 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
977 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
978 RankedTensorType variableType =
979 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
981 variablesMap[nameAttr] = variableType;
987LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
988 if (isa<mlir::tosa::VariableReadOp>(op) ||
989 isa<mlir::tosa::VariableWriteOp>(op)) {
990 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
991 if (!variablesMap.count(nameAttr))
992 return op->
emitOpError() <<
"name has not been declared";
994 auto varType = variablesMap[nameAttr];
997 auto type = v.getType();
998 if (!CompatibleTypes(type, varType))
999 return op->
emitOpError() <<
"operand type does not equal variable type";
1003 auto type = v.getType();
1004 if (!CompatibleTypes(type, varType))
1005 return op->
emitOpError() <<
"result type does not equal variable type";
1012LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
1013 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
1018LogicalResult checkErrorIfResize(
Operation *op) {
1019 auto resize = dyn_cast<tosa::ResizeOp>(op);
1024 const Value output = resize.getOutput();
1025 const RankedTensorType inputType =
1026 llvm::dyn_cast<RankedTensorType>(input.
getType());
1027 const RankedTensorType outputType =
1028 llvm::dyn_cast<RankedTensorType>(output.
getType());
1030 if (!inputType || !outputType)
1031 return op->
emitOpError(
"expect ranked input/output tensor");
1035 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1037 outputType.getDimSize(1), outputType.getDimSize(2),
1038 inputType.getDimSize(1), inputType.getDimSize(2)};
1039 const int64_t *maxDim = llvm::max_element(sizes);
1040 if (maxDim != sizes.end() && *maxDim >= 16384)
1042 "expect input/output height/width dims to be < 16384, ")
1043 <<
"got [OH, OW, IH, IW] = " << sizes;
1050 const int64_t scaleYN = scale[0];
1052 const int64_t scaleXN = scale[2];
1053 const int64_t scaleXD = scale[3];
1056 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
1058 "expect all scale numerator values to be <= (1 << 11), "
1060 << scaleYN <<
", scale_x_n=" << scaleXN;
1062 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
1063 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
1064 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
1072 const int64_t offsetY = offset[0];
1073 const int64_t offsetX = offset[1];
1076 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1078 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1079 << offsetY <<
"/" << scaleYN;
1080 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1082 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1083 << offsetX <<
"/" << scaleXN;
1085 const int64_t borderY = border[0];
1086 const int64_t borderX = border[1];
1087 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1089 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1090 << borderY <<
"/" << scaleYN;
1091 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1093 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1094 << borderX <<
"/" << scaleXN;
1107 const int64_t rhs) -> std::optional<int64_t> {
1109 return std::nullopt;
1113 const int64_t oh = outputType.getDimSize(1);
1114 const int64_t ow = outputType.getDimSize(2);
1115 const int64_t ih = inputType.getDimSize(1);
1116 const int64_t iw = inputType.getDimSize(2);
1118 if (ih != ShapedType::kDynamic) {
1119 const std::optional<int64_t> calculatedOutHeightMinusOne =
1120 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1121 if (!calculatedOutHeightMinusOne.has_value())
1123 "expected (input_height - 1) * scale_y_n - offset_y + "
1125 <<
"to be wholly divisible by scale_y_d, got ((" << ih
1126 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
1127 <<
") / " << scaleYD;
1128 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1129 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1131 "calculated output height did not match expected: ")
1132 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1135 if (iw != ShapedType::kDynamic) {
1136 const std::optional<int64_t> calculatedOutWidthMinusOne =
1137 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1138 if (!calculatedOutWidthMinusOne.has_value())
1140 "expected (input_width - 1) * scale_x_n - offset_x + "
1142 <<
"to be wholly divisible by scale_x_d, got ((" << iw
1143 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
1144 <<
") / " << scaleXD;
1145 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1146 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1147 return op->
emitOpError(
"calculated output width did not match expected: ")
1148 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1154LogicalResult checkErrorIfMul(Operation *op) {
1155 auto mul = dyn_cast<tosa::MulOp>(op);
1161 ElementsAttr shift_elem;
1164 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1166 if (inputElemType.isInteger(32)) {
1168 if (shift < 0 || shift > 63)
1170 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1175 <<
"requires shift = 0 for all input data types that "
1176 "are not int32_t, but got: "
1183LogicalResult checkErrorIfTable(Operation *op) {
1184 auto table = dyn_cast<tosa::TableOp>(op);
1190 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1192 const ShapeAdaptor tableShape(table.getTable().getType());
1193 if (tableShape.hasStaticShape()) {
1194 const auto numElements = tableShape.getNumElements();
1195 if (numElements != tableSize)
1196 return op->
emitOpError() <<
"requires table size of " << tableSize
1197 <<
", got " << numElements;
1203LogicalResult checkErrorIfRescale(Operation *op) {
1204 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1208 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1209 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1210 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1211 !outputType.getElementType().isInteger())
1214 auto inElemType = inputType.getElementType();
1215 auto outElemType = outputType.getElementType();
1216 auto inWidth = inElemType.getIntOrFloatBitWidth();
1217 auto outWidth = outElemType.getIntOrFloatBitWidth();
1219 bool inputUnsigned = rescale.getInputUnsigned();
1220 bool outputUnsigned = rescale.getOutputUnsigned();
1222 bool scale32 = rescale.getScale32();
1223 auto roundingMode = rescale.getRoundingMode();
1226 if (scale32 && inWidth == 48)
1227 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1230 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1232 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1235 if (inputUnsigned && outputUnsigned)
1236 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1239 if (outWidth == 32 && inputUnsigned)
1241 <<
"i32 output type is not allowed with unsigned input.";
1244 if (inWidth == 32 && outputUnsigned)
1246 <<
"i32 input type is not allowed with unsigned output.";
1249 if (inWidth == 48 && outputUnsigned)
1251 <<
"i48 input type is not allowed with unsigned output.";
1254 if (inWidth == 48 && inputUnsigned)
1255 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1258 if (inWidth == 32 && inputUnsigned)
1259 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1262 if (outWidth == 32 && outputUnsigned)
1263 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1268LogicalResult checkErrorIfPad(Operation *op) {
1269 auto pad = dyn_cast<tosa::PadOp>(op);
1273 DenseIntElementsAttr paddingAttr;
1278 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1279 if (val.getSExtValue() < 0)
1280 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1281 << val.getSExtValue();
1287LogicalResult checkErrorIfReshape(Operation *op) {
1288 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1292 SmallVector<int64_t> shapeValues;
1298 return op->
emitOpError(
"shape input contains inferable dimension (")
1301 "which does not conform to the TOSA specification";
1306LogicalResult checkErrorIfSlice(Operation *op) {
1307 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1311 SmallVector<int64_t> startValues;
1312 SmallVector<int64_t> sizeValues;
1314 sliceOp.getStart().getDefiningOp(), startValues);
1315 const bool hasSizeValues =
1319 return op->
emitOpError(
"start input contains inferable dimension (")
1321 <<
") which does not conform to the TOSA specification";
1323 return op->
emitOpError(
"size input contains inferable dimension (")
1326 "does not conform to the TOSA specification";
1331static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1332 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1333 Region *operandRegion = operand.getParentRegion();
1334 return operandRegion && region->isAncestor(operandRegion);
1338static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1339 bool noLiveInValue =
true;
1340 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1341 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1342 noLiveInValue =
false;
1347 return noLiveInValue ?
success() : failure();
1350LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1351 StringRef regionName) {
1352 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1355 <<
"is not conformant to the TOSA specification. It requires the '"
1356 << regionName <<
"' region is isolated from above.\n";
1359LogicalResult checkErrorIfCondIf(Operation *op) {
1360 auto ifOp = dyn_cast<tosa::IfOp>(op);
1393 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1394 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1399LogicalResult checkErrorIfWhileLoop(Operation *op) {
1400 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1404 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1405 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1410LogicalResult checkErrorIfScatter(Operation *op) {
1411 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1416 DenseIntElementsAttr indicesAttr;
1420 auto const indicesType =
1421 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1422 if (!indicesType || !indicesType.hasRank()) {
1428 op->
emitOpError(
"indices values contain duplicates");
1435LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1436 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1437 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1438 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfReshape(op)) ||
1439 failed(checkErrorIfSlice(op)) ||
failed(checkErrorIfCondIf(op)) ||
1440 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1445LogicalResult TosaValidation::applyFunctionSignatureCheck(func::FuncOp op) {
1446 const auto isShapeType = [](Type type) {
return isa<tosa::shapeType>(type); };
1447 if (llvm::any_of(op.getArgumentTypes(), isShapeType))
1448 return op.emitOpError()
1449 <<
"Function argument types must be a tensor type to be TOSA "
1450 "compliant, got !tosa.shape type";
1451 if (llvm::any_of(op.getResultTypes(), isShapeType))
1452 return op.emitOpError()
1453 <<
"Function return types must be a tensor type to be TOSA "
1454 "compliant, got !tosa.shape type";
1458bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1459 if (isa<FloatType>(type)) {
1460 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1461 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1462 Float6E3M2FNType, Float8E8M0FNUType>(type);
1463 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1464 if (intTy.isSignless()) {
1465 switch (intTy.getWidth()) {
1475 }
else if (allowUnsigned && intTy.isUnsigned()) {
1476 switch (intTy.getWidth()) {
1483 }
else if (isa<tosa::shapeType>(type))
1485 else if (isa<tosa::mxint8Type>(type))
1490void TosaValidation::runOnOperation() {
1491 ModuleOp modOp = getOperation();
1492 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1497 const auto maybeTargetEnv =
1499 if (
failed(maybeTargetEnv))
1500 return signalPassFailure();
1501 targetEnv = *maybeTargetEnv;
1503 const auto functions = modOp.getOps<func::FuncOp>();
1504 if (llvm::any_of(functions, [&](func::FuncOp func) {
1505 return failed(applyFunctionSignatureCheck(func));
1507 return signalPassFailure();
1509 modOp.walk([&](Operation *op) {
1518 const bool allowUnsigned =
1519 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1522 if (!isValidElementType(elementTy, allowUnsigned)) {
1523 op->
emitOpError() <<
"is not profile-aligned: element type "
1524 << elementTy <<
" is not legal";
1525 return signalPassFailure();
1530 if (!isValidElementType(elementTy, allowUnsigned)) {
1531 op->
emitOpError() <<
"is not profile-aligned: element type "
1532 << elementTy <<
" is not legal";
1533 return signalPassFailure();
1537 if (strictOpSpecAlignment &&
1539 return signalPassFailure();
1541 if (strictOpSpecAlignment &&
1543 return signalPassFailure();
1545 if (!allowInvalidOpDatatypeCombinations &&
1547 return signalPassFailure();
1551 if (
failed(applyConstantOperandCheck(op)))
1552 signalPassFailure();
1555 if (
failed(applyLevelCheck(op)))
1556 signalPassFailure();
1559 if (
failed(applyAttributeCheck(op)))
1560 signalPassFailure();
1563 if (
failed(applyVariableCheck(op)))
1564 signalPassFailure();
1567 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1568 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.