29 #include "llvm/ADT/StringExtras.h"
33 #define GEN_PASS_DEF_TOSAVALIDATION
34 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45 for (
const auto index : operandIndices) {
48 return op->
emitOpError(
"expected compile time resolvable constant, but "
49 "got variable value for operand #")
56 static LogicalResult checkConstantOperandMul(
Operation *op,
58 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 return checkConstantOperands(op, {2});
65 static LogicalResult checkConstantOperandTable(
Operation *op,
67 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 return checkConstantOperands(op, {1});
74 static LogicalResult checkConstantOperandPad(
Operation *op,
76 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
81 return checkConstantOperands(op, {2});
86 static LogicalResult checkConstantOperandRescale(
Operation *op,
88 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 return checkConstantOperands(op, {1, 2, 3, 4});
96 static LogicalResult checkConstantOperandConvOps(
Operation *op,
98 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
100 return checkConstantOperands(op, {3, 4});
105 static LogicalResult checkConstantOperandMatMul(
Operation *op,
107 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 return checkConstantOperands(op, {2, 3});
114 static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
116 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 return checkConstantOperands(op, {1, 2});
123 static LogicalResult checkConstantOperandNegate(
Operation *op,
125 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 return checkConstantOperands(op, {1, 2});
136 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
138 explicit TosaValidation() { populateConstantOperandChecks(); }
140 explicit TosaValidation(
const TosaValidationOptions &
options)
142 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
143 this->allowInvalidOpDatatypeCombinations =
144 options.allowInvalidOpDatatypeCombinations;
146 void runOnOperation() final;
148 LogicalResult applyConstantOperandCheck(
Operation *op) {
149 for (
auto &checker : constCheckers) {
150 if (
failed(checker(op, targetEnv)))
156 LogicalResult applyLevelCheck(
Operation *op);
157 LogicalResult applyAttributeCheck(
Operation *op);
160 LogicalResult applyVariableCheck(
Operation *op);
163 LogicalResult applyErrorIfCheck(
Operation *op);
166 void populateConstantOperandChecks() {
167 constCheckers.emplace_back(checkConstantOperandMul);
168 constCheckers.emplace_back(checkConstantOperandTable);
169 constCheckers.emplace_back(checkConstantOperandPad);
170 constCheckers.emplace_back(checkConstantOperandRescale);
171 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173 constCheckers.emplace_back(
174 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175 constCheckers.emplace_back(
176 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177 constCheckers.emplace_back(checkConstantOperandMatMul);
178 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179 constCheckers.emplace_back(checkConstantOperandNegate);
182 LogicalResult levelCheckKernel(
Operation *op, int32_t v,
183 const StringRef checkDesc) {
184 if (v > targetEnv.getLevel().MAX_KERNEL)
185 return op->
emitOpError() <<
"failed level check: " << checkDesc;
189 LogicalResult levelCheckStride(
Operation *op, int32_t v,
190 const StringRef checkDesc) {
191 if (v > targetEnv.getLevel().MAX_STRIDE)
192 return op->
emitOpError() <<
"failed level check: " << checkDesc;
196 LogicalResult levelCheckScale(
Operation *op, int32_t v,
197 const StringRef checkDesc) {
198 if (v > targetEnv.getLevel().MAX_SCALE)
199 return op->
emitOpError() <<
"failed level check: " << checkDesc;
203 LogicalResult levelCheckListSize(
Operation *op, int32_t v,
204 const StringRef checkDesc) {
205 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
207 <<
"failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
212 LogicalResult levelCheckRank(
Operation *op,
const Type typeToCheck,
213 const StringRef operandOrResult,
214 int32_t highest_rank) {
215 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
217 return op->
emitOpError() <<
"failed level check: unranked tensor";
218 if (type.getRank() > highest_rank)
219 return op->
emitOpError() <<
"failed level check: " << operandOrResult
220 <<
" rank(shape) <= MAX_RANK";
227 const StringRef operandOrResult,
228 int32_t highest_rank) {
229 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
233 LogicalResult levelCheckSize(
Operation *op,
const Type &typeToCheck,
234 const StringRef operandOrResult);
238 const StringRef operandOrResult) {
239 return levelCheckSize(op, v.
getType(), operandOrResult);
243 template <
typename T>
244 LogicalResult levelCheckSizes(T tosaOp) {
245 auto op = tosaOp.getOperation();
247 if (
failed(levelCheckSize(op, v,
"operand")))
252 if (
failed(levelCheckSize(op, v,
"result")))
259 template <
typename T>
260 LogicalResult levelCheckRanks(T tosaOp) {
261 auto op = tosaOp.getOperation();
262 const TosaLevel tosaLevel = targetEnv.getLevel();
276 LogicalResult levelCheckRanksAndSizes(
Operation *op);
279 template <
typename T>
280 LogicalResult levelCheckPool(
Operation *op) {
281 if (
auto poolOp = dyn_cast<T>(op)) {
282 for (
auto k : poolOp.getKernel()) {
283 if (
failed(levelCheckKernel(op, k,
"kernel <= MAX_KERNEL"))) {
287 for (
auto s : poolOp.getStride()) {
288 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
292 for (
auto p : poolOp.getPad()) {
293 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
302 template <
typename T>
303 LogicalResult levelCheckConv(
Operation *op) {
304 if (
auto convOp = dyn_cast<T>(op)) {
306 for (
auto k : convOp.getDilation()) {
307 if (
failed(levelCheckKernel(op, k,
"dilation <= MAX_KERNEL"))) {
311 for (
auto p : convOp.getPad()) {
312 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
316 for (
auto s : convOp.getStride()) {
317 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
321 auto dilation = convOp.getDilation();
322 if (ShapedType weightType =
324 auto shape = weightType.getShape();
325 if (isa<tosa::Conv2DOp>(op)) {
326 assert(shape.size() == 4);
327 assert(dilation.size() == 2);
328 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
329 "dilation_y * KH <= MAX_KERNEL)")) ||
330 failed(levelCheckKernel(op, dilation[1] * shape[2],
331 "dilation_x * KW <= MAX_KERNEL)")))
333 }
else if (isa<tosa::Conv3DOp>(op)) {
334 assert(shape.size() == 5);
335 assert(dilation.size() == 3);
336 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
337 "dilation_d * KD <= MAX_KERNEL)")) ||
338 failed(levelCheckKernel(op, dilation[1] * shape[2],
339 "dilation_y * KH <= MAX_KERNEL)")) ||
340 failed(levelCheckKernel(op, dilation[2] * shape[3],
341 "dilation_x * KW <= MAX_KERNEL)")))
343 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
344 assert(shape.size() == 4);
345 assert(dilation.size() == 2);
346 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
347 "dilation_y * KH <= MAX_KERNEL)")) ||
348 failed(levelCheckKernel(op, dilation[1] * shape[1],
349 "dilation_x * KW <= MAX_KERNEL)")))
358 template <
typename T>
359 LogicalResult levelCheckFFT(
Operation *op) {
362 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
363 auto shape = type.getShape();
364 assert(shape.size() == 3);
365 if (
failed(levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL")) ||
366 failed(levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL"))) {
376 LogicalResult levelCheckTransposeConv2d(
Operation *op) {
377 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
378 if (ShapedType filterType =
379 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
380 auto shape = filterType.getShape();
381 assert(shape.size() == 4);
383 if (
failed(levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL")) ||
384 failed(levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL"))) {
388 for (
auto p : transpose.getOutPad()) {
389 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
393 for (
auto s : transpose.getStride()) {
394 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
403 LogicalResult levelCheckResize(
Operation *op) {
404 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
410 const int64_t scaleYN = scale[0];
411 const int64_t scaleYD = scale[1];
412 const int64_t scaleXN = scale[2];
413 const int64_t scaleXD = scale[3];
414 if (
failed(levelCheckScale(op, scaleYN / scaleYD,
415 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
416 failed(levelCheckScale(op, scaleXN / scaleXD,
417 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
428 static void getMaxNestedDepth(
Operation *op, int32_t &depth) {
429 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
437 getMaxNestedDepth(op, depth);
440 LogicalResult levelCheckMaxNesting(
Operation *op) {
441 int32_t maxNestedDepth = 0;
442 getMaxNestedDepth(op, maxNestedDepth);
444 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
445 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
446 <<
" >= MAX_NESTING";
452 LogicalResult levelCheckListSize(
Operation *op) {
453 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
456 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
457 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
459 failed(levelCheckListSize(op, custom.getOutputList().size(),
464 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
466 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
467 failed(levelCheckListSize(op, condIf.getOutputList().size(),
472 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
473 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
474 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
481 LogicalResult attributeCheckRescale(
Operation *op) {
482 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
483 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
484 !targetEnv.allows(Extension::doubleround)) {
486 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
487 <<
"requires extension [doubleround]";
490 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
491 !targetEnv.allows(Extension::inexactround)) {
493 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
494 <<
"requires extension [inexactround]";
501 LogicalResult CheckVariable(
Operation *op);
502 LogicalResult CheckVariableReadOrWrite(
Operation *op);
503 bool isValidElementType(
Type type,
const bool allowUnsigned =
false);
514 LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
515 auto *op = tosaOp.getOperation();
516 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
517 targetEnv.getLevel().MAX_RANK)))
521 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
522 targetEnv.getLevel().MAX_RANK - 1)))
529 LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
530 auto *op = tosaOp.getOperation();
533 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
534 targetEnv.getLevel().MAX_RANK)))
541 LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
542 auto *op = tosaOp.getOperation();
544 if (
failed(levelCheckRank(op, variableType,
"variable type",
545 targetEnv.getLevel().MAX_RANK)))
552 LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
553 auto *op = tosaOp.getOperation();
555 if (
failed(levelCheckSize(op, variableType,
"variable type")))
561 LogicalResult TosaValidation::levelCheckRanksAndSizes(
Operation *op) {
562 #define CHECK_RANKS_AND_SIZES(tosaOp) \
563 if (isa<tosa::tosaOp##Op>(op)) { \
564 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
566 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
570 #define CHECK_SIZES(tosaOp) \
571 if (isa<tosa::tosaOp##Op>(op)) { \
572 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
674 #undef CHECK_RANKS_AND_SIZES
680 LogicalResult TosaValidation::levelCheckSize(
Operation *op,
681 const Type &typeToCheck,
682 const StringRef operandOrResult) {
683 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
685 return op->
emitOpError() <<
"failed level check: unranked tensor";
686 auto shape = type.getShape();
687 for (
auto dim : shape) {
688 if (mlir::ShapedType::isDynamic(dim))
689 return op->
emitOpError() <<
"failed level check: " << operandOrResult
690 <<
" shape dimension cannot be dynamic";
693 int64_t element_bits = type.getElementTypeBitWidth();
694 int64_t element_bytes =
std::max(INT64_C(1), element_bits / 8);
695 int64_t size = element_bytes * type.getNumElements();
702 const int64_t max_size =
703 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
706 <<
"failed level check: " << operandOrResult
707 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
712 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
719 if (
failed(levelCheckRanksAndSizes(op)))
723 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
724 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
725 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
726 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
727 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
728 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
729 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
730 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op))) {
735 if (
failed(levelCheckListSize(op))) {
739 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
740 if (
failed(levelCheckMaxNesting(op))) {
748 LogicalResult TosaValidation::applyAttributeCheck(
Operation *op) {
749 if (
failed(attributeCheckRescale(op)))
754 inline bool CompatibleTypes(
const mlir::Type &type,
757 return type == declaredType;
760 LogicalResult TosaValidation::CheckVariable(
Operation *op) {
761 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
762 mlir::StringAttr nameAttr = variableOp.getNameAttr();
764 if (variablesMap.count(nameAttr))
765 return op->
emitOpError() <<
"name has already been declared";
767 auto elementType = variableOp.getType();
770 RankedTensorType variableType =
773 variablesMap[nameAttr] = variableType;
779 LogicalResult TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
780 if (isa<mlir::tosa::VariableReadOp>(op) ||
781 isa<mlir::tosa::VariableWriteOp>(op)) {
782 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
783 if (!variablesMap.count(nameAttr))
784 return op->
emitOpError() <<
"name has not been declared";
786 auto varType = variablesMap[nameAttr];
789 auto type = v.getType();
790 if (!CompatibleTypes(type, varType))
791 return op->
emitOpError() <<
"operand type does not equal variable type";
795 auto type = v.getType();
796 if (!CompatibleTypes(type, varType))
797 return op->
emitOpError() <<
"result type does not equal variable type";
804 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
805 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
810 LogicalResult checkErrorIfResize(
Operation *op) {
811 auto resize = dyn_cast<tosa::ResizeOp>(op);
815 const Value input = resize.getInput();
816 const Value output = resize.getOutput();
817 const RankedTensorType inputType =
818 llvm::dyn_cast<RankedTensorType>(input.
getType());
819 const RankedTensorType outputType =
820 llvm::dyn_cast<RankedTensorType>(output.getType());
822 if (!inputType || !outputType)
823 return op->
emitOpError(
"expect ranked input/output tensor");
827 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
829 outputType.getDimSize(1), outputType.getDimSize(2),
830 inputType.getDimSize(1), inputType.getDimSize(2)};
831 const int64_t *maxDim = llvm::max_element(sizes);
832 if (maxDim != sizes.end() && *maxDim >= 16384)
834 "expect input/output height/width dims to be < 16384, ")
835 <<
"got [OH, OW, IH, IW] = " << sizes;
842 const int64_t scaleYN = scale[0];
843 const int64_t scaleYD = scale[1];
844 const int64_t scaleXN = scale[2];
845 const int64_t scaleXD = scale[3];
848 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
850 "expect all scale numerator values to be <= (1 << 11), "
852 << scaleYN <<
", scale_x_n=" << scaleXN;
854 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
855 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
856 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
864 const int64_t offsetY = offset[0];
865 const int64_t offsetX = offset[1];
868 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
870 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
871 << offsetY <<
"/" << scaleYN;
872 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
874 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
875 << offsetX <<
"/" << scaleXN;
877 const int64_t borderY = border[0];
878 const int64_t borderX = border[1];
879 if (borderY < -16 * scaleYN || borderY >= scaleYN)
881 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
882 << borderY <<
"/" << scaleYN;
883 if (borderX < -16 * scaleXN || borderX >= scaleXN)
885 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
886 << borderX <<
"/" << scaleXN;
899 const int64_t rhs) -> std::optional<int64_t> {
905 const int64_t oh = outputType.getDimSize(1);
906 const int64_t ow = outputType.getDimSize(2);
907 const int64_t ih = inputType.getDimSize(1);
908 const int64_t iw = inputType.getDimSize(2);
910 if (ih != ShapedType::kDynamic) {
911 const std::optional<int64_t> calculatedOutHeightMinusOne =
912 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
913 if (!calculatedOutHeightMinusOne.has_value())
915 "expected (input_height - 1) * scale_y_n - offset_y + "
917 <<
"to be wholly divisible by scale_y_d, got ((" << ih
918 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
919 <<
") / " << scaleYD;
920 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
921 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
923 "calculated output height did not match expected: ")
924 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
927 if (iw != ShapedType::kDynamic) {
928 const std::optional<int64_t> calculatedOutWidthMinusOne =
929 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
930 if (!calculatedOutWidthMinusOne.has_value())
932 "expected (input_width - 1) * scale_x_n - offset_x + "
934 <<
"to be wholly divisible by scale_x_d, got ((" << iw
935 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
936 <<
") / " << scaleXD;
937 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
938 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
939 return op->
emitOpError(
"calculated output width did not match expected: ")
940 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
946 LogicalResult checkErrorIfMul(
Operation *op) {
947 auto mul = dyn_cast<tosa::MulOp>(op);
953 ElementsAttr shift_elem;
956 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
958 if (inputElemType.isInteger(32)) {
960 if (shift < 0 || shift > 63)
962 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
967 <<
"requires shift = 0 for all input data types that "
968 "are not int32_t, but got: "
975 LogicalResult checkErrorIfTable(
Operation *op) {
976 auto table = dyn_cast<tosa::TableOp>(op);
982 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
985 if (tableShape.hasStaticShape()) {
986 const auto numElements = tableShape.getNumElements();
987 if (numElements != tableSize)
988 return op->
emitOpError() <<
"requires table size of " << tableSize
989 <<
", got " << numElements;
995 LogicalResult checkErrorIfRescale(
Operation *op) {
996 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1000 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1001 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1002 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1003 !outputType.getElementType().isInteger())
1006 auto inElemType = inputType.getElementType();
1007 auto outElemType = outputType.getElementType();
1008 auto inWidth = inElemType.getIntOrFloatBitWidth();
1009 auto outWidth = outElemType.getIntOrFloatBitWidth();
1011 bool inputUnsigned = rescale.getInputUnsigned();
1012 bool outputUnsigned = rescale.getOutputUnsigned();
1014 bool scale32 = rescale.getScale32();
1015 auto roundingMode = rescale.getRoundingMode();
1018 if (scale32 && inWidth == 48)
1019 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1022 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1024 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1027 if (inputUnsigned && outputUnsigned)
1028 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1031 if (outWidth == 32 && inputUnsigned)
1033 <<
"i32 output type is not allowed with unsigned input.";
1036 if (inWidth == 32 && outputUnsigned)
1038 <<
"i32 input type is not allowed with unsigned output.";
1041 if (inWidth == 48 && outputUnsigned)
1043 <<
"i48 input type is not allowed with unsigned output.";
1046 if (inWidth == 48 && inputUnsigned)
1047 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1050 if (inWidth == 32 && inputUnsigned)
1051 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1054 if (outWidth == 32 && outputUnsigned)
1055 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1060 LogicalResult checkErrorIfPad(
Operation *op) {
1061 auto pad = dyn_cast<tosa::PadOp>(op);
1070 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1071 if (val.getSExtValue() < 0)
1072 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1073 << val.getSExtValue();
1080 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1081 Region *operandRegion = operand.getParentRegion();
1082 return operandRegion && region->isAncestor(operandRegion);
1086 static LogicalResult isRegionIsolatedFromAbove(
Region ®ionToCheck) {
1087 bool noLiveInValue =
true;
1088 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](
Operation *op) {
1089 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1090 noLiveInValue =
false;
1095 return noLiveInValue ? success() : failure();
1098 LogicalResult checkIsolatedRegion(
Operation *op,
Region ®ionToCheck,
1099 StringRef regionName) {
1100 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1103 <<
"is not conformant to the TOSA specification. It requires the '"
1104 << regionName <<
"' region is isolated from above.\n";
1107 LogicalResult checkErrorIfCondIf(
Operation *op) {
1108 auto ifOp = dyn_cast<tosa::IfOp>(op);
1141 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1142 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1147 LogicalResult checkErrorIfWhileLoop(
Operation *op) {
1148 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1152 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1153 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1158 LogicalResult checkErrorIfScatter(
Operation *op) {
1159 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1168 auto const indicesType =
1169 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1170 if (!indicesType || !indicesType.hasRank()) {
1176 op->
emitOpError(
"indices values contain duplicates");
1183 LogicalResult TosaValidation::applyErrorIfCheck(
Operation *op) {
1184 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1185 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1186 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfCondIf(op)) ||
1187 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1192 bool TosaValidation::isValidElementType(
Type type,
const bool allowUnsigned) {
1193 if (isa<FloatType>(type)) {
1194 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1195 Float8E5M2Type>(type);
1197 if (
auto intTy = dyn_cast<IntegerType>(type)) {
1198 if (intTy.isSignless()) {
1199 switch (intTy.getWidth()) {
1208 }
else if (allowUnsigned && intTy.isUnsigned()) {
1209 switch (intTy.getWidth()) {
1216 }
else if (mlir::isa<tosa::shapeType>(type)) {
1222 void TosaValidation::runOnOperation() {
1229 getOperation().walk([&](
Operation *op) {
1238 const bool allowUnsigned =
1239 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1241 auto elementTy = getElementTypeOrSelf(operand);
1242 if (!isValidElementType(elementTy, allowUnsigned)) {
1243 op->emitOpError() <<
"is not profile-aligned: element type "
1244 << elementTy <<
" is not legal";
1245 return signalPassFailure();
1249 auto elementTy = getElementTypeOrSelf(resultTy);
1250 if (!isValidElementType(elementTy, allowUnsigned)) {
1251 op->emitOpError() <<
"is not profile-aligned: element type "
1252 << elementTy <<
" is not legal";
1253 return signalPassFailure();
1257 if (strictOpSpecAlignment &&
1258 failed(profileComp.checkProfile(op, targetEnv)))
1259 return signalPassFailure();
1261 if (strictOpSpecAlignment &&
1262 failed(profileComp.checkExtension(op, targetEnv)))
1263 return signalPassFailure();
1265 if (!allowInvalidOpDatatypeCombinations &&
1266 failed(profileComp.checkInvalid(op)))
1267 return signalPassFailure();
1271 if (
failed(applyConstantOperandCheck(op)))
1272 signalPassFailure();
1275 if (
failed(applyLevelCheck(op)))
1276 signalPassFailure();
1279 if (
failed(applyAttributeCheck(op)))
1280 signalPassFailure();
1283 if (
failed(applyVariableCheck(op)))
1284 signalPassFailure();
1287 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1288 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
bool allows(Profile prof) const
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.