MLIR 23.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matches with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Matchers.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
32namespace mlir {
33namespace tosa {
34#define GEN_PASS_DEF_TOSAVALIDATION
35#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36} // namespace tosa
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::tosa;
41
42namespace {
43
44static LogicalResult
45checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46 for (const auto index : operandIndices) {
47 Attribute attr;
48 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
49 return op->emitOpError("expected compile time resolvable constant, but "
50 "got variable value for operand #")
51 << index;
52 }
53 }
54 return success();
55}
56
57static LogicalResult checkConstantOperandMul(Operation *op,
58 const TargetEnv &env) {
59 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60 // Check 'shift'
61 return checkConstantOperands(op, {2});
62 }
63 return success();
64}
65
66static LogicalResult checkConstantOperandTable(Operation *op,
67 const TargetEnv &env) {
68 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69 // Check 'table'
70 return checkConstantOperands(op, {1});
71 }
72 return success();
73}
74
75static LogicalResult checkConstantOperandPad(Operation *op,
76 const TargetEnv &env) {
77 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
78 // Assume this op is zero-padding if padConst is not presented
79 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80 // Check 'pad_const'
81 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82 return checkConstantOperands(op, {2});
83 }
84 return success();
85}
86
87static LogicalResult checkConstantOperandRescale(Operation *op,
88 const TargetEnv &env) {
89 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91 return checkConstantOperands(op, {1, 2, 3, 4});
92 }
93 return success();
94}
95
96template <typename T>
97static LogicalResult checkConstantOperandConvOps(Operation *op,
98 const TargetEnv &env) {
99 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100 // Check 'input_zp' and 'weight_zp'
101 return checkConstantOperands(op, {3, 4});
102 }
103 return success();
104}
105
106static LogicalResult checkConstantOperandMatMul(Operation *op,
107 const TargetEnv &env) {
108 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109 // Check 'A_zp' and 'B_zp'
110 return checkConstantOperands(op, {2, 3});
111 }
112 return success();
113}
114
115static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116 const TargetEnv &env) {
117 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118 // Check 'input_zp' and 'output_zp'
119 return checkConstantOperands(op, {1, 2});
120 }
121 return success();
122}
123
124static LogicalResult checkConstantOperandNegate(Operation *op,
125 const TargetEnv &env) {
126 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127 // Check 'input1_zp' and 'output_zp'
128 return checkConstantOperands(op, {1, 2});
129 }
130 return success();
131}
132
133static LogicalResult checkConstantOperandSilceShape(Operation *op,
134 const TargetEnv &env) {
135 if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
136 // Check 'start' and 'size'
137 return checkConstantOperands(op, {1, 2});
138 }
139 return success();
140}
141
142//===----------------------------------------------------------------------===//
143// TOSA Validation Pass.
144//===----------------------------------------------------------------------===//
145
146struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
147public:
148 explicit TosaValidation() { populateConstantOperandChecks(); }
149
150 explicit TosaValidation(const TosaValidationOptions &options)
151 : TosaValidation() {
152 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
153 this->allowInvalidOpDatatypeCombinations =
154 options.allowInvalidOpDatatypeCombinations;
155 }
156 void runOnOperation() final;
157
158 LogicalResult applyConstantOperandCheck(Operation *op) {
159 for (auto &checker : constCheckers) {
160 if (failed(checker(op, targetEnv)))
161 return failure();
162 }
163 return success();
164 }
165
166 LogicalResult applyLevelCheck(Operation *op);
167 LogicalResult applyAttributeCheck(Operation *op);
168
169 // check variable read/write data types against variable declarations
170 LogicalResult applyVariableCheck(Operation *op);
171
172 // check error if conditions
173 LogicalResult applyErrorIfCheck(Operation *op);
174
175private:
176 void populateConstantOperandChecks() {
177 constCheckers.emplace_back(checkConstantOperandMul);
178 constCheckers.emplace_back(checkConstantOperandTable);
179 constCheckers.emplace_back(checkConstantOperandPad);
180 constCheckers.emplace_back(checkConstantOperandRescale);
181 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
182 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
183 constCheckers.emplace_back(
184 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
185 constCheckers.emplace_back(
186 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
187 constCheckers.emplace_back(checkConstantOperandMatMul);
188 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
189 constCheckers.emplace_back(checkConstantOperandNegate);
190 constCheckers.emplace_back(checkConstantOperandSilceShape);
191 }
192
193 LogicalResult levelCheck(Operation *op, const int32_t calculatedValue,
194 const int32_t maxLevel, const StringRef inputName,
195 const StringRef levelName) {
196 if (calculatedValue > maxLevel)
197 return op->emitOpError()
198 << "failed level check: " << inputName << " <= " << levelName
199 << " (" << maxLevel << "), got " << calculatedValue;
200 return success();
201 }
202
203 LogicalResult levelCheckKernel(Operation *op, int32_t v,
204 const StringRef inputName) {
205 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
206 "MAX_KERNEL");
207 }
208
209 LogicalResult levelCheckStride(Operation *op, int32_t v,
210 const StringRef inputName) {
211 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
212 "MAX_STRIDE");
213 }
214
215 LogicalResult levelCheckScale(Operation *op, int32_t v,
216 const StringRef inputName) {
217 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
218 "MAX_SCALE");
219 }
220
221 LogicalResult levelCheckListSize(Operation *op, int32_t v,
222 const StringRef inputName) {
223 const std::string inputDesc =
224 llvm::formatv("length(tensor_list_shape({0}))", inputName);
225 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
226 inputDesc, "MAX_TENSOR_LIST_SIZE");
227 }
228
229 // Perform the Level Rank check on the tensor type.
230 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
231 const StringRef operandOrResult,
232 int32_t highest_rank) {
233 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
234 if (!type.hasRank())
235 return op->emitOpError() << "failed level check: unranked tensor";
236 if (type.getRank() > highest_rank)
237 return op->emitOpError() << "failed level check: " << operandOrResult
238 << " rank(shape) <= MAX_RANK";
239 }
240 return success();
241 }
242
243 // Perform the Level Rank check on the tensor value.
244 LogicalResult levelCheckRank(Operation *op, const Value &v,
245 const StringRef operandOrResult,
246 int32_t highest_rank) {
247 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
248 }
249
250 // Perform the Level tensor size check on the tensor type.
251 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
252 const StringRef operandOrResult);
253
254 // Perform the Level tensor size check on the tensor value.
255 LogicalResult levelCheckSize(Operation *op, const Value &v,
256 const StringRef operandOrResult) {
257 return levelCheckSize(op, v.getType(), operandOrResult);
258 }
259
260 // Perform the Level shape length check on a value.
261 LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
262 const StringRef operandOrResult) {
263 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
264 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
265 return op->emitOpError()
266 << "failed shape type level check: " << typeToCheck
267 << " exceeds MAX_SHAPE_LEN";
268 }
269 return success();
270 }
271
272 // Level check sizes of all operands and results of the operation.
273 template <typename T>
274 LogicalResult levelCheckSizes(T tosaOp) {
275 auto op = tosaOp.getOperation();
276 for (auto v : op->getOperands()) {
277 if (failed(levelCheckSize(op, v, "operand")))
278 return failure();
279 }
280
281 for (auto v : op->getResults()) {
282 if (failed(levelCheckSize(op, v, "result")))
283 return failure();
284 }
285 return success();
286 }
287
288 // Level check ranks of all operands, attribute and results of the operation.
289 template <typename T>
290 LogicalResult levelCheckRanks(T tosaOp) {
291 auto op = tosaOp.getOperation();
292 const TosaLevel tosaLevel = targetEnv.getLevel();
293 for (auto v : op->getOperands()) {
294 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
295 return failure();
296 }
297
298 for (auto v : op->getResults()) {
299 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
300 return failure();
301 }
302 return success();
303 }
304 // Level check shape lengths of all operands and results of an operation that
305 // are tosa.shape type.
306 template <typename T>
307 LogicalResult levelCheckShapeLengths(T tosaOp) {
308 for (const auto &v : tosaOp->getOperands()) {
309 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
310 return failure();
311 }
312 for (const auto &v : tosaOp->getResults()) {
313 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
314 return failure();
315 }
316
317 return success();
318 }
319
320 // Level check ranks and sizes.
321 LogicalResult levelCheckRanksAndSizes(Operation *op);
322
323 // Pool Op: level check kernel/stride/pad values
324 template <typename T>
325 LogicalResult levelCheckPool(Operation *op) {
326 if (auto poolOp = dyn_cast<T>(op)) {
327 for (auto k : poolOp.getKernel()) {
328 if (failed(levelCheckKernel(op, k, "kernel"))) {
329 return failure();
330 }
331 }
332 for (auto s : poolOp.getStride()) {
333 if (failed(levelCheckStride(op, s, "stride"))) {
334 return failure();
335 }
336 }
337 for (auto p : poolOp.getPad()) {
338 if (failed(levelCheckKernel(op, p, "pad"))) {
339 return failure();
340 }
341 }
342 }
343 return success();
344 }
345
346 // Conv Op: level check dilation/stride/pad values
347 template <typename T>
348 LogicalResult levelCheckConv(Operation *op) {
349 if (auto convOp = dyn_cast<T>(op)) {
350
351 for (auto k : convOp.getDilation()) {
352 if (failed(levelCheckKernel(op, k, "dilation"))) {
353 return failure();
354 }
355 }
356 for (auto p : convOp.getPad()) {
357 if (failed(levelCheckKernel(op, p, "pad"))) {
358 return failure();
359 }
360 }
361 for (auto s : convOp.getStride()) {
362 if (failed(levelCheckStride(op, s, "stride"))) {
363 return failure();
364 }
365 }
366 auto dilation = convOp.getDilation();
367 if (ShapedType weightType =
368 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
369 auto shape = weightType.getShape();
370 if (isa<tosa::Conv2DOp>(op)) {
371 assert(shape.size() == 4);
372 assert(dilation.size() == 2);
373 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
374 "dilation_y * KH")) ||
375 failed(levelCheckKernel(op, dilation[1] * shape[2],
376 "dilation_x * KW")))
377 return failure();
378 } else if (isa<tosa::Conv3DOp>(op)) {
379 assert(shape.size() == 5);
380 assert(dilation.size() == 3);
381 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
382 "dilation_d * KD")) ||
383 failed(levelCheckKernel(op, dilation[1] * shape[2],
384 "dilation_y * KH")) ||
385 failed(levelCheckKernel(op, dilation[2] * shape[3],
386 "dilation_x * KW")))
387 return failure();
388 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
389 assert(shape.size() == 4);
390 assert(dilation.size() == 2);
391 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
392 "dilation_y * KH")) ||
393 failed(levelCheckKernel(op, dilation[1] * shape[1],
394 "dilation_x * KW")))
395 return failure();
396 }
397 }
398 }
399 return success();
400 }
401
402 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
403 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
404 if (!convOp)
405 return success();
406
407 SmallVector<int64_t> padValues;
408 if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
409 for (const auto p : padValues)
410 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
411 return failure();
412 }
413
414 SmallVector<int64_t> strideValues;
415 if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
416 strideValues)) {
417 for (const auto s : strideValues)
418 if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
419 return failure();
420 }
421
422 SmallVector<int64_t> dilationValues;
423 if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
424 dilationValues)) {
425 int64_t KH = ShapedType::kDynamic;
426 int64_t KW = ShapedType::kDynamic;
427 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
428 KH = weightDataShape.getDimSize(1);
429 KW = weightDataShape.getDimSize(2);
430 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
431 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
432 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
433
434 if (!ShapedType::isDynamic(KH) &&
435 failed(levelCheckKernel(op, dilationValues[0] * KH,
436 "dilation_y * KH <= MAX_KERNEL)")))
437 return failure();
438
439 if (!ShapedType::isDynamic(KW) &&
440 failed(levelCheckKernel(op, dilationValues[1] * KW,
441 "dilation_x * KW <= MAX_KERNEL)")))
442 return failure();
443 }
444
445 return success();
446 }
447
448 // FFT op: level check H, W in input shape [N,H,W]
449 template <typename T>
450 LogicalResult levelCheckFFT(Operation *op) {
451 if (isa<T>(op)) {
452 for (auto v : op->getOperands()) {
453 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
454 auto shape = type.getShape();
455 assert(shape.size() == 3);
456 if (failed(levelCheckKernel(op, shape[1], "H")) ||
457 failed(levelCheckKernel(op, shape[2], "W"))) {
458 return failure();
459 }
460 }
461 }
462 }
463 return success();
464 }
465
466 // TransposeConv2d op: level check kH/kW, outpad, and stride
467 LogicalResult levelCheckTransposeConv2d(Operation *op) {
468 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
469 if (ShapedType filterType =
470 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
471 auto shape = filterType.getShape();
472 assert(shape.size() == 4);
473 // level check kernel sizes for kH and KW
474 if (failed(levelCheckKernel(op, shape[1], "KH")) ||
475 failed(levelCheckKernel(op, shape[2], "KW"))) {
476 return failure();
477 }
478 }
479 for (auto p : transpose.getOutPad()) {
480 if (failed(levelCheckKernel(op, p, "pad"))) {
481 return failure();
482 }
483 }
484 for (auto s : transpose.getStride()) {
485 if (failed(levelCheckStride(op, s, "stride"))) {
486 return failure();
487 }
488 }
489 }
490 return success();
491 }
492
493 // Resize op: level check max scales
494 LogicalResult levelCheckResize(Operation *op) {
495 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
496 SmallVector<int64_t> scale;
497 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
498 scale)) {
499 return failure();
500 }
501 const int64_t scaleYN = scale[0];
502 const int64_t scaleYD = scale[1];
503 const int64_t scaleXN = scale[2];
504 const int64_t scaleXD = scale[3];
505 if (failed(
506 levelCheckScale(op, scaleYN / scaleYD, "scale_y_n/scale_y_d")) ||
507 failed(
508 levelCheckScale(op, scaleXN / scaleXD, "scale_x_n/scale_x_d"))) {
509 return failure();
510 }
511 }
512 return success();
513 }
514
515 // Recursively perform a bottom-up search to determine the maximum nesting
516 // depth, starting from a specific operation and continuing up to the function
517 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
518 // time a new nested `region` is encountered.
519 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
520 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
521 return;
522
523 op = op->getParentOp();
524 if (!op)
525 return;
526
527 depth++;
528 getMaxNestedDepth(op, depth);
529 }
530
531 LogicalResult levelCheckMaxNesting(Operation *op) {
532 int32_t maxNestedDepth = 0;
533 getMaxNestedDepth(op, maxNestedDepth);
534
535 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
536 if (maxNestedDepth >= maxNestingLevel)
537 return op->emitOpError()
538 << "failed level check: tosa_nesting_depth < MAX_NESTING" << " ("
539 << maxNestingLevel << "), got " << maxNestedDepth;
540 return success();
541 }
542
543 LogicalResult levelCheckListSize(Operation *op) {
544 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
545 return levelCheckListSize(op, concat.getInput1().size(), "input1");
546 }
547 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
548 if (failed(levelCheckListSize(op, custom.getInputList().size(),
549 "input_list")) ||
550 failed(levelCheckListSize(op, custom.getOutputList().size(),
551 "output_list"))) {
552 return failure();
553 }
554 }
555 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
556 if (failed(
557 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
558 failed(levelCheckListSize(op, condIf.getOutputList().size(),
559 "outputs"))) {
560 return failure();
561 }
562 }
563 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
564 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
565 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
566 return failure();
567 }
568 }
569 if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
570 return levelCheckListSize(op, concat_shape.getInput().size(), "input");
571 return success();
572 }
573
574 LogicalResult attributeCheckRescale(Operation *op) {
575 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
576 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
577 !targetEnv.allows(Extension::doubleround)) {
578 op->emitOpError()
579 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
580 << "requires extension [doubleround]";
581 return failure();
582 }
583 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
584 !targetEnv.allows(Extension::inexactround)) {
585 op->emitOpError()
586 << "failed attribute check: rounding_mode = INEXACT_ROUND "
587 << "requires extension [inexactround]";
588 return failure();
589 }
590 }
591 return success();
592 }
593
594 LogicalResult CheckVariable(Operation *op);
595 LogicalResult CheckVariableReadOrWrite(Operation *op);
596 bool isValidElementType(Type type, const bool allowUnsigned = false);
597
598 SmallVector<
599 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
600 constCheckers;
602 TosaProfileCompliance profileComp;
603 tosa::TargetEnv targetEnv;
604};
605
606template <>
607LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
608 auto *op = tosaOp.getOperation();
609 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
610 targetEnv.getLevel().MAX_RANK)))
611 return failure();
612
613 // rank(output) = rank(input) - 1
614 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
615 targetEnv.getLevel().MAX_RANK - 1)))
616 return failure();
617
618 return success();
619}
620
621template <>
622LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
623 auto *op = tosaOp.getOperation();
624
625 // Only the condition input has rank limitation.
626 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
627 targetEnv.getLevel().MAX_RANK)))
628 return failure();
629
630 return success();
631}
632
633template <>
634LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
635 auto *op = tosaOp.getOperation();
636 auto variableType = getVariableType(tosaOp);
637 if (failed(levelCheckRank(op, variableType, "variable type",
638 targetEnv.getLevel().MAX_RANK)))
639 return failure();
640
641 return success();
642}
643
644template <>
645LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
646 auto *op = tosaOp.getOperation();
647 auto variableType = getVariableType(tosaOp);
648 if (failed(levelCheckSize(op, variableType, "variable type")))
649 return failure();
650
651 return success();
652}
653
654LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
655#define CHECK_RANKS_AND_SIZES(tosaOp) \
656 if (isa<tosa::tosaOp##Op>(op)) { \
657 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
658 return failure(); \
659 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
660 return failure(); \
661 }
662
663#define CHECK_SIZES(tosaOp) \
664 if (isa<tosa::tosaOp##Op>(op)) { \
665 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
666 return failure(); \
667 }
668
669#define CHECK_SHAPE_LEN(tosaOp) \
670 if (isa<tosa::tosaOp##Op>(op)) { \
671 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
672 return failure(); \
673 }
674
675 // Tensor Operators
676 CHECK_RANKS_AND_SIZES(ArgMax);
677 // Activation Functions
680 CHECK_RANKS_AND_SIZES(Sigmoid);
682 // Elementwise Binary Operators
684 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
685 CHECK_RANKS_AND_SIZES(BitwiseAnd);
686 CHECK_RANKS_AND_SIZES(BitwiseOr);
687 CHECK_RANKS_AND_SIZES(BitwiseXor);
688 CHECK_RANKS_AND_SIZES(IntDiv);
689 CHECK_RANKS_AND_SIZES(LogicalAnd);
690 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
691 CHECK_RANKS_AND_SIZES(LogicalRightShift);
692 CHECK_RANKS_AND_SIZES(LogicalOr);
693 CHECK_RANKS_AND_SIZES(LogicalXor);
694 CHECK_RANKS_AND_SIZES(Maximum);
695 CHECK_RANKS_AND_SIZES(Minimum);
700 // Elementwise Unary Operators
702 CHECK_RANKS_AND_SIZES(BitwiseNot);
709 CHECK_RANKS_AND_SIZES(LogicalNot);
710 CHECK_RANKS_AND_SIZES(Negate);
711 CHECK_RANKS_AND_SIZES(Reciprocal);
714 // Elementwise Ternary Operators
715 CHECK_RANKS_AND_SIZES(Select);
716 // Comparison Operators
718 CHECK_RANKS_AND_SIZES(Greater);
719 CHECK_RANKS_AND_SIZES(GreaterEqual);
720 // Reduction Operators
721 CHECK_RANKS_AND_SIZES(ReduceAll);
722 CHECK_RANKS_AND_SIZES(ReduceAny);
723 CHECK_RANKS_AND_SIZES(ReduceMax);
724 CHECK_RANKS_AND_SIZES(ReduceMin);
725 CHECK_RANKS_AND_SIZES(ReduceProduct);
726 CHECK_RANKS_AND_SIZES(ReduceSum);
727 // Data Layout Operators
728 CHECK_RANKS_AND_SIZES(Concat);
730 CHECK_RANKS_AND_SIZES(Reshape);
731 CHECK_RANKS_AND_SIZES(Reverse);
734 CHECK_RANKS_AND_SIZES(Transpose);
735 // Type Conversion
737 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
738 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
739 CHECK_RANKS_AND_SIZES(Rescale);
740 // Data Nodes
742 CHECK_RANKS_AND_SIZES(Identity);
743 // Control Flow Operators
745 // Variable Operators
746 CHECK_RANKS_AND_SIZES(Variable);
747 CHECK_RANKS_AND_SIZES(VariableWrite);
748 CHECK_RANKS_AND_SIZES(VariableRead);
749 // Shape Operators
751
752 // For the following operators, check whether the size of each tensor
753 // operand is valid in a given Level.
754
755 // Tensor Operators
756 CHECK_SIZES(AvgPool2d);
757 CHECK_SIZES(Conv2D);
758 CHECK_SIZES(Conv2DBlockScaled);
759 CHECK_SIZES(Conv3D);
760 CHECK_SIZES(DepthwiseConv2D);
761 CHECK_SIZES(TransposeConv2D);
762 CHECK_SIZES(FFT2d);
763 CHECK_SIZES(MatMul);
764 CHECK_SIZES(MatmulTBlockScaled);
765 CHECK_SIZES(MaxPool2d);
766 CHECK_SIZES(RFFT2d);
767 // Scatter/Gather Operators
769 CHECK_SIZES(Scatter);
770 // Image Operators
771 CHECK_SIZES(Resize);
772 // Custom Operators
773 CHECK_SIZES(Custom);
774 // Control Flow Operators
775 CHECK_SIZES(While);
776 // Shape Operators
777 CHECK_SIZES(ConstShape);
778
779 // For the following operations, check whether the shape length of each
780 // operand is valid given a level.
781
782 // Shape Operators
783 CHECK_SHAPE_LEN(AddShape);
784 CHECK_SHAPE_LEN(AssertEqualShape);
785 CHECK_SHAPE_LEN(ConcatShape);
786 CHECK_SHAPE_LEN(DivCeilShape);
787 CHECK_SHAPE_LEN(DivFloorShape);
788 CHECK_SHAPE_LEN(Exp2Shape);
789 CHECK_SHAPE_LEN(Log2CeilShape);
790 CHECK_SHAPE_LEN(Log2FloorShape);
791 CHECK_SHAPE_LEN(MaxShape);
792 CHECK_SHAPE_LEN(MinShape);
793 CHECK_SHAPE_LEN(ModShape);
794 CHECK_SHAPE_LEN(MulShape);
795 CHECK_SHAPE_LEN(SliceShape);
796 CHECK_SHAPE_LEN(SubShape);
797
798#undef CHECK_RANKS_AND_SIZES
799#undef CHECK_SIZES
800#undef CHECK_SHAPE_LEN
801 return success();
802}
803
804// Perform the Level tensor size check on the tensor type.
805LogicalResult TosaValidation::levelCheckSize(Operation *op,
806 const Type &typeToCheck,
807 const StringRef operandOrResult) {
808 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
809 if (!type.hasRank())
810 return op->emitOpError() << "failed level check: unranked tensor";
811 auto shape = type.getShape();
812 for (auto dim : shape) {
813 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
814 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
815 const TosaSpecificationVersion minRequiredVersion(1, 1);
816 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
817 dimIsDynamic)
818 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
819 // resolved at backend compile time. Runtime dynamism is not currently
820 // supported. Checking this requirement is met is delegated to backends.
821 return success();
822
823 // When targeting TOSA 1.0 or below, dynamic dims are not supported
824 if (dimIsDynamic)
825 return op->emitOpError() << "failed level check: " << operandOrResult
826 << " shape dimension cannot be dynamic when"
827 << " targeting TOSA specification version 1.0"
828 << " or below";
829 }
830
831 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
832 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
833 int64_t size = element_bytes * type.getNumElements();
834
835 // According to 1.11. Tensor Definitions of Tosa spec, the value of
836 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
837 // defined in 1.7. Levels.
838 // For each tensor, the number of tensor elements multiplied by the
839 // element size in bytes must be representable as a tensor_size_t.
840 const int64_t max_size =
841 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
842 if (size > max_size)
843 return op->emitOpError()
844 << "failed level check: " << operandOrResult
845 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
846 }
847 return success();
848}
849
850LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
851 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
852 // no need to do level checks
853 return success();
854 }
855
856 // check rank and sizes early so later checks can assume shaped operands
857 if (failed(levelCheckRanksAndSizes(op)))
858 return failure();
859
860 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
861 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
862 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
863 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
864 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
865 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
866 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
867 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
868 failed(levelCheckConv2DBlockScaled(op))) {
869 return failure();
870 }
871
872 // level check MAX_TENSOR_LIST_SIZE
873 if (failed(levelCheckListSize(op))) {
874 return failure();
875 }
876
877 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
878 if (failed(levelCheckMaxNesting(op))) {
879 return failure();
880 }
881 }
882
883 return success();
884}
885
886LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
887 if (failed(attributeCheckRescale(op)))
888 return failure();
889 return success();
890}
891
892inline bool CompatibleTypes(const mlir::Type &type,
893 const mlir::Type &declaredType) {
894 // for now, simply use type equality comparison
895 return type == declaredType;
896}
897
898LogicalResult TosaValidation::CheckVariable(Operation *op) {
899 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
900 mlir::StringAttr nameAttr = variableOp.getNameAttr();
901
902 if (variablesMap.count(nameAttr))
903 return op->emitOpError() << "name has already been declared";
904
905 auto elementType = variableOp.getType();
906 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
907 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
908 RankedTensorType variableType =
909 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
910
911 variablesMap[nameAttr] = variableType;
912 }
913
914 return success();
915}
916
917LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
918 if (isa<mlir::tosa::VariableReadOp>(op) ||
919 isa<mlir::tosa::VariableWriteOp>(op)) {
920 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
921 if (!variablesMap.count(nameAttr))
922 return op->emitOpError() << "name has not been declared";
923
924 auto varType = variablesMap[nameAttr];
925
926 for (auto v : op->getOperands()) {
927 auto type = v.getType();
928 if (!CompatibleTypes(type, varType))
929 return op->emitOpError() << "operand type does not equal variable type";
930 }
931
932 for (auto v : op->getResults()) {
933 auto type = v.getType();
934 if (!CompatibleTypes(type, varType))
935 return op->emitOpError() << "result type does not equal variable type";
936 }
937 }
938
939 return success();
940}
941
942LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
943 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
944 return failure();
945 return success();
946}
947
948LogicalResult checkErrorIfResize(Operation *op) {
949 auto resize = dyn_cast<tosa::ResizeOp>(op);
950 if (!resize)
951 return success();
952
953 const Value input = resize.getInput();
954 const Value output = resize.getOutput();
955 const RankedTensorType inputType =
956 llvm::dyn_cast<RankedTensorType>(input.getType());
957 const RankedTensorType outputType =
958 llvm::dyn_cast<RankedTensorType>(output.getType());
959
960 if (!inputType || !outputType)
961 return op->emitOpError("expect ranked input/output tensor");
962
963 // Ensure the image size is supported by GPU APIs and that for integer
964 // implementations, position * stride does not overflow int32_t.
965 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
966 const SmallVector<int64_t, 4> sizes = {
967 outputType.getDimSize(1), outputType.getDimSize(2),
968 inputType.getDimSize(1), inputType.getDimSize(2)};
969 const int64_t *maxDim = llvm::max_element(sizes);
970 if (maxDim != sizes.end() && *maxDim >= 16384)
971 return op->emitOpError(
972 "expect input/output height/width dims to be < 16384, ")
973 << "got [OH, OW, IH, IW] = " << sizes;
975
977 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
978 return failure();
980 const int64_t scaleYN = scale[0];
981 const int64_t scaleYD = scale[1];
982 const int64_t scaleXN = scale[2];
983 const int64_t scaleXD = scale[3];
984
985 // Ensure scale values don't overflow int32 accumulator
986 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
987 return op->emitOpError(
988 "expect all scale numerator values to be <= (1 << 11), "
989 "got scale_y_n=")
990 << scaleYN << ", scale_x_n=" << scaleXN;
992 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
993 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
994 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
995
998 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
999 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
1000 return failure();
1001
1002 const int64_t offsetY = offset[0];
1003 const int64_t offsetX = offset[1];
1004 // Set a consistent lower limit of 1/16 downscale to simplify
1005 // implementations
1006 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1007 return op->emitOpError(
1008 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1009 << offsetY << "/" << scaleYN;
1010 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1011 return op->emitOpError(
1012 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1013 << offsetX << "/" << scaleXN;
1014
1015 const int64_t borderY = border[0];
1016 const int64_t borderX = border[1];
1017 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1018 return op->emitOpError(
1019 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1020 << borderY << "/" << scaleYN;
1021 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1022 return op->emitOpError(
1023 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1024 << borderX << "/" << scaleXN;
1026 // The following section of code is mostly duplicated with ResizeOp::verify().
1027 //
1028 // In TOSA specification, we do not support broadcast behavior.
1029 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
1030 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
1031 // existing code, we keep the rewrite pattern untouched. So, we need
1032 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
1033 //
1034 // Here is a strict checking to conform TOSA specification.
1035 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
1036 auto idivCheck = [](const int64_t lhs,
1037 const int64_t rhs) -> std::optional<int64_t> {
1038 if (lhs % rhs != 0)
1039 return std::nullopt;
1040 return lhs / rhs;
1041 };
1043 const int64_t oh = outputType.getDimSize(1);
1044 const int64_t ow = outputType.getDimSize(2);
1045 const int64_t ih = inputType.getDimSize(1);
1046 const int64_t iw = inputType.getDimSize(2);
1047
1048 if (ih != ShapedType::kDynamic) {
1049 const std::optional<int64_t> calculatedOutHeightMinusOne =
1050 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1051 if (!calculatedOutHeightMinusOne.has_value())
1052 return op->emitOpError(
1053 "expected (input_height - 1) * scale_y_n - offset_y + "
1054 "border_y ")
1055 << "to be wholly divisible by scale_y_d, got ((" << ih
1056 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1057 << ") / " << scaleYD;
1058 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1059 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1060 return op->emitOpError(
1061 "calculated output height did not match expected: ")
1062 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1063 }
1064
1065 if (iw != ShapedType::kDynamic) {
1066 const std::optional<int64_t> calculatedOutWidthMinusOne =
1067 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1068 if (!calculatedOutWidthMinusOne.has_value())
1069 return op->emitOpError(
1070 "expected (input_width - 1) * scale_x_n - offset_x + "
1071 "border_x ")
1072 << "to be wholly divisible by scale_x_d, got ((" << iw
1073 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1074 << ") / " << scaleXD;
1075 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1076 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1077 return op->emitOpError("calculated output width did not match expected: ")
1078 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1079 }
1080
1081 return success();
1082}
1083
1084LogicalResult checkErrorIfMul(Operation *op) {
1085 auto mul = dyn_cast<tosa::MulOp>(op);
1086 if (!mul)
1087 return success();
1088
1089 // REQUIRE(0 <= shift && shift <= 63);
1090 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1091 ElementsAttr shift_elem;
1092 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
1093 return success();
1094 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1095 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1096 if (inputElemType.isInteger(32)) {
1097 // 0 <= shift <= 63 for int32_t type
1098 if (shift < 0 || shift > 63)
1099 return op->emitOpError()
1100 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1101 } else {
1102 // shift must be 0 for all other types
1103 if (shift != 0)
1104 return op->emitOpError()
1105 << "requires shift = 0 for all input data types that "
1106 "are not int32_t, but got: "
1107 << shift;
1108 }
1109
1110 return success();
1111}
1112
1113LogicalResult checkErrorIfTable(Operation *op) {
1114 auto table = dyn_cast<tosa::TableOp>(op);
1115 if (!table)
1116 return success();
1117
1118 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1119 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1120 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1121
1122 const ShapeAdaptor tableShape(table.getTable().getType());
1123 if (tableShape.hasStaticShape()) {
1124 const auto numElements = tableShape.getNumElements();
1125 if (numElements != tableSize)
1126 return op->emitOpError() << "requires table size of " << tableSize
1127 << ", got " << numElements;
1128 }
1129
1130 return success();
1131}
1132
1133LogicalResult checkErrorIfRescale(Operation *op) {
1134 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1135 if (!rescale)
1136 return success();
1137
1138 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1139 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1140 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1141 !outputType.getElementType().isInteger())
1142 return success();
1143
1144 auto inElemType = inputType.getElementType();
1145 auto outElemType = outputType.getElementType();
1146 auto inWidth = inElemType.getIntOrFloatBitWidth();
1147 auto outWidth = outElemType.getIntOrFloatBitWidth();
1148
1149 bool inputUnsigned = rescale.getInputUnsigned();
1150 bool outputUnsigned = rescale.getOutputUnsigned();
1151
1152 bool scale32 = rescale.getScale32();
1153 auto roundingMode = rescale.getRoundingMode();
1154
1155 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1156 if (scale32 && inWidth == 48)
1157 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1158
1159 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1160 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1161 return op->emitOpError()
1162 << "DOUBLE_ROUND is only allowed with scale32=true.";
1163
1164 // ERROR_IF(input_unsigned && output_unsigned)
1165 if (inputUnsigned && outputUnsigned)
1166 return op->emitOpError() << "input and output cannot be both unsigned.";
1167
1168 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1169 if (outWidth == 32 && inputUnsigned)
1170 return op->emitOpError()
1171 << "i32 output type is not allowed with unsigned input.";
1172
1173 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1174 if (inWidth == 32 && outputUnsigned)
1175 return op->emitOpError()
1176 << "i32 input type is not allowed with unsigned output.";
1177
1178 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1179 if (inWidth == 48 && outputUnsigned)
1180 return op->emitOpError()
1181 << "i48 input type is not allowed with unsigned output.";
1182
1183 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1184 if (inWidth == 48 && inputUnsigned)
1185 return op->emitOpError() << "i48 input type cannot be unsigned.";
1186
1187 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1188 if (inWidth == 32 && inputUnsigned)
1189 return op->emitOpError() << "i32 input type cannot be unsigned.";
1190
1191 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1192 if (outWidth == 32 && outputUnsigned)
1193 return op->emitOpError() << "i32 output type cannot be unsigned.";
1194
1195 return success();
1196}
1197
1198LogicalResult checkErrorIfPad(Operation *op) {
1199 auto pad = dyn_cast<tosa::PadOp>(op);
1200 if (!pad)
1201 return success();
1202
1203 DenseIntElementsAttr paddingAttr;
1204 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1205 // Pad verifier will catch this
1206 return success();
1207
1208 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1209 if (val.getSExtValue() < 0)
1210 return op->emitOpError() << "padding value must all be non-negative, got "
1211 << val.getSExtValue();
1212 }
1213
1214 return success();
1215}
1216
1217static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1218 return llvm::all_of(op->getOperands(), [&](auto operand) {
1219 Region *operandRegion = operand.getParentRegion();
1220 return operandRegion && region->isAncestor(operandRegion);
1221 });
1222}
1223
1224static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1225 bool noLiveInValue = true;
1226 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1227 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1228 noLiveInValue = false;
1229 return WalkResult::interrupt();
1230 }
1231 return WalkResult::advance();
1232 });
1233 return noLiveInValue ? success() : failure();
1234}
1235
1236LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1237 StringRef regionName) {
1238 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1239 return success();
1240 return op->emitOpError()
1241 << "is not conformant to the TOSA specification. It requires the '"
1242 << regionName << "' region is isolated from above.\n";
1243}
1244
1245LogicalResult checkErrorIfCondIf(Operation *op) {
1246 auto ifOp = dyn_cast<tosa::IfOp>(op);
1247 if (!ifOp)
1248 return success();
1249
1250 // Currently the dialect supports declaring cond_if operations that
1251 // have then/else regions that reference values from outside these
1252 // regions. According to the specification, all values used by the
1253 // then/else regions must be explicitly declared within the regions.
1254 // Therefore we must check that the then/else regions are
1255 // "isolated from above", in order to be conformant to the
1256 // specification.
1257 //
1258 // Note: the dialect currently supports two styles of syntax for
1259 // declaring "cond_if" operations. We'll refer to these as follows:
1260 //
1261 // Generic:
1262 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1263 // ^bb0(%arg3, %arg4):
1264 // tosa.yield %arg3
1265 // }, {
1266 // ^bb0(%arg3, %arg4):
1267 // tosa.yield %arg4
1268 // })
1269 //
1270 // Simplified:
1271 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1272 // ^bb0(%arg3, %arg4):
1273 // tosa.yield %arg3
1274 // } else {
1275 // ^bb0(%arg3, %arg4):
1276 // tosa.yield %arg4
1277 // }
1278
1279 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1280 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1281 return failure();
1282 return success();
1283}
1284
1285LogicalResult checkErrorIfWhileLoop(Operation *op) {
1286 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1287 if (!whileOp)
1288 return success();
1289
1290 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1291 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1292 return failure();
1293 return success();
1294}
1295
1296LogicalResult checkErrorIfScatter(Operation *op) {
1297 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1298 if (!scatterOp)
1299 return success();
1300
1301 // for constant indices, check that there are no duplicate values
1302 DenseIntElementsAttr indicesAttr;
1303 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1304 return success();
1305
1306 auto const indicesType =
1307 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1308 if (!indicesType || !indicesType.hasRank()) {
1309 op->emitOpError("expect ranked indices tensor");
1310 return failure();
1311 }
1312
1313 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1314 op->emitOpError("indices values contain duplicates");
1315 return failure();
1316 }
1317
1318 return success();
1319}
1320
1321LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1322 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1323 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1324 failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
1325 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1326 return failure();
1327 return success();
1328}
1329
1330bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1331 if (isa<FloatType>(type)) {
1332 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1333 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1334 Float6E3M2FNType, Float8E8M0FNUType>(type);
1335 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1336 if (intTy.isSignless()) {
1337 switch (intTy.getWidth()) {
1338 case 1:
1339 case 4:
1340 case 8:
1341 case 16:
1342 case 32:
1343 case 48:
1344 case 64:
1345 return true;
1346 }
1347 } else if (allowUnsigned && intTy.isUnsigned()) {
1348 switch (intTy.getWidth()) {
1349 case 8:
1350 case 16:
1351 case 32:
1352 return true;
1353 }
1354 }
1355 } else if (isa<tosa::shapeType>(type))
1356 return true;
1357 else if (isa<tosa::mxint8Type>(type))
1358 return true;
1359 return false;
1360}
1361
1362void TosaValidation::runOnOperation() {
1363 ModuleOp modOp = getOperation();
1364 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1365 if (!tosaDialect)
1366 return;
1367
1368 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1369 const auto maybeTargetEnv =
1370 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1371 if (failed(maybeTargetEnv))
1372 return signalPassFailure();
1373 targetEnv = *maybeTargetEnv;
1374
1375 modOp.walk([&](Operation *op) {
1376 if (op->getDialect() != tosaDialect)
1377 return;
1378
1379 // validate operator element types:
1380 // - rescale operator is allowed to have ui8/ui16/ui32
1381 // operands/results when strictOpSpecAlignment is false
1382 // - perform valid element type check at the beginning to
1383 // protect rest of code against quantized element types
1384 const bool allowUnsigned =
1385 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1386 for (Value operand : op->getOperands()) {
1387 auto elementTy = getElementTypeOrSelf(operand);
1388 if (!isValidElementType(elementTy, allowUnsigned)) {
1389 op->emitOpError() << "is not profile-aligned: element type "
1390 << elementTy << " is not legal";
1391 return signalPassFailure();
1392 }
1393 }
1394 for (Type resultTy : op->getResultTypes()) {
1395 auto elementTy = getElementTypeOrSelf(resultTy);
1396 if (!isValidElementType(elementTy, allowUnsigned)) {
1397 op->emitOpError() << "is not profile-aligned: element type "
1398 << elementTy << " is not legal";
1399 return signalPassFailure();
1400 }
1401 }
1402
1403 if (strictOpSpecAlignment &&
1404 failed(profileComp.checkProfile(op, targetEnv)))
1405 return signalPassFailure();
1406
1407 if (strictOpSpecAlignment &&
1408 failed(profileComp.checkExtension(op, targetEnv)))
1409 return signalPassFailure();
1410
1411 if (!allowInvalidOpDatatypeCombinations &&
1412 failed(profileComp.checkInvalid(op)))
1413 return signalPassFailure();
1414
1415 // Some uses of TOSA rely on the constant operands of particular
1416 // operations.
1417 if (failed(applyConstantOperandCheck(op)))
1418 signalPassFailure();
1419
1420 // do level checks
1421 if (failed(applyLevelCheck(op)))
1422 signalPassFailure();
1423
1424 // check additional attribute restrictions
1425 if (failed(applyAttributeCheck(op)))
1426 signalPassFailure();
1427
1428 // do variable type checks
1429 if (failed(applyVariableCheck(op)))
1430 signalPassFailure();
1431
1432 // do error if checks
1433 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1434 signalPassFailure();
1435 });
1436}
1437} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:568
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:100
TosaLevel getLevel() const
Definition TargetEnv.h:117
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:65
bool allows(Profile prof) const
Definition TargetEnv.h:127
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:113
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:45
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:620
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369