15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
41 if (
auto padOp = dyn_cast<tosa::PadOp>(op)) {
44 return op->
emitOpError(
"padding of pad is not constant");
48 if (padOp.getPadConst() &&
50 return op->
emitOpError(
"pad_const of pad is not constant");
56 if (
auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
59 return op->
emitOpError(
"perms of transpose is not constant");
65 if (
auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
68 return op->
emitOpError(
"weight of fully_connected is not constant");
72 return op->
emitOpError(
"bias of fully_connected is not constant");
79 int32_t MAX_KERNEL = 0;
80 int32_t MAX_STRIDE = 0;
81 int32_t MAX_SCALE = 0;
86 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
98 struct TosaValidation :
public tosa::impl::TosaValidationBase<TosaValidation> {
100 explicit TosaValidation() { populateConstantOperandChecks(); }
101 explicit TosaValidation(
const TosaValidationOptions &
options)
103 this->profile =
options.profile;
104 this->StrictOperationSpecAlignment =
options.StrictOperationSpecAlignment;
107 void runOnOperation() final;
110 for (
auto &checker : constCheckers) {
123 void populateConstantOperandChecks() {
124 constCheckers.emplace_back(checkConstantOperandPad);
125 constCheckers.emplace_back(checkConstantOperandTranspose);
126 constCheckers.emplace_back(checkConstantOperandFullyConnected);
129 bool levelCheckKernel(
Operation *op, int32_t v,
130 const std::string &checkDesc) {
131 if (v > tosaLevel.MAX_KERNEL) {
132 op->
emitOpError() <<
"failed level check: " << checkDesc;
138 bool levelCheckStride(
Operation *op, int32_t v,
139 const std::string &checkDesc) {
140 if (v > tosaLevel.MAX_STRIDE) {
141 op->
emitOpError() <<
"failed level check: " << checkDesc;
147 bool levelCheckScale(
Operation *op, int32_t v,
const std::string &checkDesc) {
148 if (v > tosaLevel.MAX_SCALE) {
149 op->
emitOpError() <<
"failed level check: " << checkDesc;
156 const std::string &checkDesc) {
157 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
158 if (!type.hasRank()) {
159 op->
emitOpError() <<
"failed level check: unranked tensor";
162 if (type.getRank() > tosaLevel.MAX_RANK) {
163 op->
emitOpError() <<
"failed level check: " << checkDesc;
170 template <
typename T>
172 if (dyn_cast<T>(op)) {
175 if (!levelCheckRank(op, v,
"operand rank(shape) <= MAX_RANK"))
179 if (!levelCheckRank(op, v,
"result rank(shape) <= MAX_RANK"))
187 #define CHECK_RANKS_FOR(tosaOp) \
188 if (!levelCheckRanksFor<tosaOp##Op>(op)) \
255 #undef CHECK_RANKS_FOR
260 template <
typename T>
262 if (
auto poolOp = dyn_cast<T>(op)) {
263 for (
auto k : poolOp.getKernel()) {
264 if (!levelCheckKernel(op, k,
"kernel <= MAX_KERNEL")) {
268 for (
auto s : poolOp.getStride()) {
269 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
273 for (
auto p : poolOp.getPad()) {
274 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
283 template <
typename T>
285 if (
auto convOp = dyn_cast<T>(op)) {
287 for (
auto k : convOp.getDilation()) {
288 if (!levelCheckKernel(op, k,
"dilation <= MAX_KERNEL")) {
292 for (
auto p : convOp.getPad()) {
293 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
297 for (
auto s : convOp.getStride()) {
298 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
302 auto dilation = convOp.getDilation();
303 if (ShapedType weightType =
305 auto shape = weightType.getShape();
306 if (isa<tosa::Conv2DOp>(op)) {
307 assert(shape.size() == 4);
308 assert(dilation.size() == 2);
309 if (!levelCheckKernel(op, dilation[0] * shape[1],
310 "dilation_y * KH <= MAX_KERNEL)") ||
311 !levelCheckKernel(op, dilation[1] * shape[2],
312 "dilation_x * KW <= MAX_KERNEL)"))
314 }
else if (isa<tosa::Conv3DOp>(op)) {
315 assert(shape.size() == 5);
316 assert(dilation.size() == 3);
317 if (!levelCheckKernel(op, dilation[0] * shape[1],
318 "dilation_d * KD <= MAX_KERNEL)") ||
319 !levelCheckKernel(op, dilation[1] * shape[2],
320 "dilation_y * KH <= MAX_KERNEL)") ||
321 !levelCheckKernel(op, dilation[2] * shape[3],
322 "dilation_x * KW <= MAX_KERNEL)"))
324 }
else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325 assert(shape.size() == 4);
326 assert(dilation.size() == 2);
327 if (!levelCheckKernel(op, dilation[0] * shape[0],
328 "dilation_y * KH <= MAX_KERNEL)") ||
329 !levelCheckKernel(op, dilation[1] * shape[1],
330 "dilation_x * KW <= MAX_KERNEL)"))
339 template <
typename T>
343 if (ShapedType type = dyn_cast<ShapedType>(v.
getType())) {
344 auto shape = type.getShape();
345 assert(shape.size() == 3);
346 if (!levelCheckKernel(op, shape[1],
"H <= MAX_KERNEL") ||
347 !levelCheckKernel(op, shape[2],
"W <= MAX_KERNEL")) {
357 bool levelCheckTransposeConv2d(
Operation *op) {
358 if (
auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359 if (ShapedType filterType =
360 transpose.getFilter().getType().dyn_cast<ShapedType>()) {
361 auto shape = filterType.getShape();
362 assert(shape.size() == 4);
364 if (!levelCheckKernel(op, shape[1],
"KH <= MAX_KERNEL") ||
365 !levelCheckKernel(op, shape[2],
"KW <= MAX_KERNEL")) {
369 for (
auto p : transpose.getOutPad()) {
370 if (!levelCheckKernel(op, p,
"pad <= MAX_KERNEL")) {
374 for (
auto s : transpose.getStride()) {
375 if (!levelCheckStride(op, s,
"stride <= MAX_STRIDE")) {
385 if (
auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386 auto scale = resize.getScale();
387 int16_t scaleYN = scale[0];
388 int16_t scaleYD = scale[1];
389 int16_t scaleXN = scale[2];
390 int16_t scaleXD = scale[3];
391 if (!levelCheckScale(op, scaleYN / scaleYD,
392 "scale_y_n/scale_y_d <= MAX_SCALE") ||
393 !levelCheckScale(op, scaleXN / scaleXD,
394 "scale_x_n/scale_x_d <= MAX_SCALE")) {
403 void configLevelAndProfile() {
404 tosaLevel = TOSA_LEVEL_NONE;
405 if (level == TosaLevelEnum::EightK) {
406 tosaLevel = TOSA_LEVEL_EIGHTK;
411 bool CheckVariableReadOrWrite(
Operation *op);
419 if (tosaLevel == TOSA_LEVEL_NONE) {
424 if (!levelCheckRanks(op)) {
429 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
430 !levelCheckConv<tosa::Conv2DOp>(op) ||
431 !levelCheckConv<tosa::Conv3DOp>(op) ||
432 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
433 !levelCheckFFT<tosa::FFT2dOp>(op) ||
434 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
435 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
436 !levelCheckResize(op)) {
443 inline bool CompatibleTypes(
const mlir::Type &type,
446 return type == declaredType;
449 bool TosaValidation::CheckVariable(
Operation *op) {
450 if (isa<mlir::tosa::VariableOp>(op)) {
451 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
453 if (variablesMap.count(nameAttr)) {
454 op->
emitOpError() <<
"name has already been declared";
458 auto typeAttr = cast<mlir::TypeAttr>(op->
getAttr(
"type"));
461 variablesMap[nameAttr] = type;
467 bool TosaValidation::CheckVariableReadOrWrite(
Operation *op) {
468 if (isa<mlir::tosa::VariableReadOp>(op) ||
469 isa<mlir::tosa::VariableWriteOp>(op)) {
470 auto nameAttr = cast<mlir::StringAttr>(op->
getAttr(
"name"));
472 if (!variablesMap.count(nameAttr)) {
477 auto varType = variablesMap[nameAttr];
481 if (!CompatibleTypes(type, varType)) {
482 op->
emitOpError() <<
"operand type does not equal variable type";
489 if (!CompatibleTypes(type, varType)) {
490 op->
emitOpError() <<
"result type does not equal variable type";
500 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
506 void TosaValidation::runOnOperation() {
507 configLevelAndProfile();
510 if ((profile == TosaProfileEnum::BaseInference) &&
511 isa<FloatType>(getElementTypeOrSelf(operand))) {
512 return signalPassFailure();
515 return signalPassFailure();
521 if (StrictOperationSpecAlignment &&
failed(applyConstantOperandCheck(op)))
525 if (
failed(applyLevelCheck(op)))
529 if (
failed(applyVariableCheck(op)))
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.