29#include "llvm/ADT/StringExtras.h"
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45 for (
const auto index : operandIndices) {
48 return op->
emitOpError(
"expected compile time resolvable constant, but "
49 "got variable value for operand #")
56static LogicalResult checkConstantOperandMul(
Operation *op,
58 if (!env.
allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 return checkConstantOperands(op, {2});
65static LogicalResult checkConstantOperandTable(
Operation *op,
67 if (!env.
allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 return checkConstantOperands(op, {1});
74static LogicalResult checkConstantOperandPad(
Operation *op,
76 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 if (!env.
allows(Extension::dynamic) && padOp.getPadConst())
81 return checkConstantOperands(op, {2});
86static LogicalResult checkConstantOperandRescale(
Operation *op,
88 if (!env.
allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 return checkConstantOperands(op, {1, 2, 3, 4});
96static LogicalResult checkConstantOperandConvOps(
Operation *op,
98 if (!env.
allows(Extension::dynamic) && isa<T>(op)) {
100 return checkConstantOperands(op, {3, 4});
105static LogicalResult checkConstantOperandMatMul(
Operation *op,
107 if (!env.
allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 return checkConstantOperands(op, {2, 3});
114static LogicalResult checkConstantOperandAvgPool2d(
Operation *op,
116 if (!env.
allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 return checkConstantOperands(op, {1, 2});
123static LogicalResult checkConstantOperandNegate(
Operation *op,
125 if (!env.
allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 return checkConstantOperands(op, {1, 2});
136struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
138 explicit TosaValidation() { populateConstantOperandChecks(); }
140 explicit TosaValidation(
const TosaValidationOptions &
options)
142 this->strictOpSpecAlignment =
options.strictOpSpecAlignment;
143 this->allowInvalidOpDatatypeCombinations =
144 options.allowInvalidOpDatatypeCombinations;
146 void runOnOperation() final;
148 LogicalResult applyConstantOperandCheck(Operation *op) {
149 for (
auto &checker : constCheckers) {
150 if (
failed(checker(op, targetEnv)))
156 LogicalResult applyLevelCheck(Operation *op);
157 LogicalResult applyAttributeCheck(Operation *op);
160 LogicalResult applyVariableCheck(Operation *op);
163 LogicalResult applyErrorIfCheck(Operation *op);
166 void populateConstantOperandChecks() {
167 constCheckers.emplace_back(checkConstantOperandMul);
168 constCheckers.emplace_back(checkConstantOperandTable);
169 constCheckers.emplace_back(checkConstantOperandPad);
170 constCheckers.emplace_back(checkConstantOperandRescale);
171 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173 constCheckers.emplace_back(
174 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175 constCheckers.emplace_back(
176 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177 constCheckers.emplace_back(checkConstantOperandMatMul);
178 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179 constCheckers.emplace_back(checkConstantOperandNegate);
182 LogicalResult levelCheckKernel(Operation *op, int32_t v,
183 const StringRef checkDesc) {
184 if (v > targetEnv.getLevel().MAX_KERNEL)
185 return op->
emitOpError() <<
"failed level check: " << checkDesc;
189 LogicalResult levelCheckStride(Operation *op, int32_t v,
190 const StringRef checkDesc) {
191 if (v > targetEnv.getLevel().MAX_STRIDE)
192 return op->
emitOpError() <<
"failed level check: " << checkDesc;
196 LogicalResult levelCheckScale(Operation *op, int32_t v,
197 const StringRef checkDesc) {
198 if (v > targetEnv.getLevel().MAX_SCALE)
199 return op->
emitOpError() <<
"failed level check: " << checkDesc;
203 LogicalResult levelCheckListSize(Operation *op, int32_t v,
204 const StringRef checkDesc) {
205 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
207 <<
"failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
212 LogicalResult levelCheckRank(Operation *op,
const Type typeToCheck,
213 const StringRef operandOrResult,
214 int32_t highest_rank) {
215 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
217 return op->
emitOpError() <<
"failed level check: unranked tensor";
218 if (type.getRank() > highest_rank)
219 return op->
emitOpError() <<
"failed level check: " << operandOrResult
220 <<
" rank(shape) <= MAX_RANK";
221 }
else if (tosa::shapeType shapeType =
222 dyn_cast<tosa::shapeType>(typeToCheck)) {
223 if (shapeType.getRank() > highest_rank)
225 <<
"failed shape type level check: " << typeToCheck
226 <<
" exceeds MAX_RANK";
232 LogicalResult levelCheckRank(Operation *op,
const Value &v,
233 const StringRef operandOrResult,
234 int32_t highest_rank) {
235 return levelCheckRank(op, v.
getType(), operandOrResult, highest_rank);
239 LogicalResult levelCheckSize(Operation *op,
const Type &typeToCheck,
240 const StringRef operandOrResult);
243 LogicalResult levelCheckSize(Operation *op,
const Value &v,
244 const StringRef operandOrResult) {
245 return levelCheckSize(op, v.
getType(), operandOrResult);
249 template <
typename T>
250 LogicalResult levelCheckSizes(T tosaOp) {
251 auto op = tosaOp.getOperation();
253 if (
failed(levelCheckSize(op, v,
"operand")))
258 if (
failed(levelCheckSize(op, v,
"result")))
265 template <
typename T>
266 LogicalResult levelCheckRanks(T tosaOp) {
267 auto op = tosaOp.getOperation();
268 const TosaLevel tosaLevel = targetEnv.getLevel();
282 LogicalResult levelCheckRanksAndSizes(Operation *op);
285 template <
typename T>
286 LogicalResult levelCheckPool(Operation *op) {
287 if (
auto poolOp = dyn_cast<T>(op)) {
288 for (
auto k : poolOp.getKernel()) {
289 if (
failed(levelCheckKernel(op, k,
"kernel <= MAX_KERNEL"))) {
293 for (
auto s : poolOp.getStride()) {
294 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
298 for (
auto p : poolOp.getPad()) {
299 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
308 template <
typename T>
309 LogicalResult levelCheckConv(Operation *op) {
310 if (
auto convOp = dyn_cast<T>(op)) {
312 for (
auto k : convOp.getDilation()) {
313 if (
failed(levelCheckKernel(op, k,
"dilation <= MAX_KERNEL"))) {
317 for (
auto p : convOp.getPad()) {
318 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
322 for (
auto s : convOp.getStride()) {
323 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
327 auto dilation = convOp.getDilation();
328 if (ShapedType weightType =
330 auto shape = weightType.getShape();
331 if (isa<tosa::Conv2DOp>(op)) {
332 assert(shape.size() == 4);
333 assert(dilation.size() == 2);
334 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
335 "dilation_y * KH <= MAX_KERNEL)")) ||
336 failed(levelCheckKernel(op, dilation[1] * shape[2],
337 "dilation_x * KW <= MAX_KERNEL)")))
339 }
else if (isa<tosa::Conv3DOp>(op)) {
340 assert(shape.size() == 5);
341 assert(dilation.size() == 3);
342 if (
failed(levelCheckKernel(op, dilation[0] * shape[1],
343 "dilation_d * KD <= MAX_KERNEL)")) ||
344 failed(levelCheckKernel(op, dilation[1] * shape[2],
345 "dilation_y * KH <= MAX_KERNEL)")) ||
346 failed(levelCheckKernel(op, dilation[2] * shape[3],
347 "dilation_x * KW <= MAX_KERNEL)")))
349 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
350 assert(shape.size() == 4);
351 assert(dilation.size() == 2);
352 if (
failed(levelCheckKernel(op, dilation[0] * shape[0],
353 "dilation_y * KH <= MAX_KERNEL)")) ||
354 failed(levelCheckKernel(op, dilation[1] * shape[1],
355 "dilation_x * KW <= MAX_KERNEL)")))
364 template <
typename T>
365 LogicalResult levelCheckFFT(Operation *op) {
368 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
369 auto shape = type.getShape();
370 assert(shape.size() == 3);
371 if (
failed(levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL")) ||
372 failed(levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL"))) {
382 LogicalResult levelCheckTransposeConv2d(Operation *op) {
383 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
384 if (ShapedType filterType =
385 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
386 auto shape = filterType.getShape();
387 assert(shape.size() == 4);
389 if (
failed(levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL")) ||
390 failed(levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL"))) {
394 for (
auto p : transpose.getOutPad()) {
395 if (
failed(levelCheckKernel(op, p,
"pad <= MAX_KERNEL"))) {
399 for (
auto s : transpose.getStride()) {
400 if (
failed(levelCheckStride(op, s,
"stride <= MAX_STRIDE"))) {
409 LogicalResult levelCheckResize(Operation *op) {
410 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
411 SmallVector<int64_t> scale;
416 const int64_t scaleYN = scale[0];
417 const int64_t scaleYD = scale[1];
418 const int64_t scaleXN = scale[2];
419 const int64_t scaleXD = scale[3];
420 if (
failed(levelCheckScale(op, scaleYN / scaleYD,
421 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
422 failed(levelCheckScale(op, scaleXN / scaleXD,
423 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
434 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
435 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
443 getMaxNestedDepth(op, depth);
446 LogicalResult levelCheckMaxNesting(Operation *op) {
447 int32_t maxNestedDepth = 0;
448 getMaxNestedDepth(op, maxNestedDepth);
450 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
451 op->
emitOpError() <<
"failed level check: " << maxNestedDepth
452 <<
" >= MAX_NESTING";
458 LogicalResult levelCheckListSize(Operation *op) {
459 if (
auto concat = dyn_cast<tosa::ConcatOp>(op)) {
460 return levelCheckListSize(op,
concat.getInput1().size(),
"input1");
462 if (
auto custom = dyn_cast<tosa::CustomOp>(op)) {
463 if (
failed(levelCheckListSize(op, custom.getInputList().size(),
465 failed(levelCheckListSize(op, custom.getOutputList().size(),
470 if (
auto condIf = dyn_cast<tosa::IfOp>(op)) {
472 levelCheckListSize(op, condIf.getInputList().size(),
"inputs")) ||
473 failed(levelCheckListSize(op, condIf.getOutputList().size(),
478 if (
auto w = dyn_cast<tosa::WhileOp>(op)) {
479 if (
failed(levelCheckListSize(op, w.getInputList().size(),
"inputs")) ||
480 failed(levelCheckListSize(op, w.getOutputList().size(),
"outputs"))) {
487 LogicalResult attributeCheckRescale(Operation *op) {
488 if (
auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
489 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
490 !targetEnv.allows(Extension::doubleround)) {
492 <<
"failed attribute check: rounding_mode = DOUBLE_ROUND "
493 <<
"requires extension [doubleround]";
496 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
497 !targetEnv.allows(Extension::inexactround)) {
499 <<
"failed attribute check: rounding_mode = INEXACT_ROUND "
500 <<
"requires extension [inexactround]";
507 LogicalResult CheckVariable(Operation *op);
508 LogicalResult CheckVariableReadOrWrite(Operation *op);
509 bool isValidElementType(Type type,
const bool allowUnsigned =
false);
512 std::function<LogicalResult(Operation *,
const tosa::TargetEnv &)>>
515 TosaProfileCompliance profileComp;
516 tosa::TargetEnv targetEnv;
520LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
521 auto *op = tosaOp.getOperation();
522 if (
failed(levelCheckRank(op, tosaOp.getInput(),
"operand",
527 if (
failed(levelCheckRank(op, tosaOp.getOutput(),
"result",
535LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
536 auto *op = tosaOp.getOperation();
539 if (
failed(levelCheckRank(op, tosaOp.getCondition(),
"operand",
547LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
548 auto *op = tosaOp.getOperation();
550 if (
failed(levelCheckRank(op, variableType,
"variable type",
558LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
559 auto *op = tosaOp.getOperation();
561 if (
failed(levelCheckSize(op, variableType,
"variable type")))
567LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
568#define CHECK_RANKS_AND_SIZES(tosaOp) \
569 if (isa<tosa::tosaOp##Op>(op)) { \
570 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
572 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
576#define CHECK_SIZES(tosaOp) \
577 if (isa<tosa::tosaOp##Op>(op)) { \
578 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
582#define CHECK_RANKS(tosaOp) \
583 if (isa<tosa::tosaOp##Op>(op)) { \
584 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
699#undef CHECK_RANKS_AND_SIZES
706LogicalResult TosaValidation::levelCheckSize(Operation *op,
707 const Type &typeToCheck,
708 const StringRef operandOrResult) {
709 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
711 return op->
emitOpError() <<
"failed level check: unranked tensor";
712 auto shape = type.getShape();
713 for (
auto dim : shape) {
714 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
715 const TosaSpecificationVersion targetVersion = targetEnv.
getSpecVersion();
716 const TosaSpecificationVersion minRequiredVersion(1, 1);
726 return op->
emitOpError() <<
"failed level check: " << operandOrResult
727 <<
" shape dimension cannot be dynamic when"
728 <<
" targeting TOSA specification version 1.0"
733 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
734 int64_t size = element_bytes * type.getNumElements();
741 const int64_t max_size =
745 <<
"failed level check: " << operandOrResult
746 <<
" tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
751LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
758 if (
failed(levelCheckRanksAndSizes(op)))
762 if (
failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
763 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
764 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
765 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
766 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
767 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
768 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
769 failed(levelCheckTransposeConv2d(op)) ||
failed(levelCheckResize(op))) {
774 if (
failed(levelCheckListSize(op))) {
778 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
779 if (
failed(levelCheckMaxNesting(op))) {
787LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
788 if (
failed(attributeCheckRescale(op)))
793inline bool CompatibleTypes(
const mlir::Type &type,
794 const mlir::Type &declaredType) {
796 return type == declaredType;
799LogicalResult TosaValidation::CheckVariable(Operation *op) {
800 if (
auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
801 mlir::StringAttr nameAttr = variableOp.getNameAttr();
803 if (variablesMap.count(nameAttr))
804 return op->
emitOpError() <<
"name has already been declared";
806 auto elementType = variableOp.getType();
807 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
808 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
809 RankedTensorType variableType =
810 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
812 variablesMap[nameAttr] = variableType;
818LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
819 if (isa<mlir::tosa::VariableReadOp>(op) ||
820 isa<mlir::tosa::VariableWriteOp>(op)) {
821 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
822 if (!variablesMap.count(nameAttr))
823 return op->
emitOpError() <<
"name has not been declared";
825 auto varType = variablesMap[nameAttr];
828 auto type = v.getType();
829 if (!CompatibleTypes(type, varType))
830 return op->
emitOpError() <<
"operand type does not equal variable type";
834 auto type = v.getType();
835 if (!CompatibleTypes(type, varType))
836 return op->
emitOpError() <<
"result type does not equal variable type";
843LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
844 if (
failed(CheckVariable(op)) ||
failed(CheckVariableReadOrWrite(op)))
849LogicalResult checkErrorIfResize(Operation *op) {
850 auto resize = dyn_cast<tosa::ResizeOp>(op);
854 const Value input = resize.getInput();
855 const Value output = resize.getOutput();
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(input.
getType());
858 const RankedTensorType outputType =
859 llvm::dyn_cast<RankedTensorType>(output.
getType());
861 if (!inputType || !outputType)
862 return op->
emitOpError(
"expect ranked input/output tensor");
866 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
867 const SmallVector<int64_t, 4> sizes = {
868 outputType.getDimSize(1), outputType.getDimSize(2),
869 inputType.getDimSize(1), inputType.getDimSize(2)};
870 const int64_t *maxDim = llvm::max_element(sizes);
871 if (maxDim != sizes.end() && *maxDim >= 16384)
873 "expect input/output height/width dims to be < 16384, ")
874 <<
"got [OH, OW, IH, IW] = " << sizes;
877 SmallVector<int64_t> scale;
881 const int64_t scaleYN = scale[0];
882 const int64_t scaleYD = scale[1];
883 const int64_t scaleXN = scale[2];
884 const int64_t scaleXD = scale[3];
887 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
889 "expect all scale numerator values to be <= (1 << 11), "
891 << scaleYN <<
", scale_x_n=" << scaleXN;
893 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
894 return op->
emitOpError(
"expect a downscale ratio larger than 1/16, got y=")
895 << scaleYN <<
"/" << scaleYD <<
", x=" << scaleXN <<
"/" << scaleXD;
897 SmallVector<int64_t> offset;
898 SmallVector<int64_t> border;
903 const int64_t offsetY = offset[0];
904 const int64_t offsetX = offset[1];
907 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
909 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
910 << offsetY <<
"/" << scaleYN;
911 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
913 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
914 << offsetX <<
"/" << scaleXN;
916 const int64_t borderY = border[0];
917 const int64_t borderX = border[1];
918 if (borderY < -16 * scaleYN || borderY >= scaleYN)
920 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
921 << borderY <<
"/" << scaleYN;
922 if (borderX < -16 * scaleXN || borderX >= scaleXN)
924 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
925 << borderX <<
"/" << scaleXN;
938 const int64_t
rhs) -> std::optional<int64_t> {
944 const int64_t oh = outputType.getDimSize(1);
945 const int64_t ow = outputType.getDimSize(2);
946 const int64_t ih = inputType.getDimSize(1);
947 const int64_t iw = inputType.getDimSize(2);
949 if (ih != ShapedType::kDynamic) {
950 const std::optional<int64_t> calculatedOutHeightMinusOne =
951 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
952 if (!calculatedOutHeightMinusOne.has_value())
954 "expected (input_height - 1) * scale_y_n - offset_y + "
956 <<
"to be wholly divisible by scale_y_d, got ((" << ih
957 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
958 <<
") / " << scaleYD;
959 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
960 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
962 "calculated output height did not match expected: ")
963 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
966 if (iw != ShapedType::kDynamic) {
967 const std::optional<int64_t> calculatedOutWidthMinusOne =
968 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
969 if (!calculatedOutWidthMinusOne.has_value())
971 "expected (input_width - 1) * scale_x_n - offset_x + "
973 <<
"to be wholly divisible by scale_x_d, got ((" << iw
974 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
975 <<
") / " << scaleXD;
976 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
977 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
978 return op->
emitOpError(
"calculated output width did not match expected: ")
979 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
985LogicalResult checkErrorIfMul(Operation *op) {
986 auto mul = dyn_cast<tosa::MulOp>(op);
992 ElementsAttr shift_elem;
995 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
997 if (inputElemType.isInteger(32)) {
999 if (shift < 0 || shift > 63)
1001 <<
"requires 0 <= shift && shift <= 63, but got: " << shift;
1006 <<
"requires shift = 0 for all input data types that "
1007 "are not int32_t, but got: "
1014LogicalResult checkErrorIfTable(Operation *op) {
1015 auto table = dyn_cast<tosa::TableOp>(op);
1021 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1023 const ShapeAdaptor tableShape(table.getTable().getType());
1024 if (tableShape.hasStaticShape()) {
1025 const auto numElements = tableShape.getNumElements();
1026 if (numElements != tableSize)
1027 return op->
emitOpError() <<
"requires table size of " << tableSize
1028 <<
", got " << numElements;
1034LogicalResult checkErrorIfRescale(Operation *op) {
1035 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1039 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1040 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1041 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1042 !outputType.getElementType().isInteger())
1045 auto inElemType = inputType.getElementType();
1046 auto outElemType = outputType.getElementType();
1047 auto inWidth = inElemType.getIntOrFloatBitWidth();
1048 auto outWidth = outElemType.getIntOrFloatBitWidth();
1050 bool inputUnsigned = rescale.getInputUnsigned();
1051 bool outputUnsigned = rescale.getOutputUnsigned();
1053 bool scale32 = rescale.getScale32();
1054 auto roundingMode = rescale.getRoundingMode();
1057 if (scale32 && inWidth == 48)
1058 return op->
emitOpError() <<
"scale32 is not allowed with 48-bit input.";
1061 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1063 <<
"DOUBLE_ROUND is only allowed with scale32=true.";
1066 if (inputUnsigned && outputUnsigned)
1067 return op->
emitOpError() <<
"input and output cannot be both unsigned.";
1070 if (outWidth == 32 && inputUnsigned)
1072 <<
"i32 output type is not allowed with unsigned input.";
1075 if (inWidth == 32 && outputUnsigned)
1077 <<
"i32 input type is not allowed with unsigned output.";
1080 if (inWidth == 48 && outputUnsigned)
1082 <<
"i48 input type is not allowed with unsigned output.";
1085 if (inWidth == 48 && inputUnsigned)
1086 return op->
emitOpError() <<
"i48 input type cannot be unsigned.";
1089 if (inWidth == 32 && inputUnsigned)
1090 return op->
emitOpError() <<
"i32 input type cannot be unsigned.";
1093 if (outWidth == 32 && outputUnsigned)
1094 return op->
emitOpError() <<
"i32 output type cannot be unsigned.";
1099LogicalResult checkErrorIfPad(Operation *op) {
1100 auto pad = dyn_cast<tosa::PadOp>(op);
1104 DenseIntElementsAttr paddingAttr;
1109 for (
const APInt &val : paddingAttr.getValues<APInt>()) {
1110 if (val.getSExtValue() < 0)
1111 return op->
emitOpError() <<
"padding value must all be non-negative, got "
1112 << val.getSExtValue();
1118static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1119 return llvm::all_of(op->
getOperands(), [&](
auto operand) {
1120 Region *operandRegion = operand.getParentRegion();
1121 return operandRegion && region->isAncestor(operandRegion);
1125static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) {
1126 bool noLiveInValue =
true;
1127 regionToCheck.
walk([&noLiveInValue, ®ionToCheck](Operation *op) {
1128 if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
1129 noLiveInValue =
false;
1134 return noLiveInValue ?
success() : failure();
1137LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
1138 StringRef regionName) {
1139 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1142 <<
"is not conformant to the TOSA specification. It requires the '"
1143 << regionName <<
"' region is isolated from above.\n";
1146LogicalResult checkErrorIfCondIf(Operation *op) {
1147 auto ifOp = dyn_cast<tosa::IfOp>(op);
1180 if (
failed(checkIsolatedRegion(op, ifOp.getThenGraph(),
"then")) ||
1181 failed(checkIsolatedRegion(op, ifOp.getElseGraph(),
"else")))
1186LogicalResult checkErrorIfWhileLoop(Operation *op) {
1187 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1191 if (
failed(checkIsolatedRegion(op, whileOp.getCondGraph(),
"cond")) ||
1192 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(),
"body")))
1197LogicalResult checkErrorIfScatter(Operation *op) {
1198 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1203 DenseIntElementsAttr indicesAttr;
1207 auto const indicesType =
1208 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1209 if (!indicesType || !indicesType.hasRank()) {
1215 op->
emitOpError(
"indices values contain duplicates");
1222LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1223 if (
failed(checkErrorIfResize(op)) ||
failed(checkErrorIfMul(op)) ||
1224 failed(checkErrorIfTable(op)) ||
failed(checkErrorIfRescale(op)) ||
1225 failed(checkErrorIfPad(op)) ||
failed(checkErrorIfCondIf(op)) ||
1226 failed(checkErrorIfWhileLoop(op)) ||
failed(checkErrorIfScatter(op)))
1231bool TosaValidation::isValidElementType(Type type,
const bool allowUnsigned) {
1232 if (isa<FloatType>(type)) {
1233 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1234 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1235 Float6E3M2FNType, Float8E8M0FNUType>(type);
1236 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
1237 if (intTy.isSignless()) {
1238 switch (intTy.getWidth()) {
1248 }
else if (allowUnsigned && intTy.isUnsigned()) {
1249 switch (intTy.getWidth()) {
1256 }
else if (isa<tosa::shapeType>(type))
1258 else if (isa<tosa::mxint8Type>(type))
1263void TosaValidation::runOnOperation() {
1264 ModuleOp modOp = getOperation();
1266 const auto maybeTargetEnv =
1268 if (
failed(maybeTargetEnv))
1269 return signalPassFailure();
1270 targetEnv = *maybeTargetEnv;
1272 TosaDialect *tosaDialect =
getContext().getLoadedDialect<TosaDialect>();
1276 modOp.walk([&](Operation *op) {
1285 const bool allowUnsigned =
1286 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1289 if (!isValidElementType(elementTy, allowUnsigned)) {
1290 op->
emitOpError() <<
"is not profile-aligned: element type "
1291 << elementTy <<
" is not legal";
1292 return signalPassFailure();
1297 if (!isValidElementType(elementTy, allowUnsigned)) {
1298 op->
emitOpError() <<
"is not profile-aligned: element type "
1299 << elementTy <<
" is not legal";
1300 return signalPassFailure();
1304 if (strictOpSpecAlignment &&
1306 return signalPassFailure();
1308 if (strictOpSpecAlignment &&
1310 return signalPassFailure();
1312 if (!allowInvalidOpDatatypeCombinations &&
1314 return signalPassFailure();
1318 if (
failed(applyConstantOperandCheck(op)))
1319 signalPassFailure();
1322 if (
failed(applyLevelCheck(op)))
1323 signalPassFailure();
1326 if (
failed(applyAttributeCheck(op)))
1327 signalPassFailure();
1330 if (
failed(applyVariableCheck(op)))
1331 signalPassFailure();
1334 if (strictOpSpecAlignment &&
failed(applyErrorIfCheck(op)))
1335 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_RANKS(tosaOp)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
This class represents the capability enabled in the target implementation such as profile,...
TosaLevel getLevel() const
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
TosaSpecificationVersion getSpecVersion() const
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.