15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
40 static LogicalResult checkConstantOperandPad(
Operation *op) {
41 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
44 return op->
emitOpError(
"padding of pad is not constant");
48 if (padOp.getPadConst() &&
50 return op->
emitOpError(
"pad_const of pad is not constant");
55 static LogicalResult checkConstantOperandTranspose(
Operation *op) {
56 if (
auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
59 return op->
emitOpError(
"perms of transpose is not constant");
64 static LogicalResult checkConstantOperandFullyConnected(
Operation *op) {
65 if (
auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
68 return op->
emitOpError(
"weight of fully_connected is not constant");
72 return op->
emitOpError(
"bias of fully_connected is not constant");
79 int32_t MAX_KERNEL = 0;
80 int32_t MAX_STRIDE = 0;
81 int32_t MAX_SCALE = 0;
86 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
98 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
100 explicit TosaValidation() { populateConstantOperandChecks(); }
101 explicit TosaValidation(
const TosaValidationOptions &
options)
103 this->profile =
options.profile;
104 this->StrictOperationSpecAlignment =
options.StrictOperationSpecAlignment;
107 void runOnOperation() final;
109 LogicalResult applyConstantOperandCheck(
Operation *op) {
110 for (
auto &checker : constCheckers) {
111 if (failed(checker(op)))
117 LogicalResult applyLevelCheck(
Operation *op);
120 LogicalResult applyVariableCheck(
Operation *op);
123 void populateConstantOperandChecks() {
124 constCheckers.emplace_back(checkConstantOperandPad);
125 constCheckers.emplace_back(checkConstantOperandTranspose);
126 constCheckers.emplace_back(checkConstantOperandFullyConnected);
129 bool levelCheckKernel(
Operation *op, int32_t v,
130 const std::string &checkDesc) {
131 if (v > tosaLevel.MAX_KERNEL) {
132 op->
emitOpError() <<
"failed level check: " << checkDesc;
138 bool levelCheckStride(
Operation *op, int32_t v,
139 const std::string &checkDesc) {
140 if (v > tosaLevel.MAX_STRIDE) {
141 op->
emitOpError() <<
"failed level check: " << checkDesc;
147 bool levelCheckScale(
Operation *op, int32_t v,
const std::string &checkDesc) {
148 if (v > tosaLevel.MAX_SCALE) {
149 op->
emitOpError() <<
"failed level check: " << checkDesc;
156 const std::string &checkDesc) {
157 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
158 if (!type.hasRank()) {
159 op->
emitOpError() <<
"failed level check: unranked tensor";
162 if (type.getRank() > tosaLevel.MAX_RANK) {
163 op->
emitOpError() <<
"failed level check: " << checkDesc;
170 template <
typename T>
172 if (dyn_cast<T>(op)) {
175 if (!levelCheckRank(op, v,
"operand rank(shape) <= MAX_RANK"))
179 if (!levelCheckRank(op, v,
"result rank(shape) <= MAX_RANK"))
187 #define CHECK_RANKS_FOR(tosaOp) \
188 if (!levelCheckRanksFor<tosaOp##Op>(op)) \
255 #undef CHECK_RANKS_FOR
260 template <
typename T>
262 if (
auto poolOp = dyn_cast<T>(op)) {
263 for (
auto k : poolOp.getKernel()) {
264 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
268 for (
auto s : poolOp.getStride()) {
269 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
273 for (
auto p : poolOp.getPad()) {
274 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
283 template <
typename T>
285 if (
auto convOp = dyn_cast<T>(op)) {
287 for (
auto k : convOp.getDilation()) {
288 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
292 for (
auto p : convOp.getPad()) {
293 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
297 for (
auto s : convOp.getStride()) {
298 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
302 auto dilation = convOp.getDilation();
303 if (ShapedType weightType =
305 auto shape = weightType.getShape();
306 if (isa<tosa::Conv2DOp>(op)) {
307 assert(shape.size() == 4);
308 assert(dilation.size() == 2);
309 if (!levelCheckKernel(op, dilation[0] * shape[1],
310 "dilation_y * KH <= MAX_KERNEL)") ||
311 !levelCheckKernel(op, dilation[1] * shape[2],
312 "dilation_x * KW <= MAX_KERNEL)"))
314 }
else if (isa<tosa::Conv3DOp>(op)) {
315 assert(shape.size() == 5);
316 assert(dilation.size() == 3);
317 if (!levelCheckKernel(op, dilation[0] * shape[1],
318 "dilation_d * KD <= MAX_KERNEL)") ||
319 !levelCheckKernel(op, dilation[1] * shape[2],
320 "dilation_y * KH <= MAX_KERNEL)") ||
321 !levelCheckKernel(op, dilation[2] * shape[3],
322 "dilation_x * KW <= MAX_KERNEL)"))
324 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325 assert(shape.size() == 4);
326 assert(dilation.size() == 2);
327 if (!levelCheckKernel(op, dilation[0] * shape[0],
328 "dilation_y * KH <= MAX_KERNEL)") ||
329 !levelCheckKernel(op, dilation[1] * shape[1],
330 "dilation_x * KW <= MAX_KERNEL)"))
339 template <
typename T>
343 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
344 auto shape = type.getShape();
345 assert(shape.size() == 3);
346 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
347 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
357 bool levelCheckTransposeConv2d(
Operation *op) {
358 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359 if (ShapedType filterType =
360 dyn_cast<ShapedType>(
transpose.getFilter().getType())) {
361 auto shape = filterType.getShape();
362 assert(shape.size() == 4);
364 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
365 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
370 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
375 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
385 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386 auto scale = resize.getScale();
387 int16_t scaleYN = scale[0];
388 int16_t scaleYD = scale[1];
389 int16_t scaleXN = scale[2];
390 int16_t scaleXD = scale[3];
391 if (!levelCheckScale(op, scaleYN / scaleYD,
392 "scale_y_n/scale_y_d <= MAX_SCALE") ||
393 !levelCheckScale(op, scaleXN / scaleXD,
394 "scale_x_n/scale_x_d <= MAX_SCALE")) {
403 void configLevelAndProfile() {
404 tosaLevel = TOSA_LEVEL_NONE;
405 if (level == TosaLevelEnum::EightK) {
406 tosaLevel = TOSA_LEVEL_EIGHTK;
409 if (!profile.empty()) {
410 for (std::string &prof : profile) {
411 auto profSymbol = symbolizeTosaProfileEnum(prof);
413 enabled_profiles.push_back(profSymbol.value());
420 bool CheckVariableReadOrWrite(
Operation *op);
422 bool isValidElementType(
Type type);
423 bool isEnabledProfile(TosaProfileEnum prof) {
424 return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
425 std::end(enabled_profiles);
434 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
435 if (tosaLevel == TOSA_LEVEL_NONE) {
440 if (!levelCheckRanks(op)) {
445 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
446 !levelCheckConv<tosa::Conv2DOp>(op) ||
447 !levelCheckConv<tosa::Conv3DOp>(op) ||
448 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
449 !levelCheckFFT<tosa::FFT2dOp>(op) ||
450 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
451 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
452 !levelCheckResize(op)) {
459 inline bool CompatibleTypes(
const mlir::Type &type,
462 return type == declaredType;
465 bool TosaValidation::CheckVariable(
Operation *op) {
466 if (isa<mlir::tosa::VariableOp>(op)) {
467 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
469 if (variablesMap.count(nameAttr)) {
470 op->
emitOpError() <<
"name has already been declared";
474 auto typeAttr = cast<mlir::TypeAttr>(op->
getAttr(
"type"));
477 variablesMap[nameAttr] = type;
483 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
484 if (isa<mlir::tosa::VariableReadOp>(op) ||
485 isa<mlir::tosa::VariableWriteOp>(op)) {
486 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
488 if (!variablesMap.count(nameAttr)) {
493 auto varType = variablesMap[nameAttr];
497 if (!CompatibleTypes(type, varType)) {
498 op->
emitOpError() <<
"operand type does not equal variable type";
505 if (!CompatibleTypes(type, varType)) {
506 op->
emitOpError() <<
"result type does not equal variable type";
515 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
516 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
522 bool TosaValidation::isValidElementType(
Type type) {
523 if (isa<FloatType>(type)) {
524 if (!isEnabledProfile(TosaProfileEnum::MainInference))
527 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
528 if (intTy.isSignless()) {
529 switch (intTy.getWidth()) {
543 void TosaValidation::runOnOperation() {
544 configLevelAndProfile();
551 auto elementTy = getElementTypeOrSelf(operand);
552 if (!isValidElementType(elementTy)) {
553 op->emitOpError() <<
"is not profile-aligned: element type "
554 << elementTy <<
" is not legal";
555 return signalPassFailure();
559 auto elementTy = getElementTypeOrSelf(resultTy);
560 if (!isValidElementType(elementTy)) {
561 op->emitOpError() <<
"is not profile-aligned: element type "
562 << elementTy <<
" is not legal";
563 return signalPassFailure();
569 if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
573 if (failed(applyLevelCheck(op)))
577 if (failed(applyVariableCheck(op)))
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
StringRef getNamespace() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.