29#include "llvm/ADT/StringExtras.h"
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45 for (
const auto index : operandIndices) {
48 return op->
emitOpError(
"expected compile time resolvable constant, but "
49 "got variable value for operand #")
56static LogicalResult checkConstantOperandMul(
Operation *op,
58 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 return checkConstantOperands(op, {2});
65static LogicalResult checkConstantOperandTable(
Operation *op,
67 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 return checkConstantOperands(op, {1});
74static LogicalResult checkConstantOperandPad(
Operation *op,
76 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
81 return checkConstantOperands(op, {2});
86static LogicalResult checkConstantOperandRescale(
Operation *op,
88 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 return checkConstantOperands(op, {1, 2, 3, 4});
96static LogicalResult checkConstantOperandConvOps(
Operation *op,
98 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
100 return checkConstantOperands(op, {3, 4});
105static LogicalResult checkConstantOperandMatMul(
Operation *op,
107 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 return checkConstantOperands(op, {2, 3});
114static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
116 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 return checkConstantOperands(op, {1, 2});
123static LogicalResult checkConstantOperandNegate(
Operation *op,
125 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 return checkConstantOperands(op, {1, 2});
132static LogicalResult checkConstantOperandSilceShape(
Operation *op,
134 if (!env.
allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
136 return checkConstantOperands(op, {1, 2});
145struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
147 explicit TosaValidation() { populateConstantOperandChecks(); }
149 explicit TosaValidation(
const TosaValidationOptions &
options)
151 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
152 this->allowInvalidOpDatatypeCombinations =
153 options.allowInvalidOpDatatypeCombinations;
155 void runOnOperation() final;
157 LogicalResult applyConstantOperandCheck(Operation *op) {
158 for (
auto &checker : constCheckers) {
159 if (
failed(checker(op, targetEnv)))
165 LogicalResult applyLevelCheck(Operation *op);
166 LogicalResult applyAttributeCheck(Operation *op);
169 LogicalResult applyVariableCheck(Operation *op);
172 LogicalResult applyErrorIfCheck(Operation *op);
175 void populateConstantOperandChecks() {
176 constCheckers.emplace_back(checkConstantOperandMul);
177 constCheckers.emplace_back(checkConstantOperandTable);
178 constCheckers.emplace_back(checkConstantOperandPad);
179 constCheckers.emplace_back(checkConstantOperandRescale);
180 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
181 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
182 constCheckers.emplace_back(
183 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
184 constCheckers.emplace_back(
185 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
186 constCheckers.emplace_back(checkConstantOperandMatMul);
187 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
188 constCheckers.emplace_back(checkConstantOperandNegate);
189 constCheckers.emplace_back(checkConstantOperandSilceShape);
192 LogicalResult levelCheckKernel(Operation *op, int32_t v,
193 const StringRef checkDesc) {
194 if (v > targetEnv.getLevel().MAX_KERNEL)
195 return op->
emitOpError() <<
"failed level check: " << checkDesc;
199 LogicalResult levelCheckStride(Operation *op, int32_t v,
200 const StringRef checkDesc) {
201 if (v > targetEnv.getLevel().MAX_STRIDE)
202 return op->
emitOpError() <<
"failed level check: " << checkDesc;
206 LogicalResult levelCheckScale(Operation *op, int32_t v,
207 const StringRef checkDesc) {
208 if (v > targetEnv.getLevel().MAX_SCALE)
209 return op->
emitOpError() <<
"failed level check: " << checkDesc;
213 LogicalResult levelCheckListSize(Operation *op, int32_t v,
214 const StringRef checkDesc) {
215 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
217 <<
"failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
222 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
223 const StringRef operandOrResult,
224 int32_t highest_rank) {
225 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
227 return op->
emitOpError() <<
"failed level check: unranked tensor";
228 if (type.getRank() > highest_rank)
229 return op->
emitOpError() <<
"failed level check: " << operandOrResult
230 <<
" rank(shape) <= MAX_RANK";
236 LogicalResult levelCheckRank(Operation *op,
const Value &v,
237 const StringRef operandOrResult,
238 int32_t highest_rank) {
239 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
243 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
244 const StringRef operandOrResult);
247 LogicalResult levelCheckSize(Operation *op,
const Value &v,
248 const StringRef operandOrResult) {
249 return levelCheckSize(op, v.
getType(), operandOrResult);
253 LogicalResult levelCheckShapeLength(Operation *op,
const Type typeToCheck,
254 const StringRef operandOrResult) {
255 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
256 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
258 <<
"failed shape type level check: " << typeToCheck
259 <<
" exceeds MAX_SHAPE_LEN";
265 template <
typename T>
266 LogicalResult levelCheckSizes(T tosaOp) {
267 auto op = tosaOp.getOperation();
269 if (
failed(levelCheckSize(op, v,
"operand")))
274 if (
failed(levelCheckSize(op, v,
"result")))
281 template <
typename T>
282 LogicalResult levelCheckRanks(T tosaOp) {
283 auto op = tosaOp.getOperation();
284 const TosaLevel tosaLevel = targetEnv.getLevel();
298 template <
typename T>
299 LogicalResult levelCheckShapeLengths(T tosaOp) {
300 for (
const auto &v : tosaOp->getOperands()) {
301 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"operand")))
304 for (
const auto &v : tosaOp->getResults()) {
305 if (
failed(levelCheckShapeLength(tosaOp, v.getType(),
"result")))
313 LogicalResult levelCheckRanksAndSizes(Operation *op);
316 template <
typename T>
317 LogicalResult levelCheckPool(Operation *op) {
318 if (
auto poolOp = dyn_cast<T>(op)) {
319 for (
auto k : poolOp.getKernel()) {
320 if (
failed(levelCheckKernel(op, k,
"kernel <= MAX_KERNEL"))) {
324 for (
auto s : poolOp.getStride()) {
325 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
329 for (
auto p : poolOp.getPad()) {
330 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
339 template <
typename T>
340 LogicalResult levelCheckConv(Operation *op) {
341 if (
auto convOp = dyn_cast<T>(op)) {
343 for (
auto k : convOp.getDilation()) {
344 if (
failed(levelCheckKernel(op, k,
"dilation <= MAX_KERNEL"))) {
348 for (
auto p : convOp.getPad()) {
349 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
353 for (
auto s : convOp.getStride()) {
354 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
358 auto dilation = convOp.getDilation();
359 if (ShapedType weightType =
361 auto shape = weightType.getShape();
362 if (isa<tosa::Conv2DOp>(op)) {
363 assert(shape.size() == 4);
364 assert(dilation.size() == 2);
365 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
366 "dilation_y * KH <= MAX_KERNEL)")) ||
367 failed(levelCheckKernel(op, dilation[1] * shape[2],
368 "dilation_x * KW <= MAX_KERNEL)")))
370 }
else if (isa<tosa::Conv3DOp>(op)) {
371 assert(shape.size() == 5);
372 assert(dilation.size() == 3);
373 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
374 "dilation_d * KD <= MAX_KERNEL)")) ||
375 failed(levelCheckKernel(op, dilation[1] * shape[2],
376 "dilation_y * KH <= MAX_KERNEL)")) ||
377 failed(levelCheckKernel(op, dilation[2] * shape[3],
378 "dilation_x * KW <= MAX_KERNEL)")))
380 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
381 assert(shape.size() == 4);
382 assert(dilation.size() == 2);
383 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
384 "dilation_y * KH <= MAX_KERNEL)")) ||
385 failed(levelCheckKernel(op, dilation[1] * shape[1],
386 "dilation_x * KW <= MAX_KERNEL)")))
394 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
395 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
399 SmallVector<int64_t> padValues;
401 for (
const auto p : padValues)
402 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL")))
406 SmallVector<int64_t> strideValues;
409 for (
const auto s : strideValues)
410 if (
failed(levelCheckKernel(op, s,
"stride <= MAX_KERNEL")))
414 SmallVector<int64_t> dilationValues;
417 int64_t KH = ShapedType::kDynamic;
418 int64_t KW = ShapedType::kDynamic;
419 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
420 KH = weightDataShape.getDimSize(1);
421 KW = weightDataShape.getDimSize(2);
422 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
423 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
424 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
426 if (!ShapedType::isDynamic(KH) &&
427 failed(levelCheckKernel(op, dilationValues[0] * KH,
428 "dilation_y * KH <= MAX_KERNEL)")))
431 if (!ShapedType::isDynamic(KW) &&
432 failed(levelCheckKernel(op, dilationValues[1] * KW,
433 "dilation_x * KW <= MAX_KERNEL)")))
441 template <
typename T>
442 LogicalResult levelCheckFFT(Operation *op) {
445 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
446 auto shape = type.getShape();
447 assert(shape.size() == 3);
448 if (
failed(levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL")) ||
449 failed(levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL"))) {
459 LogicalResult levelCheckTransposeConv2d(Operation *op) {
460 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
461 if (ShapedType filterType =
462 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
463 auto shape = filterType.getShape();
464 assert(shape.size() == 4);
466 if (
failed(levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL")) ||
467 failed(levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL"))) {
471 for (
auto p : transpose.getOutPad()) {
472 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
476 for (
auto s : transpose.getStride()) {
477 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
486 LogicalResult levelCheckResize(Operation *op) {
487 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
488 SmallVector<int64_t> scale;
493 const int64_t scaleYN = scale[0];
494 const int64_t scaleYD = scale[1];
495 const int64_t scaleXN = scale[2];
496 const int64_t scaleXD = scale[3];
497 if (
failed(levelCheckScale(op, scaleYN / scaleYD,
498 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
499 failed(levelCheckScale(op, scaleXN / scaleXD,
500 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
511 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
512 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
520 getMaxNestedDepth(op, depth);
523 LogicalResult levelCheckMaxNesting(Operation *op) {
524 int32_t maxNestedDepth = 0;
525 getMaxNestedDepth(op, maxNestedDepth);
527 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
528 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
529 <<
" >= MAX_NESTING";
535 LogicalResult levelCheckListSize(Operation *op) {
536 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
537 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
539 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
540 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
542 failed(levelCheckListSize(op, custom.getOutputList().size(),
547 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
549 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
550 failed(levelCheckListSize(op, condIf.getOutputList().size(),
555 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
556 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
557 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
561 if (
auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
562 return levelCheckListSize(op, concat_shape.getInput().size(),
"input");
566 LogicalResult attributeCheckRescale(Operation *op) {
567 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
568 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
569 !targetEnv.allows(Extension::doubleround)) {
571 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
572 <<
"requires extension [doubleround]";
575 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
576 !targetEnv.allows(Extension::inexactround)) {
578 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
579 <<
"requires extension [inexactround]";
586 LogicalResult CheckVariable(Operation *op);
587 LogicalResult CheckVariableReadOrWrite(Operation *op);
588 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
591 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
594 TosaProfileCompliance profileComp;
595 tosa::TargetEnv targetEnv;
599LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
600 auto *op = tosaOp.getOperation();
601 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
606 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
614LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
615 auto *op = tosaOp.getOperation();
618 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
626LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
627 auto *op = tosaOp.getOperation();
629 if (
failed(levelCheckRank(op, variableType,
"variable type",
637LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
638 auto *op = tosaOp.getOperation();
640 if (
failed(levelCheckSize(op, variableType,
"variable type")))
646LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
647#define CHECK_RANKS_AND_SIZES(tosaOp) \
648 if (isa<tosa::tosaOp##Op>(op)) { \
649 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
651 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
655#define CHECK_SIZES(tosaOp) \
656 if (isa<tosa::tosaOp##Op>(op)) { \
657 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
661#define CHECK_SHAPE_LEN(tosaOp) \
662 if (isa<tosa::tosaOp##Op>(op)) { \
663 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
790#undef CHECK_RANKS_AND_SIZES
792#undef CHECK_SHAPE_LEN
797LogicalResult TosaValidation::levelCheckSize(Operation *op,
798 const Type &typeToCheck,
799 const StringRef operandOrResult) {
800 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
802 return op->
emitOpError() <<
"failed level check: unranked tensor";
803 auto shape = type.getShape();
804 for (
auto dim : shape) {
805 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
806 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
807 const TosaSpecificationVersion minRequiredVersion(1, 1);
817 return op->
emitOpError() <<
"failed level check: " << operandOrResult
818 <<
" shape dimension cannot be dynamic when"
819 <<
" targeting TOSA specification version 1.0"
824 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
825 int64_t size = element_bytes * type.getNumElements();
832 const int64_t max_size =
836 <<
"failed level check: " << operandOrResult
837 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
842LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
849 if (
failed(levelCheckRanksAndSizes(op)))
852 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
853 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
854 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
855 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
856 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
857 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
858 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
859 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op)) ||
860 failed(levelCheckConv2DBlockScaled(op))) {
865 if (
failed(levelCheckListSize(op))) {
869 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
870 if (
failed(levelCheckMaxNesting(op))) {
878LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
879 if (
failed(attributeCheckRescale(op)))
884inline bool CompatibleTypes(
const mlir::Type &type,
885 const mlir::Type &declaredType) {
887 return type == declaredType;
890LogicalResult TosaValidation::CheckVariable(Operation *op) {
891 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
892 mlir::StringAttr nameAttr = variableOp.getNameAttr();
894 if (variablesMap.count(nameAttr))
895 return op->
emitOpError() <<
"name has already been declared";
897 auto elementType = variableOp.getType();
898 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
899 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
900 RankedTensorType variableType =
901 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
903 variablesMap[nameAttr] = variableType;
909LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
910 if (isa<mlir::tosa::VariableReadOp>(op) ||
911 isa<mlir::tosa::VariableWriteOp>(op)) {
912 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
913 if (!variablesMap.count(nameAttr))
914 return op->
emitOpError() <<
"name has not been declared";
916 auto varType = variablesMap[nameAttr];
919 auto type = v.getType();
920 if (!CompatibleTypes(type, varType))
921 return op->
emitOpError() <<
"operand type does not equal variable type";
925 auto type = v.getType();
926 if (!CompatibleTypes(type, varType))
927 return op->
emitOpError() <<
"result type does not equal variable type";
934LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
935 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
940LogicalResult checkErrorIfResize(Operation *op) {
941 auto resize = dyn_cast<tosa::ResizeOp>(op);
945 const Value input = resize.getInput();
946 const Value output = resize.getOutput();
947 const RankedTensorType inputType =
948 llvm::dyn_cast<RankedTensorType>(input.
getType());
949 const RankedTensorType outputType =
950 llvm::dyn_cast<RankedTensorType>(output.
getType());
952 if (!inputType || !outputType)
953 return op->
emitOpError(
"expect ranked input/output tensor");
957 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
958 const SmallVector<int64_t, 4> sizes = {
959 outputType.getDimSize(1), outputType.getDimSize(2),
960 inputType.getDimSize(1), inputType.getDimSize(2)};
961 const int64_t *maxDim = llvm::max_element(sizes);
962 if (maxDim != sizes.end() && *maxDim >= 16384)
964 "expect input/output height/width dims to be < 16384, ")
965 <<
"got [OH, OW, IH, IW] = " << sizes;
968 SmallVector<int64_t> scale;
972 const int64_t scaleYN = scale[0];
973 const int64_t scaleYD = scale[1];
974 const int64_t scaleXN = scale[2];
975 const int64_t scaleXD = scale[3];
978 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
980 "expect all scale numerator values to be <= (1 << 11), "
982 << scaleYN <<
", scale_x_n=" << scaleXN;
984 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
985 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
986 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
988 SmallVector<int64_t> offset;
989 SmallVector<int64_t> border;
994 const int64_t offsetY = offset[0];
995 const int64_t offsetX = offset[1];
998 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1000 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1001 << offsetY <<
"/" << scaleYN;
1002 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1004 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1005 << offsetX <<
"/" << scaleXN;
1007 const int64_t borderY = border[0];
1008 const int64_t borderX = border[1];
1009 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1011 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1012 << borderY <<
"/" << scaleYN;
1013 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1015 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1016 << borderX <<
"/" << scaleXN;
1029 const int64_t
rhs) -> std::optional<int64_t> {
1031 return std::nullopt;
1035 const int64_t oh = outputType.getDimSize(1);
1036 const int64_t ow = outputType.getDimSize(2);
1037 const int64_t ih = inputType.getDimSize(1);
1038 const int64_t iw = inputType.getDimSize(2);
1040 if (ih != ShapedType::kDynamic) {
1041 const std::optional<int64_t> calculatedOutHeightMinusOne =
1042 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1043 if (!calculatedOutHeightMinusOne.has_value())
1045 "expected (input_height - 1) * scale_y_n - offset_y + "
1047 <<
"to be wholly divisible by scale_y_d, got ((" << ih
1048 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
1049 <<
") / " << scaleYD;
1050 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1051 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1053 "calculated output height did not match expected: ")
1054 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
1057 if (iw != ShapedType::kDynamic) {
1058 const std::optional<int64_t> calculatedOutWidthMinusOne =
1059 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1060 if (!calculatedOutWidthMinusOne.has_value())
1062 "expected (input_width - 1) * scale_x_n - offset_x + "
1064 <<
"to be wholly divisible by scale_x_d, got ((" << iw
1065 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
1066 <<
") / " << scaleXD;
1067 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1068 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1069 return op->
emitOpError(
"calculated output width did not match expected: ")
1070 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
1076LogicalResult checkErrorIfMul(Operation *op) {
1077 auto mul = dyn_cast<tosa::MulOp>(op);
1083 ElementsAttr shift_elem;
1086 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1088 if (inputElemType.isInteger(32)) {
1090 if (shift < 0 || shift > 63)
1092 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1097 <<
"requires shift = 0 for all input data types that "
1098 "are not int32_t, but got: "
1105LogicalResult checkErrorIfTable(Operation *op) {
1106 auto table = dyn_cast<tosa::TableOp>(op);
1112 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1114 const ShapeAdaptor tableShape(table.getTable().getType());
1115 if (tableShape.hasStaticShape()) {
1116 const auto numElements = tableShape.getNumElements();
1117 if (numElements != tableSize)
1118 return op->
emitOpError() <<
"requires table size of " << tableSize
1119 <<
", got " << numElements;
1125LogicalResult checkErrorIfRescale(Operation *op) {
1126 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1130 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1131 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1132 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1133 !outputType.getElementType().isInteger())
1136 auto inElemType = inputType.getElementType();
1137 auto outElemType = outputType.getElementType();
1138 auto inWidth = inElemType.getIntOrFloatBitWidth();
1139 auto outWidth = outElemType.getIntOrFloatBitWidth();
1141 bool inputUnsigned = rescale.getInputUnsigned();
1142 bool outputUnsigned = rescale.getOutputUnsigned();
1144 bool scale32 = rescale.getScale32();
1145 auto roundingMode = rescale.getRoundingMode();
1148 if (scale32 && inWidth == 48)
1149 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1152 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1154 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1157 if (inputUnsigned && outputUnsigned)
1158 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1161 if (outWidth == 32 && inputUnsigned)
1163 <<
"i32 output type is not allowed with unsigned input.";
1166 if (inWidth == 32 && outputUnsigned)
1168 <<
"i32 input type is not allowed with unsigned output.";
1171 if (inWidth == 48 && outputUnsigned)
1173 <<
"i48 input type is not allowed with unsigned output.";
1176 if (inWidth == 48 && inputUnsigned)
1177 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1180 if (inWidth == 32 && inputUnsigned)
1181 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1184 if (outWidth == 32 && outputUnsigned)
1185 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1190LogicalResult checkErrorIfPad(Operation *op) {
1191 auto pad = dyn_cast<tosa::PadOp>(op);
1195 DenseIntElementsAttr paddingAttr;
1200 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1201 if (val.getSExtValue() < 0)
1202 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1203 << val.getSExtValue();
1209static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1210 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1211 Region *operandRegion = operand.getParentRegion();
1212 return operandRegion && region->isAncestor(operandRegion);
1216static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1217 bool noLiveInValue =
true;
1218 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1219 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1220 noLiveInValue =
false;
1225 return noLiveInValue ?
success() : failure();
1228LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1229 StringRef regionName) {
1230 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1233 <<
"is not conformant to the TOSA specification. It requires the '"
1234 << regionName <<
"' region is isolated from above.\n";
1237LogicalResult checkErrorIfCondIf(Operation *op) {
1238 auto ifOp = dyn_cast<tosa::IfOp>(op);
1271 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1272 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1277LogicalResult checkErrorIfWhileLoop(Operation *op) {
1278 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1282 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1283 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1288LogicalResult checkErrorIfScatter(Operation *op) {
1289 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1294 DenseIntElementsAttr indicesAttr;
1298 auto const indicesType =
1299 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1300 if (!indicesType || !indicesType.hasRank()) {
1306 op->
emitOpError(
"indices values contain duplicates");
1313LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1314 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1315 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1316 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfCondIf(op)) ||
1317 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1322bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1323 if (isa<FloatType>(type)) {
1324 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1325 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1326 Float6E3M2FNType, Float8E8M0FNUType>(type);
1327 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1328 if (intTy.isSignless()) {
1329 switch (intTy.getWidth()) {
1339 }
else if (allowUnsigned && intTy.isUnsigned()) {
1340 switch (intTy.getWidth()) {
1347 }
else if (isa<tosa::shapeType>(type))
1349 else if (isa<tosa::mxint8Type>(type))
1354void TosaValidation::runOnOperation() {
1355 ModuleOp modOp = getOperation();
1356 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1361 const auto maybeTargetEnv =
1363 if (
failed(maybeTargetEnv))
1364 return signalPassFailure();
1365 targetEnv = *maybeTargetEnv;
1367 modOp.walk([&](Operation *op) {
1376 const bool allowUnsigned =
1377 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1380 if (!isValidElementType(elementTy, allowUnsigned)) {
1381 op->
emitOpError() <<
"is not profile-aligned: element type "
1382 << elementTy <<
" is not legal";
1383 return signalPassFailure();
1388 if (!isValidElementType(elementTy, allowUnsigned)) {
1389 op->
emitOpError() <<
"is not profile-aligned: element type "
1390 << elementTy <<
" is not legal";
1391 return signalPassFailure();
1395 if (strictOpSpecAlignment &&
1397 return signalPassFailure();
1399 if (strictOpSpecAlignment &&
1401 return signalPassFailure();
1403 if (!allowInvalidOpDatatypeCombinations &&
1405 return signalPassFailure();
1409 if (
failed(applyConstantOperandCheck(op)))
1410 signalPassFailure();
1413 if (
failed(applyLevelCheck(op)))
1414 signalPassFailure();
1417 if (
failed(applyAttributeCheck(op)))
1418 signalPassFailure();
1421 if (
failed(applyVariableCheck(op)))
1422 signalPassFailure();
1425 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1426 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.