MLIR 23.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matches with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Matchers.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/StringExtras.h"
30
31namespace mlir {
32namespace tosa {
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
35} // namespace tosa
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::tosa;
40
41namespace {
42
43static LogicalResult
44checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
45 for (const auto index : operandIndices) {
46 Attribute attr;
47 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
48 return op->emitOpError("expected compile time resolvable constant, but "
49 "got variable value for operand #")
50 << index;
51 }
52 }
53 return success();
54}
55
56static LogicalResult checkConstantOperandMul(Operation *op,
57 const TargetEnv &env) {
58 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
59 // Check 'shift'
60 return checkConstantOperands(op, {2});
61 }
62 return success();
63}
64
65static LogicalResult checkConstantOperandTable(Operation *op,
66 const TargetEnv &env) {
67 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
68 // Check 'table'
69 return checkConstantOperands(op, {1});
70 }
71 return success();
72}
73
74static LogicalResult checkConstantOperandPad(Operation *op,
75 const TargetEnv &env) {
76 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
77 // Assume this op is zero-padding if padConst is not presented
78 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
79 // Check 'pad_const'
80 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
81 return checkConstantOperands(op, {2});
82 }
83 return success();
84}
85
86static LogicalResult checkConstantOperandRescale(Operation *op,
87 const TargetEnv &env) {
88 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
89 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
90 return checkConstantOperands(op, {1, 2, 3, 4});
91 }
92 return success();
93}
94
95template <typename T>
96static LogicalResult checkConstantOperandConvOps(Operation *op,
97 const TargetEnv &env) {
98 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
99 // Check 'input_zp' and 'weight_zp'
100 return checkConstantOperands(op, {3, 4});
101 }
102 return success();
103}
104
105static LogicalResult checkConstantOperandMatMul(Operation *op,
106 const TargetEnv &env) {
107 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
108 // Check 'A_zp' and 'B_zp'
109 return checkConstantOperands(op, {2, 3});
110 }
111 return success();
112}
113
114static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
115 const TargetEnv &env) {
116 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
117 // Check 'input_zp' and 'output_zp'
118 return checkConstantOperands(op, {1, 2});
119 }
120 return success();
121}
122
123static LogicalResult checkConstantOperandNegate(Operation *op,
124 const TargetEnv &env) {
125 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
126 // Check 'input1_zp' and 'output_zp'
127 return checkConstantOperands(op, {1, 2});
128 }
129 return success();
130}
131
132static LogicalResult checkConstantOperandSilceShape(Operation *op,
133 const TargetEnv &env) {
134 if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
135 // Check 'start' and 'size'
136 return checkConstantOperands(op, {1, 2});
137 }
138 return success();
139}
140
141//===----------------------------------------------------------------------===//
142// TOSA Validation Pass.
143//===----------------------------------------------------------------------===//
144
145struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
146public:
147 explicit TosaValidation() { populateConstantOperandChecks(); }
148
149 explicit TosaValidation(const TosaValidationOptions &options)
150 : TosaValidation() {
151 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
152 this->allowInvalidOpDatatypeCombinations =
153 options.allowInvalidOpDatatypeCombinations;
154 }
155 void runOnOperation() final;
156
157 LogicalResult applyConstantOperandCheck(Operation *op) {
158 for (auto &checker : constCheckers) {
159 if (failed(checker(op, targetEnv)))
160 return failure();
161 }
162 return success();
163 }
164
165 LogicalResult applyLevelCheck(Operation *op);
166 LogicalResult applyAttributeCheck(Operation *op);
167
168 // check variable read/write data types against variable declarations
169 LogicalResult applyVariableCheck(Operation *op);
170
171 // check error if conditions
172 LogicalResult applyErrorIfCheck(Operation *op);
173
174private:
175 void populateConstantOperandChecks() {
176 constCheckers.emplace_back(checkConstantOperandMul);
177 constCheckers.emplace_back(checkConstantOperandTable);
178 constCheckers.emplace_back(checkConstantOperandPad);
179 constCheckers.emplace_back(checkConstantOperandRescale);
180 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
181 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
182 constCheckers.emplace_back(
183 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
184 constCheckers.emplace_back(
185 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
186 constCheckers.emplace_back(checkConstantOperandMatMul);
187 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
188 constCheckers.emplace_back(checkConstantOperandNegate);
189 constCheckers.emplace_back(checkConstantOperandSilceShape);
190 }
191
192 LogicalResult levelCheckKernel(Operation *op, int32_t v,
193 const StringRef checkDesc) {
194 if (v > targetEnv.getLevel().MAX_KERNEL)
195 return op->emitOpError() << "failed level check: " << checkDesc;
196 return success();
197 }
198
199 LogicalResult levelCheckStride(Operation *op, int32_t v,
200 const StringRef checkDesc) {
201 if (v > targetEnv.getLevel().MAX_STRIDE)
202 return op->emitOpError() << "failed level check: " << checkDesc;
203 return success();
204 }
205
206 LogicalResult levelCheckScale(Operation *op, int32_t v,
207 const StringRef checkDesc) {
208 if (v > targetEnv.getLevel().MAX_SCALE)
209 return op->emitOpError() << "failed level check: " << checkDesc;
210 return success();
211 }
212
213 LogicalResult levelCheckListSize(Operation *op, int32_t v,
214 const StringRef checkDesc) {
215 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
216 return op->emitOpError()
217 << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
218 return success();
219 }
220
221 // Perform the Level Rank check on the tensor type.
222 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
223 const StringRef operandOrResult,
224 int32_t highest_rank) {
225 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
226 if (!type.hasRank())
227 return op->emitOpError() << "failed level check: unranked tensor";
228 if (type.getRank() > highest_rank)
229 return op->emitOpError() << "failed level check: " << operandOrResult
230 << " rank(shape) <= MAX_RANK";
231 }
232 return success();
233 }
234
235 // Perform the Level Rank check on the tensor value.
236 LogicalResult levelCheckRank(Operation *op, const Value &v,
237 const StringRef operandOrResult,
238 int32_t highest_rank) {
239 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
240 }
241
242 // Perform the Level tensor size check on the tensor type.
243 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
244 const StringRef operandOrResult);
245
246 // Perform the Level tensor size check on the tensor value.
247 LogicalResult levelCheckSize(Operation *op, const Value &v,
248 const StringRef operandOrResult) {
249 return levelCheckSize(op, v.getType(), operandOrResult);
250 }
251
252 // Perform the Level shape length check on a value.
253 LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
254 const StringRef operandOrResult) {
255 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
256 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
257 return op->emitOpError()
258 << "failed shape type level check: " << typeToCheck
259 << " exceeds MAX_SHAPE_LEN";
260 }
261 return success();
262 }
263
264 // Level check sizes of all operands and results of the operation.
265 template <typename T>
266 LogicalResult levelCheckSizes(T tosaOp) {
267 auto op = tosaOp.getOperation();
268 for (auto v : op->getOperands()) {
269 if (failed(levelCheckSize(op, v, "operand")))
270 return failure();
271 }
272
273 for (auto v : op->getResults()) {
274 if (failed(levelCheckSize(op, v, "result")))
275 return failure();
276 }
277 return success();
278 }
279
280 // Level check ranks of all operands, attribute and results of the operation.
281 template <typename T>
282 LogicalResult levelCheckRanks(T tosaOp) {
283 auto op = tosaOp.getOperation();
284 const TosaLevel tosaLevel = targetEnv.getLevel();
285 for (auto v : op->getOperands()) {
286 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
287 return failure();
288 }
289
290 for (auto v : op->getResults()) {
291 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
292 return failure();
293 }
294 return success();
295 }
296 // Level check shape lengths of all operands and results of an operation that
297 // are tosa.shape type.
298 template <typename T>
299 LogicalResult levelCheckShapeLengths(T tosaOp) {
300 for (const auto &v : tosaOp->getOperands()) {
301 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
302 return failure();
303 }
304 for (const auto &v : tosaOp->getResults()) {
305 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
306 return failure();
307 }
308
309 return success();
310 }
311
312 // Level check ranks and sizes.
313 LogicalResult levelCheckRanksAndSizes(Operation *op);
314
315 // Pool Op: level check kernel/stride/pad values
316 template <typename T>
317 LogicalResult levelCheckPool(Operation *op) {
318 if (auto poolOp = dyn_cast<T>(op)) {
319 for (auto k : poolOp.getKernel()) {
320 if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
321 return failure();
322 }
323 }
324 for (auto s : poolOp.getStride()) {
325 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
326 return failure();
327 }
328 }
329 for (auto p : poolOp.getPad()) {
330 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
331 return failure();
332 }
333 }
334 }
335 return success();
336 }
337
338 // Conv Op: level check dilation/stride/pad values
339 template <typename T>
340 LogicalResult levelCheckConv(Operation *op) {
341 if (auto convOp = dyn_cast<T>(op)) {
342
343 for (auto k : convOp.getDilation()) {
344 if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
345 return failure();
346 }
347 }
348 for (auto p : convOp.getPad()) {
349 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
350 return failure();
351 }
352 }
353 for (auto s : convOp.getStride()) {
354 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
355 return failure();
356 }
357 }
358 auto dilation = convOp.getDilation();
359 if (ShapedType weightType =
360 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
361 auto shape = weightType.getShape();
362 if (isa<tosa::Conv2DOp>(op)) {
363 assert(shape.size() == 4);
364 assert(dilation.size() == 2);
365 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
366 "dilation_y * KH <= MAX_KERNEL)")) ||
367 failed(levelCheckKernel(op, dilation[1] * shape[2],
368 "dilation_x * KW <= MAX_KERNEL)")))
369 return failure();
370 } else if (isa<tosa::Conv3DOp>(op)) {
371 assert(shape.size() == 5);
372 assert(dilation.size() == 3);
373 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
374 "dilation_d * KD <= MAX_KERNEL)")) ||
375 failed(levelCheckKernel(op, dilation[1] * shape[2],
376 "dilation_y * KH <= MAX_KERNEL)")) ||
377 failed(levelCheckKernel(op, dilation[2] * shape[3],
378 "dilation_x * KW <= MAX_KERNEL)")))
379 return failure();
380 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
381 assert(shape.size() == 4);
382 assert(dilation.size() == 2);
383 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
384 "dilation_y * KH <= MAX_KERNEL)")) ||
385 failed(levelCheckKernel(op, dilation[1] * shape[1],
386 "dilation_x * KW <= MAX_KERNEL)")))
387 return failure();
388 }
389 }
390 }
391 return success();
392 }
393
394 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
395 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
396 if (!convOp)
397 return success();
398
399 SmallVector<int64_t> padValues;
400 if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
401 for (const auto p : padValues)
402 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
403 return failure();
404 }
405
406 SmallVector<int64_t> strideValues;
407 if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
408 strideValues)) {
409 for (const auto s : strideValues)
410 if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
411 return failure();
412 }
413
414 SmallVector<int64_t> dilationValues;
415 if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
416 dilationValues)) {
417 int64_t KH = ShapedType::kDynamic;
418 int64_t KW = ShapedType::kDynamic;
419 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
420 KH = weightDataShape.getDimSize(1);
421 KW = weightDataShape.getDimSize(2);
422 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
423 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
424 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
425
426 if (!ShapedType::isDynamic(KH) &&
427 failed(levelCheckKernel(op, dilationValues[0] * KH,
428 "dilation_y * KH <= MAX_KERNEL)")))
429 return failure();
430
431 if (!ShapedType::isDynamic(KW) &&
432 failed(levelCheckKernel(op, dilationValues[1] * KW,
433 "dilation_x * KW <= MAX_KERNEL)")))
434 return failure();
435 }
436
437 return success();
438 }
439
440 // FFT op: level check H, W in input shape [N,H,W]
441 template <typename T>
442 LogicalResult levelCheckFFT(Operation *op) {
443 if (isa<T>(op)) {
444 for (auto v : op->getOperands()) {
445 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
446 auto shape = type.getShape();
447 assert(shape.size() == 3);
448 if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
449 failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
450 return failure();
451 }
452 }
453 }
454 }
455 return success();
456 }
457
458 // TransposeConv2d op: level check kH/kW, outpad, and stride
459 LogicalResult levelCheckTransposeConv2d(Operation *op) {
460 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
461 if (ShapedType filterType =
462 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
463 auto shape = filterType.getShape();
464 assert(shape.size() == 4);
465 // level check kernel sizes for kH and KW
466 if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
467 failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
468 return failure();
469 }
470 }
471 for (auto p : transpose.getOutPad()) {
472 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
473 return failure();
474 }
475 }
476 for (auto s : transpose.getStride()) {
477 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
478 return failure();
479 }
480 }
481 }
482 return success();
483 }
484
485 // Resize op: level check max scales
486 LogicalResult levelCheckResize(Operation *op) {
487 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
488 SmallVector<int64_t> scale;
489 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
490 scale)) {
491 return failure();
492 }
493 const int64_t scaleYN = scale[0];
494 const int64_t scaleYD = scale[1];
495 const int64_t scaleXN = scale[2];
496 const int64_t scaleXD = scale[3];
497 if (failed(levelCheckScale(op, scaleYN / scaleYD,
498 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
499 failed(levelCheckScale(op, scaleXN / scaleXD,
500 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
501 return failure();
502 }
503 }
504 return success();
505 }
506
507 // Recursively perform a bottom-up search to determine the maximum nesting
508 // depth, starting from a specific operation and continuing up to the function
509 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
510 // time a new nested `region` is encountered.
511 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
512 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
513 return;
514
515 op = op->getParentOp();
516 if (!op)
517 return;
518
519 depth++;
520 getMaxNestedDepth(op, depth);
521 }
522
523 LogicalResult levelCheckMaxNesting(Operation *op) {
524 int32_t maxNestedDepth = 0;
525 getMaxNestedDepth(op, maxNestedDepth);
526
527 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
528 op->emitOpError() << "failed level check: " << maxNestedDepth
529 << " >= MAX_NESTING";
530 return failure();
531 }
532 return success();
533 }
534
535 LogicalResult levelCheckListSize(Operation *op) {
536 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
537 return levelCheckListSize(op, concat.getInput1().size(), "input1");
538 }
539 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
540 if (failed(levelCheckListSize(op, custom.getInputList().size(),
541 "input_list")) ||
542 failed(levelCheckListSize(op, custom.getOutputList().size(),
543 "output_list"))) {
544 return failure();
545 }
546 }
547 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
548 if (failed(
549 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
550 failed(levelCheckListSize(op, condIf.getOutputList().size(),
551 "outputs"))) {
552 return failure();
553 }
554 }
555 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
556 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
557 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
558 return failure();
559 }
560 }
561 if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
562 return levelCheckListSize(op, concat_shape.getInput().size(), "input");
563 return success();
564 }
565
566 LogicalResult attributeCheckRescale(Operation *op) {
567 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
568 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
569 !targetEnv.allows(Extension::doubleround)) {
570 op->emitOpError()
571 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
572 << "requires extension [doubleround]";
573 return failure();
574 }
575 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
576 !targetEnv.allows(Extension::inexactround)) {
577 op->emitOpError()
578 << "failed attribute check: rounding_mode = INEXACT_ROUND "
579 << "requires extension [inexactround]";
580 return failure();
581 }
582 }
583 return success();
584 }
585
586 LogicalResult CheckVariable(Operation *op);
587 LogicalResult CheckVariableReadOrWrite(Operation *op);
588 bool isValidElementType(Type type, const bool allowUnsigned = false);
589
590 SmallVector<
591 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
592 constCheckers;
594 TosaProfileCompliance profileComp;
595 tosa::TargetEnv targetEnv;
596};
597
598template <>
599LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
600 auto *op = tosaOp.getOperation();
601 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
602 targetEnv.getLevel().MAX_RANK)))
603 return failure();
604
605 // rank(output) = rank(input) - 1
606 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
607 targetEnv.getLevel().MAX_RANK - 1)))
608 return failure();
609
610 return success();
611}
612
613template <>
614LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
615 auto *op = tosaOp.getOperation();
616
617 // Only the condition input has rank limitation.
618 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
619 targetEnv.getLevel().MAX_RANK)))
620 return failure();
621
622 return success();
623}
624
625template <>
626LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
627 auto *op = tosaOp.getOperation();
628 auto variableType = getVariableType(tosaOp);
629 if (failed(levelCheckRank(op, variableType, "variable type",
630 targetEnv.getLevel().MAX_RANK)))
631 return failure();
632
633 return success();
634}
635
636template <>
637LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
638 auto *op = tosaOp.getOperation();
639 auto variableType = getVariableType(tosaOp);
640 if (failed(levelCheckSize(op, variableType, "variable type")))
641 return failure();
642
643 return success();
644}
645
646LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
647#define CHECK_RANKS_AND_SIZES(tosaOp) \
648 if (isa<tosa::tosaOp##Op>(op)) { \
649 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
650 return failure(); \
651 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
652 return failure(); \
653 }
654
655#define CHECK_SIZES(tosaOp) \
656 if (isa<tosa::tosaOp##Op>(op)) { \
657 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
658 return failure(); \
659 }
660
661#define CHECK_SHAPE_LEN(tosaOp) \
662 if (isa<tosa::tosaOp##Op>(op)) { \
663 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
664 return failure(); \
665 }
666
667 // Tensor Operators
668 CHECK_RANKS_AND_SIZES(ArgMax);
669 // Activation Functions
672 CHECK_RANKS_AND_SIZES(Sigmoid);
674 // Elementwise Binary Operators
676 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
677 CHECK_RANKS_AND_SIZES(BitwiseAnd);
678 CHECK_RANKS_AND_SIZES(BitwiseOr);
679 CHECK_RANKS_AND_SIZES(BitwiseXor);
680 CHECK_RANKS_AND_SIZES(IntDiv);
681 CHECK_RANKS_AND_SIZES(LogicalAnd);
682 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
683 CHECK_RANKS_AND_SIZES(LogicalRightShift);
684 CHECK_RANKS_AND_SIZES(LogicalOr);
685 CHECK_RANKS_AND_SIZES(LogicalXor);
686 CHECK_RANKS_AND_SIZES(Maximum);
687 CHECK_RANKS_AND_SIZES(Minimum);
692 // Elementwise Unary Operators
694 CHECK_RANKS_AND_SIZES(BitwiseNot);
701 CHECK_RANKS_AND_SIZES(LogicalNot);
702 CHECK_RANKS_AND_SIZES(Negate);
703 CHECK_RANKS_AND_SIZES(Reciprocal);
706 // Elementwise Ternary Operators
707 CHECK_RANKS_AND_SIZES(Select);
708 // Comparison Operators
710 CHECK_RANKS_AND_SIZES(Greater);
711 CHECK_RANKS_AND_SIZES(GreaterEqual);
712 // Reduction Operators
713 CHECK_RANKS_AND_SIZES(ReduceAll);
714 CHECK_RANKS_AND_SIZES(ReduceAny);
715 CHECK_RANKS_AND_SIZES(ReduceMax);
716 CHECK_RANKS_AND_SIZES(ReduceMin);
717 CHECK_RANKS_AND_SIZES(ReduceProduct);
718 CHECK_RANKS_AND_SIZES(ReduceSum);
719 // Data Layout Operators
720 CHECK_RANKS_AND_SIZES(Concat);
722 CHECK_RANKS_AND_SIZES(Reshape);
723 CHECK_RANKS_AND_SIZES(Reverse);
726 CHECK_RANKS_AND_SIZES(Transpose);
727 // Type Conversion
729 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
730 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
731 CHECK_RANKS_AND_SIZES(Rescale);
732 // Data Nodes
734 CHECK_RANKS_AND_SIZES(Identity);
735 // Control Flow Operators
737 // Variable Operators
738 CHECK_RANKS_AND_SIZES(Variable);
739 CHECK_RANKS_AND_SIZES(VariableWrite);
740 CHECK_RANKS_AND_SIZES(VariableRead);
741 // Shape Operators
743
744 // For the following operators, check whether the size of each tensor
745 // operand is valid in a given Level.
746
747 // Tensor Operators
748 CHECK_SIZES(AvgPool2d);
749 CHECK_SIZES(Conv2D);
750 CHECK_SIZES(Conv2DBlockScaled);
751 CHECK_SIZES(Conv3D);
752 CHECK_SIZES(DepthwiseConv2D);
753 CHECK_SIZES(TransposeConv2D);
754 CHECK_SIZES(FFT2d);
755 CHECK_SIZES(MatMul);
756 CHECK_SIZES(MatmulTBlockScaled);
757 CHECK_SIZES(MaxPool2d);
758 CHECK_SIZES(RFFT2d);
759 // Scatter/Gather Operators
761 CHECK_SIZES(Scatter);
762 // Image Operators
763 CHECK_SIZES(Resize);
764 // Custom Operators
765 CHECK_SIZES(Custom);
766 // Control Flow Operators
767 CHECK_SIZES(While);
768 // Shape Operators
769 CHECK_SIZES(ConstShape);
770
771 // For the following operations, check whether the shape length of each
772 // operand is valid given a level.
773
774 // Shape Operators
775 CHECK_SHAPE_LEN(AddShape);
776 CHECK_SHAPE_LEN(AssertEqualShape);
777 CHECK_SHAPE_LEN(ConcatShape);
778 CHECK_SHAPE_LEN(DivCeilShape);
779 CHECK_SHAPE_LEN(DivFloorShape);
780 CHECK_SHAPE_LEN(Exp2Shape);
781 CHECK_SHAPE_LEN(Log2CeilShape);
782 CHECK_SHAPE_LEN(Log2FloorShape);
783 CHECK_SHAPE_LEN(MaxShape);
784 CHECK_SHAPE_LEN(MinShape);
785 CHECK_SHAPE_LEN(ModShape);
786 CHECK_SHAPE_LEN(MulShape);
787 CHECK_SHAPE_LEN(SliceShape);
788 CHECK_SHAPE_LEN(SubShape);
789
790#undef CHECK_RANKS_AND_SIZES
791#undef CHECK_SIZES
792#undef CHECK_SHAPE_LEN
793 return success();
794}
795
796// Perform the Level tensor size check on the tensor type.
797LogicalResult TosaValidation::levelCheckSize(Operation *op,
798 const Type &typeToCheck,
799 const StringRef operandOrResult) {
800 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
801 if (!type.hasRank())
802 return op->emitOpError() << "failed level check: unranked tensor";
803 auto shape = type.getShape();
804 for (auto dim : shape) {
805 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
806 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
807 const TosaSpecificationVersion minRequiredVersion(1, 1);
808 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
809 dimIsDynamic)
810 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
811 // resolved at backend compile time. Runtime dynamism is not currently
812 // supported. Checking this requirement is met is delegated to backends.
813 return success();
814
815 // When targeting TOSA 1.0 or below, dynamic dims are not supported
816 if (dimIsDynamic)
817 return op->emitOpError() << "failed level check: " << operandOrResult
818 << " shape dimension cannot be dynamic when"
819 << " targeting TOSA specification version 1.0"
820 << " or below";
821 }
822
823 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
824 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
825 int64_t size = element_bytes * type.getNumElements();
826
827 // According to 1.11. Tensor Definitions of Tosa spec, the value of
828 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
829 // defined in 1.7. Levels.
830 // For each tensor, the number of tensor elements multiplied by the
831 // element size in bytes must be representable as a tensor_size_t.
832 const int64_t max_size =
833 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
834 if (size > max_size)
835 return op->emitOpError()
836 << "failed level check: " << operandOrResult
837 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
838 }
839 return success();
840}
841
842LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
843 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
844 // no need to do level checks
845 return success();
846 }
847
848 // check rank and sizes early so later checks can assume shaped operands
849 if (failed(levelCheckRanksAndSizes(op)))
850 return failure();
851
852 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
853 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
854 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
855 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
856 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
857 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
858 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
859 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
860 failed(levelCheckConv2DBlockScaled(op))) {
861 return failure();
862 }
863
864 // level check MAX_TENSOR_LIST_SIZE
865 if (failed(levelCheckListSize(op))) {
866 return failure();
867 }
868
869 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
870 if (failed(levelCheckMaxNesting(op))) {
871 return failure();
872 }
873 }
874
875 return success();
876}
877
878LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
879 if (failed(attributeCheckRescale(op)))
880 return failure();
881 return success();
882}
883
884inline bool CompatibleTypes(const mlir::Type &type,
885 const mlir::Type &declaredType) {
886 // for now, simply use type equality comparison
887 return type == declaredType;
888}
889
890LogicalResult TosaValidation::CheckVariable(Operation *op) {
891 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
892 mlir::StringAttr nameAttr = variableOp.getNameAttr();
893
894 if (variablesMap.count(nameAttr))
895 return op->emitOpError() << "name has already been declared";
896
897 auto elementType = variableOp.getType();
898 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
899 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
900 RankedTensorType variableType =
901 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
902
903 variablesMap[nameAttr] = variableType;
904 }
905
906 return success();
907}
908
909LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
910 if (isa<mlir::tosa::VariableReadOp>(op) ||
911 isa<mlir::tosa::VariableWriteOp>(op)) {
912 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
913 if (!variablesMap.count(nameAttr))
914 return op->emitOpError() << "name has not been declared";
915
916 auto varType = variablesMap[nameAttr];
917
918 for (auto v : op->getOperands()) {
919 auto type = v.getType();
920 if (!CompatibleTypes(type, varType))
921 return op->emitOpError() << "operand type does not equal variable type";
922 }
923
924 for (auto v : op->getResults()) {
925 auto type = v.getType();
926 if (!CompatibleTypes(type, varType))
927 return op->emitOpError() << "result type does not equal variable type";
928 }
929 }
930
931 return success();
932}
933
934LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
935 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
936 return failure();
937 return success();
938}
939
940LogicalResult checkErrorIfResize(Operation *op) {
941 auto resize = dyn_cast<tosa::ResizeOp>(op);
942 if (!resize)
943 return success();
944
945 const Value input = resize.getInput();
946 const Value output = resize.getOutput();
947 const RankedTensorType inputType =
948 llvm::dyn_cast<RankedTensorType>(input.getType());
949 const RankedTensorType outputType =
950 llvm::dyn_cast<RankedTensorType>(output.getType());
951
952 if (!inputType || !outputType)
953 return op->emitOpError("expect ranked input/output tensor");
954
955 // Ensure the image size is supported by GPU APIs and that for integer
956 // implementations, position * stride does not overflow int32_t.
957 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
958 const SmallVector<int64_t, 4> sizes = {
959 outputType.getDimSize(1), outputType.getDimSize(2),
960 inputType.getDimSize(1), inputType.getDimSize(2)};
961 const int64_t *maxDim = llvm::max_element(sizes);
962 if (maxDim != sizes.end() && *maxDim >= 16384)
963 return op->emitOpError(
964 "expect input/output height/width dims to be < 16384, ")
965 << "got [OH, OW, IH, IW] = " << sizes;
966 }
967
968 SmallVector<int64_t> scale;
969 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
970 return failure();
971
972 const int64_t scaleYN = scale[0];
973 const int64_t scaleYD = scale[1];
974 const int64_t scaleXN = scale[2];
975 const int64_t scaleXD = scale[3];
976
977 // Ensure scale values don't overflow int32 accumulator
978 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
979 return op->emitOpError(
980 "expect all scale numerator values to be <= (1 << 11), "
981 "got scale_y_n=")
982 << scaleYN << ", scale_x_n=" << scaleXN;
983
984 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
985 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
986 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
987
988 SmallVector<int64_t> offset;
989 SmallVector<int64_t> border;
990 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
991 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
992 return failure();
993
994 const int64_t offsetY = offset[0];
995 const int64_t offsetX = offset[1];
996 // Set a consistent lower limit of 1/16 downscale to simplify
997 // implementations
998 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
999 return op->emitOpError(
1000 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1001 << offsetY << "/" << scaleYN;
1002 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1003 return op->emitOpError(
1004 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1005 << offsetX << "/" << scaleXN;
1006
1007 const int64_t borderY = border[0];
1008 const int64_t borderX = border[1];
1009 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1010 return op->emitOpError(
1011 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1012 << borderY << "/" << scaleYN;
1013 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1014 return op->emitOpError(
1015 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1016 << borderX << "/" << scaleXN;
1017
1018 // The following section of code is mostly duplicated with ResizeOp::verify().
1019 //
1020 // In TOSA specification, we do not support broadcast behavior.
1021 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
1022 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
1023 // existing code, we keep the rewrite pattern untouched. So, we need
1024 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
1025 //
1026 // Here is a strict checking to conform TOSA specification.
1027 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
1028 auto idivCheck = [](const int64_t lhs,
1029 const int64_t rhs) -> std::optional<int64_t> {
1030 if (lhs % rhs != 0)
1031 return std::nullopt;
1032 return lhs / rhs;
1033 };
1034
1035 const int64_t oh = outputType.getDimSize(1);
1036 const int64_t ow = outputType.getDimSize(2);
1037 const int64_t ih = inputType.getDimSize(1);
1038 const int64_t iw = inputType.getDimSize(2);
1039
1040 if (ih != ShapedType::kDynamic) {
1041 const std::optional<int64_t> calculatedOutHeightMinusOne =
1042 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1043 if (!calculatedOutHeightMinusOne.has_value())
1044 return op->emitOpError(
1045 "expected (input_height - 1) * scale_y_n - offset_y + "
1046 "border_y ")
1047 << "to be wholly divisible by scale_y_d, got ((" << ih
1048 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1049 << ") / " << scaleYD;
1050 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1051 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1052 return op->emitOpError(
1053 "calculated output height did not match expected: ")
1054 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1055 }
1056
1057 if (iw != ShapedType::kDynamic) {
1058 const std::optional<int64_t> calculatedOutWidthMinusOne =
1059 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1060 if (!calculatedOutWidthMinusOne.has_value())
1061 return op->emitOpError(
1062 "expected (input_width - 1) * scale_x_n - offset_x + "
1063 "border_x ")
1064 << "to be wholly divisible by scale_x_d, got ((" << iw
1065 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1066 << ") / " << scaleXD;
1067 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1068 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1069 return op->emitOpError("calculated output width did not match expected: ")
1070 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1071 }
1072
1073 return success();
1074}
1075
1076LogicalResult checkErrorIfMul(Operation *op) {
1077 auto mul = dyn_cast<tosa::MulOp>(op);
1078 if (!mul)
1079 return success();
1080
1081 // REQUIRE(0 <= shift && shift <= 63);
1082 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1083 ElementsAttr shift_elem;
1084 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
1085 return success();
1086 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1087 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1088 if (inputElemType.isInteger(32)) {
1089 // 0 <= shift <= 63 for int32_t type
1090 if (shift < 0 || shift > 63)
1091 return op->emitOpError()
1092 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1093 } else {
1094 // shift must be 0 for all other types
1095 if (shift != 0)
1096 return op->emitOpError()
1097 << "requires shift = 0 for all input data types that "
1098 "are not int32_t, but got: "
1099 << shift;
1100 }
1101
1102 return success();
1103}
1104
1105LogicalResult checkErrorIfTable(Operation *op) {
1106 auto table = dyn_cast<tosa::TableOp>(op);
1107 if (!table)
1108 return success();
1109
1110 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1111 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1112 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1113
1114 const ShapeAdaptor tableShape(table.getTable().getType());
1115 if (tableShape.hasStaticShape()) {
1116 const auto numElements = tableShape.getNumElements();
1117 if (numElements != tableSize)
1118 return op->emitOpError() << "requires table size of " << tableSize
1119 << ", got " << numElements;
1120 }
1121
1122 return success();
1123}
1124
1125LogicalResult checkErrorIfRescale(Operation *op) {
1126 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1127 if (!rescale)
1128 return success();
1129
1130 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1131 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1132 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1133 !outputType.getElementType().isInteger())
1134 return success();
1135
1136 auto inElemType = inputType.getElementType();
1137 auto outElemType = outputType.getElementType();
1138 auto inWidth = inElemType.getIntOrFloatBitWidth();
1139 auto outWidth = outElemType.getIntOrFloatBitWidth();
1140
1141 bool inputUnsigned = rescale.getInputUnsigned();
1142 bool outputUnsigned = rescale.getOutputUnsigned();
1143
1144 bool scale32 = rescale.getScale32();
1145 auto roundingMode = rescale.getRoundingMode();
1146
1147 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1148 if (scale32 && inWidth == 48)
1149 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1150
1151 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1152 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1153 return op->emitOpError()
1154 << "DOUBLE_ROUND is only allowed with scale32=true.";
1155
1156 // ERROR_IF(input_unsigned && output_unsigned)
1157 if (inputUnsigned && outputUnsigned)
1158 return op->emitOpError() << "input and output cannot be both unsigned.";
1159
1160 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1161 if (outWidth == 32 && inputUnsigned)
1162 return op->emitOpError()
1163 << "i32 output type is not allowed with unsigned input.";
1164
1165 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1166 if (inWidth == 32 && outputUnsigned)
1167 return op->emitOpError()
1168 << "i32 input type is not allowed with unsigned output.";
1169
1170 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1171 if (inWidth == 48 && outputUnsigned)
1172 return op->emitOpError()
1173 << "i48 input type is not allowed with unsigned output.";
1174
1175 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1176 if (inWidth == 48 && inputUnsigned)
1177 return op->emitOpError() << "i48 input type cannot be unsigned.";
1178
1179 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1180 if (inWidth == 32 && inputUnsigned)
1181 return op->emitOpError() << "i32 input type cannot be unsigned.";
1182
1183 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1184 if (outWidth == 32 && outputUnsigned)
1185 return op->emitOpError() << "i32 output type cannot be unsigned.";
1186
1187 return success();
1188}
1189
1190LogicalResult checkErrorIfPad(Operation *op) {
1191 auto pad = dyn_cast<tosa::PadOp>(op);
1192 if (!pad)
1193 return success();
1194
1195 DenseIntElementsAttr paddingAttr;
1196 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1197 // Pad verifier will catch this
1198 return success();
1199
1200 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1201 if (val.getSExtValue() < 0)
1202 return op->emitOpError() << "padding value must all be non-negative, got "
1203 << val.getSExtValue();
1204 }
1205
1206 return success();
1207}
1208
1209static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1210 return llvm::all_of(op->getOperands(), [&](auto operand) {
1211 Region *operandRegion = operand.getParentRegion();
1212 return operandRegion && region->isAncestor(operandRegion);
1213 });
1214}
1215
1216static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1217 bool noLiveInValue = true;
1218 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1219 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1220 noLiveInValue = false;
1221 return WalkResult::interrupt();
1222 }
1223 return WalkResult::advance();
1224 });
1225 return noLiveInValue ? success() : failure();
1226}
1227
1228LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1229 StringRef regionName) {
1230 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1231 return success();
1232 return op->emitOpError()
1233 << "is not conformant to the TOSA specification. It requires the '"
1234 << regionName << "' region is isolated from above.\n";
1235}
1236
1237LogicalResult checkErrorIfCondIf(Operation *op) {
1238 auto ifOp = dyn_cast<tosa::IfOp>(op);
1239 if (!ifOp)
1240 return success();
1241
1242 // Currently the dialect supports declaring cond_if operations that
1243 // have then/else regions that reference values from outside these
1244 // regions. According to the specification, all values used by the
1245 // then/else regions must be explicitly declared within the regions.
1246 // Therefore we must check that the then/else regions are
1247 // "isolated from above", in order to be conformant to the
1248 // specification.
1249 //
1250 // Note: the dialect currently supports two styles of syntax for
1251 // declaring "cond_if" operations. We'll refer to these as follows:
1252 //
1253 // Generic:
1254 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1255 // ^bb0(%arg3, %arg4):
1256 // tosa.yield %arg3
1257 // }, {
1258 // ^bb0(%arg3, %arg4):
1259 // tosa.yield %arg4
1260 // })
1261 //
1262 // Simplified:
1263 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1264 // ^bb0(%arg3, %arg4):
1265 // tosa.yield %arg3
1266 // } else {
1267 // ^bb0(%arg3, %arg4):
1268 // tosa.yield %arg4
1269 // }
1270
1271 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1272 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1273 return failure();
1274 return success();
1275}
1276
1277LogicalResult checkErrorIfWhileLoop(Operation *op) {
1278 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1279 if (!whileOp)
1280 return success();
1281
1282 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1283 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1284 return failure();
1285 return success();
1286}
1287
1288LogicalResult checkErrorIfScatter(Operation *op) {
1289 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1290 if (!scatterOp)
1291 return success();
1292
1293 // for constant indices, check that there are no duplicate values
1294 DenseIntElementsAttr indicesAttr;
1295 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1296 return success();
1297
1298 auto const indicesType =
1299 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1300 if (!indicesType || !indicesType.hasRank()) {
1301 op->emitOpError("expect ranked indices tensor");
1302 return failure();
1303 }
1304
1305 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1306 op->emitOpError("indices values contain duplicates");
1307 return failure();
1308 }
1309
1310 return success();
1311}
1312
1313LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1314 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1315 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1316 failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
1317 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1318 return failure();
1319 return success();
1320}
1321
1322bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1323 if (isa<FloatType>(type)) {
1324 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1325 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1326 Float6E3M2FNType, Float8E8M0FNUType>(type);
1327 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1328 if (intTy.isSignless()) {
1329 switch (intTy.getWidth()) {
1330 case 1:
1331 case 4:
1332 case 8:
1333 case 16:
1334 case 32:
1335 case 48:
1336 case 64:
1337 return true;
1338 }
1339 } else if (allowUnsigned && intTy.isUnsigned()) {
1340 switch (intTy.getWidth()) {
1341 case 8:
1342 case 16:
1343 case 32:
1344 return true;
1345 }
1346 }
1347 } else if (isa<tosa::shapeType>(type))
1348 return true;
1349 else if (isa<tosa::mxint8Type>(type))
1350 return true;
1351 return false;
1352}
1353
1354void TosaValidation::runOnOperation() {
1355 ModuleOp modOp = getOperation();
1356 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1357 if (!tosaDialect)
1358 return;
1359
1360 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1361 const auto maybeTargetEnv =
1362 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1363 if (failed(maybeTargetEnv))
1364 return signalPassFailure();
1365 targetEnv = *maybeTargetEnv;
1366
1367 modOp.walk([&](Operation *op) {
1368 if (op->getDialect() != tosaDialect)
1369 return;
1370
1371 // validate operator element types:
1372 // - rescale operator is allowed to have ui8/ui16/ui32
1373 // operands/results when strictOpSpecAlignment is false
1374 // - perform valid element type check at the beginning to
1375 // protect rest of code against quantized element types
1376 const bool allowUnsigned =
1377 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1378 for (Value operand : op->getOperands()) {
1379 auto elementTy = getElementTypeOrSelf(operand);
1380 if (!isValidElementType(elementTy, allowUnsigned)) {
1381 op->emitOpError() << "is not profile-aligned: element type "
1382 << elementTy << " is not legal";
1383 return signalPassFailure();
1384 }
1385 }
1386 for (Type resultTy : op->getResultTypes()) {
1387 auto elementTy = getElementTypeOrSelf(resultTy);
1388 if (!isValidElementType(elementTy, allowUnsigned)) {
1389 op->emitOpError() << "is not profile-aligned: element type "
1390 << elementTy << " is not legal";
1391 return signalPassFailure();
1392 }
1393 }
1394
1395 if (strictOpSpecAlignment &&
1396 failed(profileComp.checkProfile(op, targetEnv)))
1397 return signalPassFailure();
1398
1399 if (strictOpSpecAlignment &&
1400 failed(profileComp.checkExtension(op, targetEnv)))
1401 return signalPassFailure();
1402
1403 if (!allowInvalidOpDatatypeCombinations &&
1404 failed(profileComp.checkInvalid(op)))
1405 return signalPassFailure();
1406
1407 // Some uses of TOSA rely on the constant operands of particular
1408 // operations.
1409 if (failed(applyConstantOperandCheck(op)))
1410 signalPassFailure();
1411
1412 // do level checks
1413 if (failed(applyLevelCheck(op)))
1414 signalPassFailure();
1415
1416 // check additional attribute restrictions
1417 if (failed(applyAttributeCheck(op)))
1418 signalPassFailure();
1419
1420 // do variable type checks
1421 if (failed(applyVariableCheck(op)))
1422 signalPassFailure();
1423
1424 // do error if checks
1425 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1426 signalPassFailure();
1427 });
1428}
1429} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:567
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:100
TosaLevel getLevel() const
Definition TargetEnv.h:117
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:65
bool allows(Profile prof) const
Definition TargetEnv.h:127
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:113
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:45
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:619
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369