15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
40 static LogicalResult checkConstantOperandPad(
Operation *op) {
41 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
44 return op->
emitOpError(
"padding of pad is not constant");
48 if (padOp.getPadConst() &&
50 return op->
emitOpError(
"pad_const of pad is not constant");
55 static LogicalResult checkConstantOperandTranspose(
Operation *op) {
56 if (
auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
59 return op->
emitOpError(
"perms of transpose is not constant");
66 int32_t MAX_KERNEL = 0;
67 int32_t MAX_STRIDE = 0;
68 int32_t MAX_SCALE = 0;
73 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
74 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
78 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
79 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
85 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
87 explicit TosaValidation() { populateConstantOperandChecks(); }
88 explicit TosaValidation(
const TosaValidationOptions &
options)
90 this->profile =
options.profile;
91 this->StrictOperationSpecAlignment =
options.StrictOperationSpecAlignment;
94 void runOnOperation() final;
96 LogicalResult applyConstantOperandCheck(
Operation *op) {
97 for (
auto &checker : constCheckers) {
98 if (failed(checker(op)))
104 LogicalResult applyLevelCheck(
Operation *op);
107 LogicalResult applyVariableCheck(
Operation *op);
110 void populateConstantOperandChecks() {
111 constCheckers.emplace_back(checkConstantOperandPad);
112 constCheckers.emplace_back(checkConstantOperandTranspose);
115 bool levelCheckKernel(
Operation *op, int32_t v,
116 const std::string &checkDesc) {
117 if (v > tosaLevel.MAX_KERNEL) {
118 op->
emitOpError() <<
"failed level check: " << checkDesc;
124 bool levelCheckStride(
Operation *op, int32_t v,
125 const std::string &checkDesc) {
126 if (v > tosaLevel.MAX_STRIDE) {
127 op->
emitOpError() <<
"failed level check: " << checkDesc;
133 bool levelCheckScale(
Operation *op, int32_t v,
const std::string &checkDesc) {
134 if (v > tosaLevel.MAX_SCALE) {
135 op->
emitOpError() <<
"failed level check: " << checkDesc;
142 const std::string &checkDesc) {
143 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
144 if (!type.hasRank()) {
145 op->
emitOpError() <<
"failed level check: unranked tensor";
148 if (type.getRank() > tosaLevel.MAX_RANK) {
149 op->
emitOpError() <<
"failed level check: " << checkDesc;
156 template <
typename T>
158 if (dyn_cast<T>(op)) {
161 if (!levelCheckRank(op, v,
"operand rank(shape) <= MAX_RANK"))
165 if (!levelCheckRank(op, v,
"result rank(shape) <= MAX_RANK"))
173 #define CHECK_RANKS_FOR(tosaOp) \
174 if (!levelCheckRanksFor<tosaOp##Op>(op)) \
241 #undef CHECK_RANKS_FOR
246 template <
typename T>
248 if (
auto poolOp = dyn_cast<T>(op)) {
249 for (
auto k : poolOp.getKernel()) {
250 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
254 for (
auto s : poolOp.getStride()) {
255 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
259 for (
auto p : poolOp.getPad()) {
260 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
269 template <
typename T>
271 if (
auto convOp = dyn_cast<T>(op)) {
273 for (
auto k : convOp.getDilation()) {
274 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
278 for (
auto p : convOp.getPad()) {
279 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
283 for (
auto s : convOp.getStride()) {
284 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
288 auto dilation = convOp.getDilation();
289 if (ShapedType weightType =
291 auto shape = weightType.getShape();
292 if (isa<tosa::Conv2DOp>(op)) {
293 assert(shape.size() == 4);
294 assert(dilation.size() == 2);
295 if (!levelCheckKernel(op, dilation[0] * shape[1],
296 "dilation_y * KH <= MAX_KERNEL)") ||
297 !levelCheckKernel(op, dilation[1] * shape[2],
298 "dilation_x * KW <= MAX_KERNEL)"))
300 }
else if (isa<tosa::Conv3DOp>(op)) {
301 assert(shape.size() == 5);
302 assert(dilation.size() == 3);
303 if (!levelCheckKernel(op, dilation[0] * shape[1],
304 "dilation_d * KD <= MAX_KERNEL)") ||
305 !levelCheckKernel(op, dilation[1] * shape[2],
306 "dilation_y * KH <= MAX_KERNEL)") ||
307 !levelCheckKernel(op, dilation[2] * shape[3],
308 "dilation_x * KW <= MAX_KERNEL)"))
310 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
311 assert(shape.size() == 4);
312 assert(dilation.size() == 2);
313 if (!levelCheckKernel(op, dilation[0] * shape[0],
314 "dilation_y * KH <= MAX_KERNEL)") ||
315 !levelCheckKernel(op, dilation[1] * shape[1],
316 "dilation_x * KW <= MAX_KERNEL)"))
325 template <
typename T>
329 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
330 auto shape = type.getShape();
331 assert(shape.size() == 3);
332 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
333 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
343 bool levelCheckTransposeConv2d(
Operation *op) {
344 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
345 if (ShapedType filterType =
346 dyn_cast<ShapedType>(
transpose.getWeight().getType())) {
347 auto shape = filterType.getShape();
348 assert(shape.size() == 4);
350 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
351 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
356 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
361 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
371 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
372 auto scale = resize.getScale();
373 int16_t scaleYN = scale[0];
374 int16_t scaleYD = scale[1];
375 int16_t scaleXN = scale[2];
376 int16_t scaleXD = scale[3];
377 if (!levelCheckScale(op, scaleYN / scaleYD,
378 "scale_y_n/scale_y_d <= MAX_SCALE") ||
379 !levelCheckScale(op, scaleXN / scaleXD,
380 "scale_x_n/scale_x_d <= MAX_SCALE")) {
389 void configLevelAndProfile() {
390 tosaLevel = TOSA_LEVEL_NONE;
391 if (level == TosaLevelEnum::EightK) {
392 tosaLevel = TOSA_LEVEL_EIGHTK;
395 if (!profile.empty()) {
396 for (std::string &prof : profile) {
397 auto profSymbol = symbolizeTosaProfileEnum(prof);
399 enabled_profiles.push_back(profSymbol.value());
406 bool CheckVariableReadOrWrite(
Operation *op);
408 bool isValidElementType(
Type type);
409 bool isEnabledProfile(TosaProfileEnum prof) {
410 return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
411 std::end(enabled_profiles);
420 LogicalResult TosaValidation::applyLevelCheck(
Operation *op) {
421 if (tosaLevel == TOSA_LEVEL_NONE) {
426 if (!levelCheckRanks(op)) {
431 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
432 !levelCheckConv<tosa::Conv2DOp>(op) ||
433 !levelCheckConv<tosa::Conv3DOp>(op) ||
434 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
435 !levelCheckFFT<tosa::FFT2dOp>(op) ||
436 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
437 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
438 !levelCheckResize(op)) {
445 inline bool CompatibleTypes(
const mlir::Type &type,
448 return type == declaredType;
451 bool TosaValidation::CheckVariable(
Operation *op) {
452 if (isa<mlir::tosa::VariableOp>(op)) {
453 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
455 if (variablesMap.count(nameAttr)) {
456 op->
emitOpError() <<
"name has already been declared";
460 auto typeAttr = cast<mlir::TypeAttr>(op->
getAttr(
"type"));
463 variablesMap[nameAttr] = type;
469 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
470 if (isa<mlir::tosa::VariableReadOp>(op) ||
471 isa<mlir::tosa::VariableWriteOp>(op)) {
472 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
474 if (!variablesMap.count(nameAttr)) {
479 auto varType = variablesMap[nameAttr];
483 if (!CompatibleTypes(type, varType)) {
484 op->
emitOpError() <<
"operand type does not equal variable type";
491 if (!CompatibleTypes(type, varType)) {
492 op->
emitOpError() <<
"result type does not equal variable type";
501 LogicalResult TosaValidation::applyVariableCheck(
Operation *op) {
502 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
508 bool TosaValidation::isValidElementType(
Type type) {
509 if (isa<FloatType>(type)) {
510 if (!isEnabledProfile(TosaProfileEnum::MainInference))
513 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
514 if (intTy.isSignless()) {
515 switch (intTy.getWidth()) {
525 }
else if (mlir::isa<tosa::shapeType>(type)) {
531 void TosaValidation::runOnOperation() {
532 configLevelAndProfile();
543 auto elementTy = getElementTypeOrSelf(operand);
544 if (!isValidElementType(elementTy)) {
545 op->emitOpError() <<
"is not profile-aligned: element type "
546 << elementTy <<
" is not legal";
547 return signalPassFailure();
551 auto elementTy = getElementTypeOrSelf(resultTy);
552 if (!isValidElementType(elementTy)) {
553 op->emitOpError() <<
"is not profile-aligned: element type "
554 << elementTy <<
" is not legal";
555 return signalPassFailure();
561 if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
565 if (failed(applyLevelCheck(op)))
569 if (failed(applyVariableCheck(op)))
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.