29#include "llvm/ADT/StringExtras.h"
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45 for (
const auto index : operandIndices) {
48 return op->
emitOpError(
"expected compile time resolvable constant, but "
49 "got variable value for operand #")
56static LogicalResult checkConstantOperandMul(
Operation *op,
58 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 return checkConstantOperands(op, {2});
65static LogicalResult checkConstantOperandTable(
Operation *op,
67 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 return checkConstantOperands(op, {1});
74static LogicalResult checkConstantOperandPad(
Operation *op,
76 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
81 return checkConstantOperands(op, {2});
86static LogicalResult checkConstantOperandRescale(
Operation *op,
88 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 return checkConstantOperands(op, {1, 2, 3, 4});
96static LogicalResult checkConstantOperandConvOps(
Operation *op,
98 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
100 return checkConstantOperands(op, {3, 4});
105static LogicalResult checkConstantOperandMatMul(
Operation *op,
107 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 return checkConstantOperands(op, {2, 3});
114static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
116 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 return checkConstantOperands(op, {1, 2});
123static LogicalResult checkConstantOperandNegate(
Operation *op,
125 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 return checkConstantOperands(op, {1, 2});
138 explicit TosaValidation() { populateConstantOperandChecks(); }
140 explicit TosaValidation(
const TosaValidationOptions &
options)
142 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
143 this->allowInvalidOpDatatypeCombinations =
144 options.allowInvalidOpDatatypeCombinations;
146 void runOnOperation() final;
148 LogicalResult applyConstantOperandCheck(Operation *op) {
149 for (
auto &checker : constCheckers) {
150 if (
failed(checker(op, targetEnv)))
156 LogicalResult applyLevelCheck(Operation *op);
157 LogicalResult applyAttributeCheck(Operation *op);
160 LogicalResult applyVariableCheck(Operation *op);
163 LogicalResult applyErrorIfCheck(Operation *op);
166 void populateConstantOperandChecks() {
167 constCheckers.emplace_back(checkConstantOperandMul);
168 constCheckers.emplace_back(checkConstantOperandTable);
169 constCheckers.emplace_back(checkConstantOperandPad);
170 constCheckers.emplace_back(checkConstantOperandRescale);
171 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173 constCheckers.emplace_back(
174 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175 constCheckers.emplace_back(
176 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177 constCheckers.emplace_back(checkConstantOperandMatMul);
178 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179 constCheckers.emplace_back(checkConstantOperandNegate);
182 LogicalResult levelCheckKernel(Operation *op, int32_t v,
183 const StringRef checkDesc) {
184 if (v > targetEnv.getLevel().MAX_KERNEL)
185 return op->
emitOpError() <<
"failed level check: " << checkDesc;
189 LogicalResult levelCheckStride(Operation *op, int32_t v,
190 const StringRef checkDesc) {
191 if (v > targetEnv.getLevel().MAX_STRIDE)
192 return op->
emitOpError() <<
"failed level check: " << checkDesc;
196 LogicalResult levelCheckScale(Operation *op, int32_t v,
197 const StringRef checkDesc) {
198 if (v > targetEnv.getLevel().MAX_SCALE)
199 return op->
emitOpError() <<
"failed level check: " << checkDesc;
203 LogicalResult levelCheckListSize(Operation *op, int32_t v,
204 const StringRef checkDesc) {
205 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
207 <<
"failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
212 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
213 const StringRef operandOrResult,
214 int32_t highest_rank) {
215 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
217 return op->
emitOpError() <<
"failed level check: unranked tensor";
218 if (type.getRank() > highest_rank)
219 return op->
emitOpError() <<
"failed level check: " << operandOrResult
220 <<
" rank(shape) <= MAX_RANK";
226 LogicalResult levelCheckRank(Operation *op,
const Value &v,
227 const StringRef operandOrResult,
228 int32_t highest_rank) {
229 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
233 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
234 const StringRef operandOrResult);
237 LogicalResult levelCheckSize(Operation *op,
const Value &v,
238 const StringRef operandOrResult) {
239 return levelCheckSize(op, v.
getType(), operandOrResult);
243 template <
typename T>
244 LogicalResult levelCheckSizes(T tosaOp) {
245 auto op = tosaOp.getOperation();
247 if (
failed(levelCheckSize(op, v,
"operand")))
252 if (
failed(levelCheckSize(op, v,
"result")))
259 template <
typename T>
260 LogicalResult levelCheckRanks(T tosaOp) {
261 auto op = tosaOp.getOperation();
262 const TosaLevel tosaLevel = targetEnv.getLevel();
276 LogicalResult levelCheckRanksAndSizes(Operation *op);
279 template <
typename T>
280 LogicalResult levelCheckPool(Operation *op) {
281 if (
auto poolOp = dyn_cast<T>(op)) {
282 for (
auto k : poolOp.getKernel()) {
283 if (
failed(levelCheckKernel(op, k,
"kernel <= MAX_KERNEL"))) {
287 for (
auto s : poolOp.getStride()) {
288 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
292 for (
auto p : poolOp.getPad()) {
293 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
302 template <
typename T>
303 LogicalResult levelCheckConv(Operation *op) {
304 if (
auto convOp = dyn_cast<T>(op)) {
306 for (
auto k : convOp.getDilation()) {
307 if (
failed(levelCheckKernel(op, k,
"dilation <= MAX_KERNEL"))) {
311 for (
auto p : convOp.getPad()) {
312 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
316 for (
auto s : convOp.getStride()) {
317 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
321 auto dilation = convOp.getDilation();
322 if (ShapedType weightType =
324 auto shape = weightType.getShape();
325 if (isa<tosa::Conv2DOp>(op)) {
326 assert(shape.size() == 4);
327 assert(dilation.size() == 2);
328 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
329 "dilation_y * KH <= MAX_KERNEL)")) ||
330 failed(levelCheckKernel(op, dilation[1] * shape[2],
331 "dilation_x * KW <= MAX_KERNEL)")))
333 }
else if (isa<tosa::Conv3DOp>(op)) {
334 assert(shape.size() == 5);
335 assert(dilation.size() == 3);
336 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
337 "dilation_d * KD <= MAX_KERNEL)")) ||
338 failed(levelCheckKernel(op, dilation[1] * shape[2],
339 "dilation_y * KH <= MAX_KERNEL)")) ||
340 failed(levelCheckKernel(op, dilation[2] * shape[3],
341 "dilation_x * KW <= MAX_KERNEL)")))
343 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
344 assert(shape.size() == 4);
345 assert(dilation.size() == 2);
346 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
347 "dilation_y * KH <= MAX_KERNEL)")) ||
348 failed(levelCheckKernel(op, dilation[1] * shape[1],
349 "dilation_x * KW <= MAX_KERNEL)")))
358 template <
typename T>
359 LogicalResult levelCheckFFT(Operation *op) {
362 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
363 auto shape = type.getShape();
364 assert(shape.size() == 3);
365 if (
failed(levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL")) ||
366 failed(levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL"))) {
376 LogicalResult levelCheckTransposeConv2d(Operation *op) {
377 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
378 if (ShapedType filterType =
379 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
380 auto shape = filterType.getShape();
381 assert(shape.size() == 4);
383 if (
failed(levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL")) ||
384 failed(levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL"))) {
388 for (
auto p : transpose.getOutPad()) {
389 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
393 for (
auto s : transpose.getStride()) {
394 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
403 LogicalResult levelCheckResize(Operation *op) {
404 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
405 SmallVector<int64_t> scale;
410 const int64_t scaleYN = scale[0];
411 const int64_t scaleYD = scale[1];
412 const int64_t scaleXN = scale[2];
413 const int64_t scaleXD = scale[3];
414 if (
failed(levelCheckScale(op, scaleYN / scaleYD,
415 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
416 failed(levelCheckScale(op, scaleXN / scaleXD,
417 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
428 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
429 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
437 getMaxNestedDepth(op, depth);
440 LogicalResult levelCheckMaxNesting(Operation *op) {
441 int32_t maxNestedDepth = 0;
442 getMaxNestedDepth(op, maxNestedDepth);
444 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
445 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
446 <<
" >= MAX_NESTING";
452 LogicalResult levelCheckListSize(Operation *op) {
453 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
456 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
457 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
459 failed(levelCheckListSize(op, custom.getOutputList().size(),
464 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
466 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
467 failed(levelCheckListSize(op, condIf.getOutputList().size(),
472 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
473 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
474 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
481 LogicalResult attributeCheckRescale(Operation *op) {
482 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
483 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
484 !targetEnv.allows(Extension::doubleround)) {
486 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
487 <<
"requires extension [doubleround]";
490 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
491 !targetEnv.allows(Extension::inexactround)) {
493 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
494 <<
"requires extension [inexactround]";
501 LogicalResult CheckVariable(Operation *op);
502 LogicalResult CheckVariableReadOrWrite(Operation *op);
503 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
506 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
509 TosaProfileCompliance profileComp;
510 tosa::TargetEnv targetEnv;
514LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
515 auto *op = tosaOp.getOperation();
516 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
521 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
529LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
530 auto *op = tosaOp.getOperation();
533 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
541LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
542 auto *op = tosaOp.getOperation();
544 if (
failed(levelCheckRank(op, variableType,
"variable type",
552LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
553 auto *op = tosaOp.getOperation();
555 if (
failed(levelCheckSize(op, variableType,
"variable type")))
561LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
562#define CHECK_RANKS_AND_SIZES(tosaOp) \
563 if (isa<tosa::tosaOp##Op>(op)) { \
564 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
566 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
570#define CHECK_SIZES(tosaOp) \
571 if (isa<tosa::tosaOp##Op>(op)) { \
572 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
677#undef CHECK_RANKS_AND_SIZES
683LogicalResult TosaValidation::levelCheckSize(Operation *op,
684 const Type &typeToCheck,
685 const StringRef operandOrResult) {
686 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
688 return op->
emitOpError() <<
"failed level check: unranked tensor";
689 auto shape = type.getShape();
690 for (
auto dim : shape) {
691 if (mlir::ShapedType::isDynamic(dim))
692 return op->
emitOpError() <<
"failed level check: " << operandOrResult
693 <<
" shape dimension cannot be dynamic";
697 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
698 int64_t size = element_bytes * type.getNumElements();
705 const int64_t max_size =
709 <<
"failed level check: " << operandOrResult
710 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
715LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
722 if (
failed(levelCheckRanksAndSizes(op)))
726 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
727 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
728 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
729 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
730 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
731 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
732 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
733 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
738 if (failed(levelCheckListSize(op))) {
742 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
743 if (failed(levelCheckMaxNesting(op))) {
751LogicalResult TosaValidation::applyAttributeCheck(
Operation *op) {
752 if (failed(attributeCheckRescale(op)))
757inline bool CompatibleTypes(
const mlir::Type &type,
760 return type == declaredType;
763LogicalResult TosaValidation::CheckVariable(
Operation *op) {
764 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
765 mlir::StringAttr nameAttr = variableOp.getNameAttr();
767 if (variablesMap.count(nameAttr))
768 return op->
emitOpError() <<
"name has already been declared";
770 auto elementType = variableOp.getType();
773 RankedTensorType variableType =
776 variablesMap[nameAttr] = variableType;
782LogicalResult TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
783 if (isa<mlir::tosa::VariableReadOp>(op) ||
784 isa<mlir::tosa::VariableWriteOp>(op)) {
785 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
786 if (!variablesMap.count(nameAttr))
787 return op->
emitOpError() <<
"name has not been declared";
789 auto varType = variablesMap[nameAttr];
792 auto type = v.getType();
793 if (!CompatibleTypes(type, varType))
794 return op->
emitOpError() <<
"operand type does not equal variable type";
798 auto type = v.getType();
799 if (!CompatibleTypes(type, varType))
800 return op->
emitOpError() <<
"result type does not equal variable type";
807LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
808 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
813LogicalResult checkErrorIfResize(Operation *op) {
814 auto resize = dyn_cast<tosa::ResizeOp>(op);
818 const Value input = resize.getInput();
819 const Value output = resize.getOutput();
820 const RankedTensorType inputType =
821 llvm::dyn_cast<RankedTensorType>(input.
getType());
822 const RankedTensorType outputType =
823 llvm::dyn_cast<RankedTensorType>(output.
getType());
825 if (!inputType || !outputType)
826 return op->
emitOpError(
"expect ranked input/output tensor");
830 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
831 const SmallVector<int64_t, 4> sizes = {
832 outputType.getDimSize(1), outputType.getDimSize(2),
833 inputType.getDimSize(1), inputType.getDimSize(2)};
834 const int64_t *maxDim = llvm::max_element(sizes);
835 if (maxDim != sizes.end() && *maxDim >= 16384)
837 "expect input/output height/width dims to be < 16384, ")
838 <<
"got [OH, OW, IH, IW] = " << sizes;
841 SmallVector<int64_t> scale;
845 const int64_t scaleYN = scale[0];
846 const int64_t scaleYD = scale[1];
847 const int64_t scaleXN = scale[2];
848 const int64_t scaleXD = scale[3];
851 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
853 "expect all scale numerator values to be <= (1 << 11), "
855 << scaleYN <<
", scale_x_n=" << scaleXN;
857 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
858 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
859 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
861 SmallVector<int64_t> offset;
862 SmallVector<int64_t> border;
867 const int64_t offsetY = offset[0];
868 const int64_t offsetX = offset[1];
871 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
873 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
874 << offsetY <<
"/" << scaleYN;
875 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
877 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
878 << offsetX <<
"/" << scaleXN;
880 const int64_t borderY = border[0];
881 const int64_t borderX = border[1];
882 if (borderY < -16 * scaleYN || borderY >= scaleYN)
884 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
885 << borderY <<
"/" << scaleYN;
886 if (borderX < -16 * scaleXN || borderX >= scaleXN)
888 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
889 << borderX <<
"/" << scaleXN;
902 const int64_t
rhs) -> std::optional<int64_t> {
908 const int64_t oh = outputType.getDimSize(1);
909 const int64_t ow = outputType.getDimSize(2);
910 const int64_t ih = inputType.getDimSize(1);
911 const int64_t iw = inputType.getDimSize(2);
913 if (ih != ShapedType::kDynamic) {
914 const std::optional<int64_t> calculatedOutHeightMinusOne =
915 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
916 if (!calculatedOutHeightMinusOne.has_value())
918 "expected (input_height - 1) * scale_y_n - offset_y + "
920 <<
"to be wholly divisible by scale_y_d, got ((" << ih
921 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
922 <<
") / " << scaleYD;
923 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
924 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
926 "calculated output height did not match expected: ")
927 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
930 if (iw != ShapedType::kDynamic) {
931 const std::optional<int64_t> calculatedOutWidthMinusOne =
932 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
933 if (!calculatedOutWidthMinusOne.has_value())
935 "expected (input_width - 1) * scale_x_n - offset_x + "
937 <<
"to be wholly divisible by scale_x_d, got ((" << iw
938 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
939 <<
") / " << scaleXD;
940 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
941 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
942 return op->
emitOpError(
"calculated output width did not match expected: ")
943 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
949LogicalResult checkErrorIfMul(Operation *op) {
950 auto mul = dyn_cast<tosa::MulOp>(op);
956 ElementsAttr shift_elem;
959 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
961 if (inputElemType.isInteger(32)) {
963 if (shift < 0 || shift > 63)
965 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
970 <<
"requires shift = 0 for all input data types that "
971 "are not int32_t, but got: "
978LogicalResult checkErrorIfTable(Operation *op) {
979 auto table = dyn_cast<tosa::TableOp>(op);
985 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
987 const ShapeAdaptor tableShape(table.getTable().getType());
988 if (tableShape.hasStaticShape()) {
989 const auto numElements = tableShape.getNumElements();
990 if (numElements != tableSize)
991 return op->
emitOpError() <<
"requires table size of " << tableSize
992 <<
", got " << numElements;
998LogicalResult checkErrorIfRescale(Operation *op) {
999 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1003 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1004 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1005 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1006 !outputType.getElementType().isInteger())
1009 auto inElemType = inputType.getElementType();
1010 auto outElemType = outputType.getElementType();
1011 auto inWidth = inElemType.getIntOrFloatBitWidth();
1012 auto outWidth = outElemType.getIntOrFloatBitWidth();
1014 bool inputUnsigned = rescale.getInputUnsigned();
1015 bool outputUnsigned = rescale.getOutputUnsigned();
1017 bool scale32 = rescale.getScale32();
1018 auto roundingMode = rescale.getRoundingMode();
1021 if (scale32 && inWidth == 48)
1022 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1025 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1027 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1030 if (inputUnsigned && outputUnsigned)
1031 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1034 if (outWidth == 32 && inputUnsigned)
1036 <<
"i32 output type is not allowed with unsigned input.";
1039 if (inWidth == 32 && outputUnsigned)
1041 <<
"i32 input type is not allowed with unsigned output.";
1044 if (inWidth == 48 && outputUnsigned)
1046 <<
"i48 input type is not allowed with unsigned output.";
1049 if (inWidth == 48 && inputUnsigned)
1050 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1053 if (inWidth == 32 && inputUnsigned)
1054 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1057 if (outWidth == 32 && outputUnsigned)
1058 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1063LogicalResult checkErrorIfPad(Operation *op) {
1064 auto pad = dyn_cast<tosa::PadOp>(op);
1068 DenseIntElementsAttr paddingAttr;
1073 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1074 if (val.getSExtValue() < 0)
1075 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1076 << val.getSExtValue();
1082static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1083 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1084 Region *operandRegion = operand.getParentRegion();
1085 return operandRegion && region->isAncestor(operandRegion);
1089static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1090 bool noLiveInValue =
true;
1091 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1092 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1093 noLiveInValue =
false;
1098 return noLiveInValue ?
success() : failure();
1101LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1102 StringRef regionName) {
1103 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1106 <<
"is not conformant to the TOSA specification. It requires the '"
1107 << regionName <<
"' region is isolated from above.\n";
1110LogicalResult checkErrorIfCondIf(Operation *op) {
1111 auto ifOp = dyn_cast<tosa::IfOp>(op);
1144 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1145 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1150LogicalResult checkErrorIfWhileLoop(Operation *op) {
1151 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1155 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1156 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1161LogicalResult checkErrorIfScatter(Operation *op) {
1162 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1167 DenseIntElementsAttr indicesAttr;
1171 auto const indicesType =
1172 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1173 if (!indicesType || !indicesType.hasRank()) {
1179 op->
emitOpError(
"indices values contain duplicates");
1186LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1187 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1188 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1189 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfCondIf(op)) ||
1190 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1195bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1196 if (isa<FloatType>(type)) {
1197 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1198 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1199 Float6E3M2FNType, Float8E8M0FNUType>(type);
1200 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1201 if (intTy.isSignless()) {
1202 switch (intTy.getWidth()) {
1212 }
else if (allowUnsigned && intTy.isUnsigned()) {
1213 switch (intTy.getWidth()) {
1220 }
else if (isa<tosa::shapeType>(type))
1222 else if (isa<tosa::mxint8Type>(type))
1227void TosaValidation::runOnOperation() {
1228 ModuleOp modOp = getOperation();
1230 const auto maybeTargetEnv =
1232 if (
failed(maybeTargetEnv))
1233 return signalPassFailure();
1234 targetEnv = *maybeTargetEnv;
1236 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1240 modOp.walk([&](Operation *op) {
1249 const bool allowUnsigned =
1250 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1253 if (!isValidElementType(elementTy, allowUnsigned)) {
1254 op->
emitOpError() <<
"is not profile-aligned: element type "
1255 << elementTy <<
" is not legal";
1256 return signalPassFailure();
1261 if (!isValidElementType(elementTy, allowUnsigned)) {
1262 op->
emitOpError() <<
"is not profile-aligned: element type "
1263 << elementTy <<
" is not legal";
1264 return signalPassFailure();
1268 if (strictOpSpecAlignment &&
1270 return signalPassFailure();
1272 if (strictOpSpecAlignment &&
1274 return signalPassFailure();
1276 if (!allowInvalidOpDatatypeCombinations &&
1278 return signalPassFailure();
1282 if (
failed(applyConstantOperandCheck(op)))
1283 signalPassFailure();
1286 if (
failed(applyLevelCheck(op)))
1287 signalPassFailure();
1290 if (
failed(applyAttributeCheck(op)))
1291 signalPassFailure();
1294 if (
failed(applyVariableCheck(op)))
1295 signalPassFailure();
1298 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1299 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.