29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/Support/FormatVariadic.h"
35#define GEN_PASS_DEF_TOSAVALIDATION
36#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
47 for (
const auto index : operandIndices) {
50 return op->
emitOpError(
"expected compile time resolvable constant, but "
51 "got variable value for operand #")
58static LogicalResult checkConstantOperandMul(
Operation *op,
60 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
62 return checkConstantOperands(op, {2});
67static LogicalResult checkConstantOperandTable(
Operation *op,
69 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
71 return checkConstantOperands(op, {1});
76static LogicalResult checkConstantOperandPad(
Operation *op,
78 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
80 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
83 return checkConstantOperands(op, {2});
88static LogicalResult checkConstantOperandRescale(
Operation *op,
90 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
92 return checkConstantOperands(op, {1, 2, 3, 4});
98static LogicalResult checkConstantOperandConvOps(
Operation *op,
100 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
102 return checkConstantOperands(op, {3, 4});
107static LogicalResult checkConstantOperandMatMul(
Operation *op,
109 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
111 return checkConstantOperands(op, {2, 3});
116static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
118 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
120 return checkConstantOperands(op, {1, 2});
125static LogicalResult checkConstantOperandNegate(
Operation *op,
127 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
129 return checkConstantOperands(op, {1, 2});
134static LogicalResult checkConstantOperandSilceShape(
Operation *op,
136 if (!env.
allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
138 return checkConstantOperands(op, {1, 2});
147struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
149 explicit TosaValidation() { populateConstantOperandChecks(); }
151 explicit TosaValidation(
const TosaValidationOptions &
options)
153 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
154 this->allowInvalidOpDatatypeCombinations =
155 options.allowInvalidOpDatatypeCombinations;
157 void runOnOperation() final;
159 LogicalResult applyConstantOperandCheck(Operation *op) {
160 for (
auto &checker : constCheckers) {
161 if (
failed(checker(op, targetEnv)))
167 LogicalResult applyLevelCheck(Operation *op);
168 LogicalResult applyAttributeCheck(Operation *op);
171 LogicalResult applyVariableCheck(Operation *op);
174 LogicalResult applyErrorIfCheck(Operation *op);
177 void populateConstantOperandChecks() {
178 constCheckers.emplace_back(checkConstantOperandMul);
179 constCheckers.emplace_back(checkConstantOperandTable);
180 constCheckers.emplace_back(checkConstantOperandPad);
181 constCheckers.emplace_back(checkConstantOperandRescale);
182 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
183 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
184 constCheckers.emplace_back(
185 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
186 constCheckers.emplace_back(
187 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
188 constCheckers.emplace_back(checkConstantOperandMatMul);
189 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
190 constCheckers.emplace_back(checkConstantOperandNegate);
191 constCheckers.emplace_back(checkConstantOperandSilceShape);
194 LogicalResult levelCheck(Operation *op,
const int32_t calculatedValue,
195 const int32_t maxLevel,
const StringRef inputName,
196 const StringRef levelName) {
197 if (calculatedValue > maxLevel)
199 <<
"failed level check: " << inputName <<
" <= " << levelName
200 <<
" (" << maxLevel <<
"), got " << calculatedValue;
204 LogicalResult levelCheckKernel(Operation *op, int32_t v,
205 const StringRef inputName) {
206 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
210 LogicalResult levelCheckStride(Operation *op, int32_t v,
211 const StringRef inputName) {
212 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
216 LogicalResult levelCheckScale(Operation *op, int32_t v,
217 const StringRef inputName) {
218 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
222 LogicalResult levelCheckListSize(Operation *op, int32_t v,
223 const StringRef inputName) {
224 const std::string inputDesc =
225 llvm::formatv(
"length(tensor_list_shape({0}))", inputName);
226 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
227 inputDesc,
"MAX_TENSOR_LIST_SIZE");
231 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
232 const StringRef operandOrResult,
233 int32_t highest_rank) {
234 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
236 return op->
emitOpError() <<
"failed level check: unranked tensor";
237 if (type.getRank() > highest_rank)
238 return op->
emitOpError() <<
"failed level check: " << operandOrResult
239 <<
" rank(shape) <= MAX_RANK";
245 LogicalResult levelCheckRank(Operation *op,
const Value &v,
246 const StringRef operandOrResult,
247 int32_t highest_rank) {
248 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
252 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
253 const StringRef operandOrResult);
256 LogicalResult levelCheckSize(Operation *op,
const Value &v,
257 const StringRef operandOrResult) {
258 return levelCheckSize(op, v.
getType(), operandOrResult);
262 LogicalResult levelCheckShapeLength(Operation *op,
const Type typeToCheck,
263 const StringRef operandOrResult) {
264 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
265 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
267 <<
"failed shape type level check: " << typeToCheck
268 <<
" exceeds MAX_SHAPE_LEN";
274 template <
typename T>
275 LogicalResult levelCheckSizes(T tosaOp) {
276 auto op = tosaOp.getOperation();
278 if (
failed(levelCheckSize(op, v,
"operand")))
283 if (
failed(levelCheckSize(op, v,
"result")))
290 template <
typename T>
291 LogicalResult levelCheckRanks(T tosaOp) {
292 auto op = tosaOp.getOperation();
293 const TosaLevel tosaLevel = targetEnv.getLevel();
307 template <
typename T>
308 LogicalResult levelCheckShapeLengths(T tosaOp) {
309 for (
const auto &v : tosaOp->getOperands()) {
310 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"operand")))
313 for (
const auto &v : tosaOp->getResults()) {
314 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"result")))
322 LogicalResult levelCheckRanksAndSizes(Operation *op);
325 template <
typename T>
326 LogicalResult levelCheckPool(Operation *op) {
327 if (
auto poolOp = dyn_cast<T>(op)) {
328 for (
auto k : poolOp.getKernel()) {
329 if (
failed(levelCheckKernel(op, k,
"kernel"))) {
333 for (
auto s : poolOp.getStride()) {
334 if (
failed(levelCheckStride(op, s,
"stride"))) {
338 for (
auto p : poolOp.getPad()) {
339 if (
failed(levelCheckKernel(op, p,
"pad"))) {
348 template <
typename T>
349 LogicalResult levelCheckConv(Operation *op) {
350 if (
auto convOp = dyn_cast<T>(op)) {
352 for (
auto k : convOp.getDilation()) {
353 if (
failed(levelCheckKernel(op, k,
"dilation"))) {
357 for (
auto p : convOp.getPad()) {
358 if (
failed(levelCheckKernel(op, p,
"pad"))) {
362 for (
auto s : convOp.getStride()) {
363 if (
failed(levelCheckStride(op, s,
"stride"))) {
367 auto dilation = convOp.getDilation();
368 if (ShapedType weightType =
370 auto shape = weightType.getShape();
371 if (isa<tosa::Conv2DOp>(op)) {
372 assert(shape.size() == 4);
373 assert(dilation.size() == 2);
374 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
375 "dilation_y * KH")) ||
376 failed(levelCheckKernel(op, dilation[1] * shape[2],
379 }
else if (isa<tosa::Conv3DOp>(op)) {
380 assert(shape.size() == 5);
381 assert(dilation.size() == 3);
382 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
383 "dilation_d * KD")) ||
384 failed(levelCheckKernel(op, dilation[1] * shape[2],
385 "dilation_y * KH")) ||
386 failed(levelCheckKernel(op, dilation[2] * shape[3],
389 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
390 assert(shape.size() == 4);
391 assert(dilation.size() == 2);
392 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
393 "dilation_y * KH")) ||
394 failed(levelCheckKernel(op, dilation[1] * shape[1],
403 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
404 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
408 SmallVector<int64_t> padValues;
410 for (
const auto p : padValues)
411 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL")))
415 SmallVector<int64_t> strideValues;
418 for (
const auto s : strideValues)
419 if (
failed(levelCheckKernel(op, s,
"stride <= MAX_KERNEL")))
423 SmallVector<int64_t> dilationValues;
426 int64_t KH = ShapedType::kDynamic;
427 int64_t KW = ShapedType::kDynamic;
428 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
429 KH = weightDataShape.getDimSize(1);
430 KW = weightDataShape.getDimSize(2);
431 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
432 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
433 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
435 if (!ShapedType::isDynamic(KH) &&
436 failed(levelCheckKernel(op, dilationValues[0] * KH,
437 "dilation_y * KH <= MAX_KERNEL)")))
440 if (!ShapedType::isDynamic(KW) &&
441 failed(levelCheckKernel(op, dilationValues[1] * KW,
442 "dilation_x * KW <= MAX_KERNEL)")))
450 template <
typename T>
451 LogicalResult levelCheckFFT(Operation *op) {
454 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
455 auto shape = type.getShape();
456 assert(shape.size() == 3);
457 if (
failed(levelCheckKernel(op, shape[1],
"H")) ||
458 failed(levelCheckKernel(op, shape[2],
"W"))) {
468 LogicalResult levelCheckTransposeConv2d(Operation *op) {
469 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
470 if (ShapedType filterType =
471 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
472 auto shape = filterType.getShape();
473 assert(shape.size() == 4);
475 if (
failed(levelCheckKernel(op, shape[1],
"KH")) ||
476 failed(levelCheckKernel(op, shape[2],
"KW"))) {
480 for (
auto p : transpose.getOutPad()) {
481 if (
failed(levelCheckKernel(op, p,
"pad"))) {
485 for (
auto s : transpose.getStride()) {
486 if (
failed(levelCheckStride(op, s,
"stride"))) {
495 LogicalResult levelCheckResize(Operation *op) {
496 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
497 SmallVector<int64_t> scale;
502 const int64_t scaleYN = scale[0];
503 const int64_t scaleYD = scale[1];
504 const int64_t scaleXN = scale[2];
505 const int64_t scaleXD = scale[3];
507 levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d")) ||
509 levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d"))) {
520 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
521 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
529 getMaxNestedDepth(op, depth);
532 LogicalResult levelCheckMaxNesting(Operation *op) {
533 int32_t maxNestedDepth = 0;
534 getMaxNestedDepth(op, maxNestedDepth);
536 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
537 if (maxNestedDepth >= maxNestingLevel)
539 <<
"failed level check: tosa_nesting_depth < MAX_NESTING" <<
" ("
540 << maxNestingLevel <<
"), got " << maxNestedDepth;
544 LogicalResult levelCheckListSize(Operation *op) {
545 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
546 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
548 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
549 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
551 failed(levelCheckListSize(op, custom.getOutputList().size(),
556 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
558 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
559 failed(levelCheckListSize(op, condIf.getOutputList().size(),
564 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
565 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
566 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
570 if (
auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
571 return levelCheckListSize(op, concat_shape.getInput().size(),
"input");
575 LogicalResult attributeCheckRescale(Operation *op) {
576 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
577 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
578 !targetEnv.allows(Extension::doubleround)) {
580 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
581 <<
"requires extension [doubleround]";
584 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
585 !targetEnv.allows(Extension::inexactround)) {
587 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
588 <<
"requires extension [inexactround]";
595 LogicalResult CheckVariable(Operation *op);
596 LogicalResult CheckVariableReadOrWrite(Operation *op);
597 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
600 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
603 TosaProfileCompliance profileComp;
604 tosa::TargetEnv targetEnv;
608LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
609 auto *op = tosaOp.getOperation();
610 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
615 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
623LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
624 auto *op = tosaOp.getOperation();
627 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
635LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
636 auto *op = tosaOp.getOperation();
638 if (
failed(levelCheckRank(op, variableType,
"variable type",
646LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
647 auto *op = tosaOp.getOperation();
649 if (
failed(levelCheckSize(op, variableType,
"variable type")))
655LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
656#define CHECK_RANKS_AND_SIZES(tosaOp) \
657 if (isa<tosa::tosaOp##Op>(op)) { \
658 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
660 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
664#define CHECK_SIZES(tosaOp) \
665 if (isa<tosa::tosaOp##Op>(op)) { \
666 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
670#define CHECK_SHAPE_LEN(tosaOp) \
671 if (isa<tosa::tosaOp##Op>(op)) { \
672 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
799#undef CHECK_RANKS_AND_SIZES
801#undef CHECK_SHAPE_LEN
806LogicalResult TosaValidation::levelCheckSize(Operation *op,
807 const Type &typeToCheck,
808 const StringRef operandOrResult) {
809 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
811 return op->
emitOpError() <<
"failed level check: unranked tensor";
812 auto shape = type.getShape();
813 for (
auto dim : shape) {
814 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
815 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
816 const TosaSpecificationVersion minRequiredVersion(1, 1);
826 return op->
emitOpError() <<
"failed level check: " << operandOrResult
827 <<
" shape dimension cannot be dynamic when"
828 <<
" targeting TOSA specification version 1.0"
833 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
834 int64_t size = element_bytes * type.getNumElements();
841 const int64_t max_size =
845 <<
"failed level check: " << operandOrResult
846 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
851LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
858 if (
failed(levelCheckRanksAndSizes(op)))
861 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
862 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
863 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
864 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
865 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
866 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
867 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
868 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op)) ||
869 failed(levelCheckConv2DBlockScaled(op))) {
874 if (
failed(levelCheckListSize(op))) {
878 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
879 if (
failed(levelCheckMaxNesting(op))) {
887LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
888 if (
failed(attributeCheckRescale(op)))
893inline bool CompatibleTypes(
const mlir::Type &type,
894 const mlir::Type &declaredType) {
896 return type == declaredType;
899LogicalResult TosaValidation::CheckVariable(Operation *op) {
900 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
901 mlir::StringAttr nameAttr = variableOp.getNameAttr();
903 if (variablesMap.count(nameAttr))
904 return op->
emitOpError() <<
"name has already been declared";
906 auto elementType = variableOp.getType();
907 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
908 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
909 RankedTensorType variableType =
910 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
912 variablesMap[nameAttr] = variableType;
918LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
919 if (isa<mlir::tosa::VariableReadOp>(op) ||
920 isa<mlir::tosa::VariableWriteOp>(op)) {
921 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
922 if (!variablesMap.count(nameAttr))
923 return op->
emitOpError() <<
"name has not been declared";
925 auto varType = variablesMap[nameAttr];
928 auto type = v.getType();
929 if (!CompatibleTypes(type, varType))
930 return op->
emitOpError() <<
"operand type does not equal variable type";
934 auto type = v.getType();
935 if (!CompatibleTypes(type, varType))
936 return op->
emitOpError() <<
"result type does not equal variable type";
943LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
944 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
949LogicalResult checkErrorIfResize(Operation *op) {
950 auto resize = dyn_cast<tosa::ResizeOp>(op);
954 const Value input = resize.getInput();
955 const Value output = resize.getOutput();
956 const RankedTensorType inputType =
957 llvm::dyn_cast<RankedTensorType>(input.
getType());
958 const RankedTensorType outputType =
959 llvm::dyn_cast<RankedTensorType>(output.
getType());
961 if (!inputType || !outputType)
962 return op->
emitOpError(
"expect ranked input/output tensor");
966 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
967 const SmallVector<int64_t, 4> sizes = {
968 outputType.getDimSize(1), outputType.getDimSize(2),
969 inputType.getDimSize(1), inputType.getDimSize(2)};
970 const int64_t *maxDim = llvm::max_element(sizes);
971 if (maxDim != sizes.end() && *maxDim >= 16384)
973 "expect input/output height/width dims to be < 16384, ")
974 <<
"got [OH, OW, IH, IW] = " << sizes;
977 SmallVector<int64_t> scale;
981 const int64_t scaleYN = scale[0];
982 const int64_t scaleYD = scale[1];
983 const int64_t scaleXN = scale[2];
984 const int64_t scaleXD = scale[3];
987 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
989 "expect all scale numerator values to be <= (1 << 11), "
991 << scaleYN <<
", scale_x_n=" << scaleXN;
993 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
994 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
995 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
997 SmallVector<int64_t> offset;
998 SmallVector<int64_t> border;
1003 const int64_t offsetY = offset[0];
1004 const int64_t offsetX = offset[1];
1007 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1009 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1010 << offsetY <<
"/" << scaleYN;
1011 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1013 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1014 << offsetX <<
"/" << scaleXN;
1016 const int64_t borderY = border[0];
1017 const int64_t borderX = border[1];
1018 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1020 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1021 << borderY <<
"/" << scaleYN;
1022 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1024 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1025 << borderX <<
"/" << scaleXN;
1038 const int64_t
rhs) -> std::optional<int64_t> {
1040 return std::nullopt;
1044 const int64_t oh = outputType.getDimSize(1);
1045 const int64_t ow = outputType.getDimSize(2);
1046 const int64_t ih = inputType.getDimSize(1);
1047 const int64_t iw = inputType.getDimSize(2);
1049 if (ih != ShapedType::kDynamic) {
1050 const std::optional<int64_t> calculatedOutHeightMinusOne =
1051 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1052 if (!calculatedOutHeightMinusOne.has_value())
1054 "expected (input_height - 1) * scale_y_n - offset_y + "
1056 <<
"to be wholly divisible by scale_y_d, got ((" << ih
1057 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
1058 <<
") / " << scaleYD;
1059 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1060 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1062 "calculated output height did not match expected: ")
1063 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1066 if (iw != ShapedType::kDynamic) {
1067 const std::optional<int64_t> calculatedOutWidthMinusOne =
1068 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1069 if (!calculatedOutWidthMinusOne.has_value())
1071 "expected (input_width - 1) * scale_x_n - offset_x + "
1073 <<
"to be wholly divisible by scale_x_d, got ((" << iw
1074 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
1075 <<
") / " << scaleXD;
1076 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1077 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1078 return op->
emitOpError(
"calculated output width did not match expected: ")
1079 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1085LogicalResult checkErrorIfMul(Operation *op) {
1086 auto mul = dyn_cast<tosa::MulOp>(op);
1092 ElementsAttr shift_elem;
1095 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1097 if (inputElemType.isInteger(32)) {
1099 if (shift < 0 || shift > 63)
1101 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1106 <<
"requires shift = 0 for all input data types that "
1107 "are not int32_t, but got: "
1114LogicalResult checkErrorIfTable(Operation *op) {
1115 auto table = dyn_cast<tosa::TableOp>(op);
1121 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1123 const ShapeAdaptor tableShape(table.getTable().getType());
1124 if (tableShape.hasStaticShape()) {
1125 const auto numElements = tableShape.getNumElements();
1126 if (numElements != tableSize)
1127 return op->
emitOpError() <<
"requires table size of " << tableSize
1128 <<
", got " << numElements;
1134LogicalResult checkErrorIfRescale(Operation *op) {
1135 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1139 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1140 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1141 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1142 !outputType.getElementType().isInteger())
1145 auto inElemType = inputType.getElementType();
1146 auto outElemType = outputType.getElementType();
1147 auto inWidth = inElemType.getIntOrFloatBitWidth();
1148 auto outWidth = outElemType.getIntOrFloatBitWidth();
1150 bool inputUnsigned = rescale.getInputUnsigned();
1151 bool outputUnsigned = rescale.getOutputUnsigned();
1153 bool scale32 = rescale.getScale32();
1154 auto roundingMode = rescale.getRoundingMode();
1157 if (scale32 && inWidth == 48)
1158 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1161 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1163 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1166 if (inputUnsigned && outputUnsigned)
1167 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1170 if (outWidth == 32 && inputUnsigned)
1172 <<
"i32 output type is not allowed with unsigned input.";
1175 if (inWidth == 32 && outputUnsigned)
1177 <<
"i32 input type is not allowed with unsigned output.";
1180 if (inWidth == 48 && outputUnsigned)
1182 <<
"i48 input type is not allowed with unsigned output.";
1185 if (inWidth == 48 && inputUnsigned)
1186 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1189 if (inWidth == 32 && inputUnsigned)
1190 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1193 if (outWidth == 32 && outputUnsigned)
1194 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1199LogicalResult checkErrorIfPad(Operation *op) {
1200 auto pad = dyn_cast<tosa::PadOp>(op);
1204 DenseIntElementsAttr paddingAttr;
1209 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1210 if (val.getSExtValue() < 0)
1211 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1212 << val.getSExtValue();
1218LogicalResult checkErrorIfReshape(Operation *op) {
1219 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1223 SmallVector<int64_t> shapeValues;
1229 return op->
emitOpError(
"shape input contains inferable dimension (")
1232 "which does not conform to the TOSA specification";
1237LogicalResult checkErrorIfSlice(Operation *op) {
1238 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1242 SmallVector<int64_t> startValues;
1243 SmallVector<int64_t> sizeValues;
1245 sliceOp.getStart().getDefiningOp(), startValues);
1246 const bool hasSizeValues =
1250 return op->
emitOpError(
"start input contains inferable dimension (")
1252 <<
") which does not conform to the TOSA specification";
1254 return op->
emitOpError(
"size input contains inferable dimension (")
1257 "does not conform to the TOSA specification";
1262static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1263 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1264 Region *operandRegion = operand.getParentRegion();
1265 return operandRegion && region->isAncestor(operandRegion);
1269static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1270 bool noLiveInValue =
true;
1271 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1272 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1273 noLiveInValue =
false;
1278 return noLiveInValue ?
success() : failure();
1281LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1282 StringRef regionName) {
1283 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1286 <<
"is not conformant to the TOSA specification. It requires the '"
1287 << regionName <<
"' region is isolated from above.\n";
1290LogicalResult checkErrorIfCondIf(Operation *op) {
1291 auto ifOp = dyn_cast<tosa::IfOp>(op);
1324 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1325 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1330LogicalResult checkErrorIfWhileLoop(Operation *op) {
1331 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1335 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1336 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1341LogicalResult checkErrorIfScatter(Operation *op) {
1342 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1347 DenseIntElementsAttr indicesAttr;
1351 auto const indicesType =
1352 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1353 if (!indicesType || !indicesType.hasRank()) {
1359 op->
emitOpError(
"indices values contain duplicates");
1366LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1367 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1368 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1369 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfReshape(op)) ||
1370 failed(checkErrorIfSlice(op)) ||
failed(checkErrorIfCondIf(op)) ||
1371 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1376bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1377 if (isa<FloatType>(type)) {
1378 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1379 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1380 Float6E3M2FNType, Float8E8M0FNUType>(type);
1381 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1382 if (intTy.isSignless()) {
1383 switch (intTy.getWidth()) {
1393 }
else if (allowUnsigned && intTy.isUnsigned()) {
1394 switch (intTy.getWidth()) {
1401 }
else if (isa<tosa::shapeType>(type))
1403 else if (isa<tosa::mxint8Type>(type))
1408void TosaValidation::runOnOperation() {
1409 ModuleOp modOp = getOperation();
1410 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1415 const auto maybeTargetEnv =
1417 if (
failed(maybeTargetEnv))
1418 return signalPassFailure();
1419 targetEnv = *maybeTargetEnv;
1421 modOp.walk([&](Operation *op) {
1430 const bool allowUnsigned =
1431 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1434 if (!isValidElementType(elementTy, allowUnsigned)) {
1435 op->
emitOpError() <<
"is not profile-aligned: element type "
1436 << elementTy <<
" is not legal";
1437 return signalPassFailure();
1442 if (!isValidElementType(elementTy, allowUnsigned)) {
1443 op->
emitOpError() <<
"is not profile-aligned: element type "
1444 << elementTy <<
" is not legal";
1445 return signalPassFailure();
1449 if (strictOpSpecAlignment &&
1451 return signalPassFailure();
1453 if (strictOpSpecAlignment &&
1455 return signalPassFailure();
1457 if (!allowInvalidOpDatatypeCombinations &&
1459 return signalPassFailure();
1463 if (
failed(applyConstantOperandCheck(op)))
1464 signalPassFailure();
1467 if (
failed(applyLevelCheck(op)))
1468 signalPassFailure();
1471 if (
failed(applyAttributeCheck(op)))
1472 signalPassFailure();
1475 if (
failed(applyVariableCheck(op)))
1476 signalPassFailure();
1479 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1480 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.