MLIR 22.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matchs with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Matchers.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/StringExtras.h"
30
31namespace mlir {
32namespace tosa {
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
35} // namespace tosa
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::tosa;
40
41namespace {
42
43static LogicalResult
44checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
45 for (const auto index : operandIndices) {
46 Attribute attr;
47 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
48 return op->emitOpError("expected compile time resolvable constant, but "
49 "got variable value for operand #")
50 << index;
51 }
52 }
53 return success();
54}
55
56static LogicalResult checkConstantOperandMul(Operation *op,
57 const TargetEnv &env) {
58 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
59 // Check 'shift'
60 return checkConstantOperands(op, {2});
61 }
62 return success();
63}
64
65static LogicalResult checkConstantOperandTable(Operation *op,
66 const TargetEnv &env) {
67 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
68 // Check 'table'
69 return checkConstantOperands(op, {1});
70 }
71 return success();
72}
73
74static LogicalResult checkConstantOperandPad(Operation *op,
75 const TargetEnv &env) {
76 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
77 // Assume this op is zero-padding if padConst is not presented
78 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
79 // Check 'pad_const'
80 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
81 return checkConstantOperands(op, {2});
82 }
83 return success();
84}
85
86static LogicalResult checkConstantOperandRescale(Operation *op,
87 const TargetEnv &env) {
88 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
89 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
90 return checkConstantOperands(op, {1, 2, 3, 4});
91 }
92 return success();
93}
94
95template <typename T>
96static LogicalResult checkConstantOperandConvOps(Operation *op,
97 const TargetEnv &env) {
98 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
99 // Check 'input_zp' and 'weight_zp'
100 return checkConstantOperands(op, {3, 4});
101 }
102 return success();
103}
104
105static LogicalResult checkConstantOperandMatMul(Operation *op,
106 const TargetEnv &env) {
107 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
108 // Check 'A_zp' and 'B_zp'
109 return checkConstantOperands(op, {2, 3});
110 }
111 return success();
112}
113
114static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
115 const TargetEnv &env) {
116 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
117 // Check 'input_zp' and 'output_zp'
118 return checkConstantOperands(op, {1, 2});
119 }
120 return success();
121}
122
123static LogicalResult checkConstantOperandNegate(Operation *op,
124 const TargetEnv &env) {
125 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
126 // Check 'input1_zp' and 'output_zp'
127 return checkConstantOperands(op, {1, 2});
128 }
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// TOSA Validation Pass.
134//===----------------------------------------------------------------------===//
135
136struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137public:
138 explicit TosaValidation() { populateConstantOperandChecks(); }
139
140 explicit TosaValidation(const TosaValidationOptions &options)
141 : TosaValidation() {
142 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
143 this->allowInvalidOpDatatypeCombinations =
144 options.allowInvalidOpDatatypeCombinations;
145 }
146 void runOnOperation() final;
147
148 LogicalResult applyConstantOperandCheck(Operation *op) {
149 for (auto &checker : constCheckers) {
150 if (failed(checker(op, targetEnv)))
151 return failure();
152 }
153 return success();
154 }
155
156 LogicalResult applyLevelCheck(Operation *op);
157 LogicalResult applyAttributeCheck(Operation *op);
158
159 // check variable read/write data types against variable declarations
160 LogicalResult applyVariableCheck(Operation *op);
161
162 // check error if conditions
163 LogicalResult applyErrorIfCheck(Operation *op);
164
165private:
166 void populateConstantOperandChecks() {
167 constCheckers.emplace_back(checkConstantOperandMul);
168 constCheckers.emplace_back(checkConstantOperandTable);
169 constCheckers.emplace_back(checkConstantOperandPad);
170 constCheckers.emplace_back(checkConstantOperandRescale);
171 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173 constCheckers.emplace_back(
174 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175 constCheckers.emplace_back(
176 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177 constCheckers.emplace_back(checkConstantOperandMatMul);
178 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179 constCheckers.emplace_back(checkConstantOperandNegate);
180 }
181
182 LogicalResult levelCheckKernel(Operation *op, int32_t v,
183 const StringRef checkDesc) {
184 if (v > targetEnv.getLevel().MAX_KERNEL)
185 return op->emitOpError() << "failed level check: " << checkDesc;
186 return success();
187 }
188
189 LogicalResult levelCheckStride(Operation *op, int32_t v,
190 const StringRef checkDesc) {
191 if (v > targetEnv.getLevel().MAX_STRIDE)
192 return op->emitOpError() << "failed level check: " << checkDesc;
193 return success();
194 }
195
196 LogicalResult levelCheckScale(Operation *op, int32_t v,
197 const StringRef checkDesc) {
198 if (v > targetEnv.getLevel().MAX_SCALE)
199 return op->emitOpError() << "failed level check: " << checkDesc;
200 return success();
201 }
202
203 LogicalResult levelCheckListSize(Operation *op, int32_t v,
204 const StringRef checkDesc) {
205 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
206 return op->emitOpError()
207 << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
208 return success();
209 }
210
211 // Perform the Level Rank check on the tensor type.
212 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
213 const StringRef operandOrResult,
214 int32_t highest_rank) {
215 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
216 if (!type.hasRank())
217 return op->emitOpError() << "failed level check: unranked tensor";
218 if (type.getRank() > highest_rank)
219 return op->emitOpError() << "failed level check: " << operandOrResult
220 << " rank(shape) <= MAX_RANK";
221 } else if (tosa::shapeType shapeType =
222 dyn_cast<tosa::shapeType>(typeToCheck)) {
223 if (shapeType.getRank() > highest_rank)
224 return op->emitOpError()
225 << "failed shape type level check: " << typeToCheck
226 << " exceeds MAX_RANK";
227 }
228 return success();
229 }
230
231 // Perform the Level Rank check on the tensor value.
232 LogicalResult levelCheckRank(Operation *op, const Value &v,
233 const StringRef operandOrResult,
234 int32_t highest_rank) {
235 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
236 }
237
238 // Perform the Level tensor size check on the tensor type.
239 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
240 const StringRef operandOrResult);
241
242 // Perform the Level tensor size check on the tensor value.
243 LogicalResult levelCheckSize(Operation *op, const Value &v,
244 const StringRef operandOrResult) {
245 return levelCheckSize(op, v.getType(), operandOrResult);
246 }
247
248 // Level check sizes of all operands and results of the operation.
249 template <typename T>
250 LogicalResult levelCheckSizes(T tosaOp) {
251 auto op = tosaOp.getOperation();
252 for (auto v : op->getOperands()) {
253 if (failed(levelCheckSize(op, v, "operand")))
254 return failure();
255 }
256
257 for (auto v : op->getResults()) {
258 if (failed(levelCheckSize(op, v, "result")))
259 return failure();
260 }
261 return success();
262 }
263
264 // Level check ranks of all operands, attribute and results of the operation.
265 template <typename T>
266 LogicalResult levelCheckRanks(T tosaOp) {
267 auto op = tosaOp.getOperation();
268 const TosaLevel tosaLevel = targetEnv.getLevel();
269 for (auto v : op->getOperands()) {
270 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
271 return failure();
272 }
273
274 for (auto v : op->getResults()) {
275 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
276 return failure();
277 }
278 return success();
279 }
280
281 // Level check ranks and sizes.
282 LogicalResult levelCheckRanksAndSizes(Operation *op);
283
284 // Pool Op: level check kernel/stride/pad values
285 template <typename T>
286 LogicalResult levelCheckPool(Operation *op) {
287 if (auto poolOp = dyn_cast<T>(op)) {
288 for (auto k : poolOp.getKernel()) {
289 if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
290 return failure();
291 }
292 }
293 for (auto s : poolOp.getStride()) {
294 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
295 return failure();
296 }
297 }
298 for (auto p : poolOp.getPad()) {
299 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
300 return failure();
301 }
302 }
303 }
304 return success();
305 }
306
307 // Conv Op: level check dilation/stride/pad values
308 template <typename T>
309 LogicalResult levelCheckConv(Operation *op) {
310 if (auto convOp = dyn_cast<T>(op)) {
311
312 for (auto k : convOp.getDilation()) {
313 if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
314 return failure();
315 }
316 }
317 for (auto p : convOp.getPad()) {
318 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
319 return failure();
320 }
321 }
322 for (auto s : convOp.getStride()) {
323 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
324 return failure();
325 }
326 }
327 auto dilation = convOp.getDilation();
328 if (ShapedType weightType =
329 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
330 auto shape = weightType.getShape();
331 if (isa<tosa::Conv2DOp>(op)) {
332 assert(shape.size() == 4);
333 assert(dilation.size() == 2);
334 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
335 "dilation_y * KH <= MAX_KERNEL)")) ||
336 failed(levelCheckKernel(op, dilation[1] * shape[2],
337 "dilation_x * KW <= MAX_KERNEL)")))
338 return failure();
339 } else if (isa<tosa::Conv3DOp>(op)) {
340 assert(shape.size() == 5);
341 assert(dilation.size() == 3);
342 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
343 "dilation_d * KD <= MAX_KERNEL)")) ||
344 failed(levelCheckKernel(op, dilation[1] * shape[2],
345 "dilation_y * KH <= MAX_KERNEL)")) ||
346 failed(levelCheckKernel(op, dilation[2] * shape[3],
347 "dilation_x * KW <= MAX_KERNEL)")))
348 return failure();
349 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
350 assert(shape.size() == 4);
351 assert(dilation.size() == 2);
352 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
353 "dilation_y * KH <= MAX_KERNEL)")) ||
354 failed(levelCheckKernel(op, dilation[1] * shape[1],
355 "dilation_x * KW <= MAX_KERNEL)")))
356 return failure();
357 }
358 }
359 }
360 return success();
361 }
362
363 // FFT op: level check H, W in input shape [N,H,W]
364 template <typename T>
365 LogicalResult levelCheckFFT(Operation *op) {
366 if (isa<T>(op)) {
367 for (auto v : op->getOperands()) {
368 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
369 auto shape = type.getShape();
370 assert(shape.size() == 3);
371 if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
372 failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
373 return failure();
374 }
375 }
376 }
377 }
378 return success();
379 }
380
381 // TransposeConv2d op: level check kH/kW, outpad, and stride
382 LogicalResult levelCheckTransposeConv2d(Operation *op) {
383 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
384 if (ShapedType filterType =
385 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
386 auto shape = filterType.getShape();
387 assert(shape.size() == 4);
388 // level check kernel sizes for kH and KW
389 if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
390 failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
391 return failure();
392 }
393 }
394 for (auto p : transpose.getOutPad()) {
395 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
396 return failure();
397 }
398 }
399 for (auto s : transpose.getStride()) {
400 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
401 return failure();
402 }
403 }
404 }
405 return success();
406 }
407
408 // Resize op: level check max scales
409 LogicalResult levelCheckResize(Operation *op) {
410 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
411 SmallVector<int64_t> scale;
412 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
413 scale)) {
414 return failure();
415 }
416 const int64_t scaleYN = scale[0];
417 const int64_t scaleYD = scale[1];
418 const int64_t scaleXN = scale[2];
419 const int64_t scaleXD = scale[3];
420 if (failed(levelCheckScale(op, scaleYN / scaleYD,
421 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
422 failed(levelCheckScale(op, scaleXN / scaleXD,
423 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
424 return failure();
425 }
426 }
427 return success();
428 }
429
430 // Recursively perform a bottom-up search to determine the maximum nesting
431 // depth, starting from a specific operation and continuing up to the function
432 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
433 // time a new nested `region` is encountered.
434 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
435 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
436 return;
437
438 op = op->getParentOp();
439 if (!op)
440 return;
441
442 depth++;
443 getMaxNestedDepth(op, depth);
444 }
445
446 LogicalResult levelCheckMaxNesting(Operation *op) {
447 int32_t maxNestedDepth = 0;
448 getMaxNestedDepth(op, maxNestedDepth);
449
450 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
451 op->emitOpError() << "failed level check: " << maxNestedDepth
452 << " >= MAX_NESTING";
453 return failure();
454 }
455 return success();
456 }
457
458 LogicalResult levelCheckListSize(Operation *op) {
459 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
460 return levelCheckListSize(op, concat.getInput1().size(), "input1");
461 }
462 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
463 if (failed(levelCheckListSize(op, custom.getInputList().size(),
464 "input_list")) ||
465 failed(levelCheckListSize(op, custom.getOutputList().size(),
466 "output_list"))) {
467 return failure();
468 }
469 }
470 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
471 if (failed(
472 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
473 failed(levelCheckListSize(op, condIf.getOutputList().size(),
474 "outputs"))) {
475 return failure();
476 }
477 }
478 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
479 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
480 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
481 return failure();
482 }
483 }
484 return success();
485 }
486
487 LogicalResult attributeCheckRescale(Operation *op) {
488 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
489 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
490 !targetEnv.allows(Extension::doubleround)) {
491 op->emitOpError()
492 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
493 << "requires extension [doubleround]";
494 return failure();
495 }
496 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
497 !targetEnv.allows(Extension::inexactround)) {
498 op->emitOpError()
499 << "failed attribute check: rounding_mode = INEXACT_ROUND "
500 << "requires extension [inexactround]";
501 return failure();
502 }
503 }
504 return success();
505 }
506
507 LogicalResult CheckVariable(Operation *op);
508 LogicalResult CheckVariableReadOrWrite(Operation *op);
509 bool isValidElementType(Type type, const bool allowUnsigned = false);
510
511 SmallVector<
512 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
513 constCheckers;
515 TosaProfileCompliance profileComp;
516 tosa::TargetEnv targetEnv;
517};
518
519template <>
520LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
521 auto *op = tosaOp.getOperation();
522 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
523 targetEnv.getLevel().MAX_RANK)))
524 return failure();
525
526 // rank(output) = rank(input) - 1
527 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
528 targetEnv.getLevel().MAX_RANK - 1)))
529 return failure();
530
531 return success();
532}
533
534template <>
535LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
536 auto *op = tosaOp.getOperation();
537
538 // Only the condition input has rank limitation.
539 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
540 targetEnv.getLevel().MAX_RANK)))
541 return failure();
542
543 return success();
544}
545
546template <>
547LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
548 auto *op = tosaOp.getOperation();
549 auto variableType = getVariableType(tosaOp);
550 if (failed(levelCheckRank(op, variableType, "variable type",
551 targetEnv.getLevel().MAX_RANK)))
552 return failure();
553
554 return success();
555}
556
557template <>
558LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
559 auto *op = tosaOp.getOperation();
560 auto variableType = getVariableType(tosaOp);
561 if (failed(levelCheckSize(op, variableType, "variable type")))
562 return failure();
563
564 return success();
565}
566
567LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
568#define CHECK_RANKS_AND_SIZES(tosaOp) \
569 if (isa<tosa::tosaOp##Op>(op)) { \
570 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
571 return failure(); \
572 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
573 return failure(); \
574 }
575
576#define CHECK_SIZES(tosaOp) \
577 if (isa<tosa::tosaOp##Op>(op)) { \
578 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
579 return failure(); \
580 }
581
582#define CHECK_RANKS(tosaOp) \
583 if (isa<tosa::tosaOp##Op>(op)) { \
584 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
585 return failure(); \
586 }
587
588 // Tensor Operators
589 CHECK_RANKS_AND_SIZES(ArgMax);
590 // Activation Functions
593 CHECK_RANKS_AND_SIZES(Sigmoid);
595 // Elementwise Binary Operators
597 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
598 CHECK_RANKS_AND_SIZES(BitwiseAnd);
599 CHECK_RANKS_AND_SIZES(BitwiseOr);
600 CHECK_RANKS_AND_SIZES(BitwiseXor);
601 CHECK_RANKS_AND_SIZES(IntDiv);
602 CHECK_RANKS_AND_SIZES(LogicalAnd);
603 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
604 CHECK_RANKS_AND_SIZES(LogicalRightShift);
605 CHECK_RANKS_AND_SIZES(LogicalOr);
606 CHECK_RANKS_AND_SIZES(LogicalXor);
607 CHECK_RANKS_AND_SIZES(Maximum);
608 CHECK_RANKS_AND_SIZES(Minimum);
613 // Elementwise Unary Operators
615 CHECK_RANKS_AND_SIZES(BitwiseNot);
622 CHECK_RANKS_AND_SIZES(LogicalNot);
623 CHECK_RANKS_AND_SIZES(Negate);
624 CHECK_RANKS_AND_SIZES(Reciprocal);
627 // Elementwise Ternary Operators
628 CHECK_RANKS_AND_SIZES(Select);
629 // Comparison Operators
631 CHECK_RANKS_AND_SIZES(Greater);
632 CHECK_RANKS_AND_SIZES(GreaterEqual);
633 // Reduction Operators
634 CHECK_RANKS_AND_SIZES(ReduceAll);
635 CHECK_RANKS_AND_SIZES(ReduceAny);
636 CHECK_RANKS_AND_SIZES(ReduceMax);
637 CHECK_RANKS_AND_SIZES(ReduceMin);
638 CHECK_RANKS_AND_SIZES(ReduceProduct);
639 CHECK_RANKS_AND_SIZES(ReduceSum);
640 // Data Layout Operators
641 CHECK_RANKS_AND_SIZES(Concat);
643 CHECK_RANKS_AND_SIZES(Reshape);
644 CHECK_RANKS_AND_SIZES(Reverse);
647 CHECK_RANKS_AND_SIZES(Transpose);
648 // Type Conversion
650 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
651 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
652 CHECK_RANKS_AND_SIZES(Rescale);
653 // Data Nodes
655 CHECK_RANKS_AND_SIZES(Identity);
656 // Control Flow Operators
658 // Variable Operators
659 CHECK_RANKS_AND_SIZES(Variable);
660 CHECK_RANKS_AND_SIZES(VariableWrite);
661 CHECK_RANKS_AND_SIZES(VariableRead);
662
663 // For the following operators, check whether the size of each tensor
664 // operand is valid in a given Level.
665
666 // Tensor Operators
667 CHECK_SIZES(AvgPool2d);
668 CHECK_SIZES(Conv2D);
669 CHECK_SIZES(Conv3D);
670 CHECK_SIZES(DepthwiseConv2D);
671 CHECK_SIZES(TransposeConv2D);
672 CHECK_SIZES(FFT2d);
673 CHECK_SIZES(MatMul);
674 CHECK_SIZES(MatmulTBlockScaled);
675 CHECK_SIZES(MaxPool2d);
676 CHECK_SIZES(RFFT2d);
677 // Scatter/Gather Operators
679 CHECK_SIZES(Scatter);
680 // Image Operators
681 CHECK_SIZES(Resize);
682 // Custom Operators
683 CHECK_SIZES(Custom);
684 // Control Flow Operators
685 CHECK_SIZES(While);
686 // Shape Operators
687 CHECK_SIZES(ConstShape);
688
689 // For the following operations, check whether the rank of each operand
690 // is valid given a level.
691
692 // Shape Operators
693 CHECK_RANKS(AddShape);
694 CHECK_RANKS(DivCeilShape);
695 CHECK_RANKS(DivFloorShape);
696 CHECK_RANKS(MulShape);
697 CHECK_RANKS(SubShape);
698
699#undef CHECK_RANKS_AND_SIZES
700#undef CHECK_SIZES
701#undef CHECK_RANKS
702 return success();
703}
704
705// Perform the Level tensor size check on the tensor type.
706LogicalResult TosaValidation::levelCheckSize(Operation *op,
707 const Type &typeToCheck,
708 const StringRef operandOrResult) {
709 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
710 if (!type.hasRank())
711 return op->emitOpError() << "failed level check: unranked tensor";
712 auto shape = type.getShape();
713 for (auto dim : shape) {
714 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
715 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
716 const TosaSpecificationVersion minRequiredVersion(1, 1);
717 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
718 dimIsDynamic)
719 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
720 // resolved at backend compile time. Runtime dynamism is not currently
721 // supported. Checking this requirement is met is delegated to backends.
722 return success();
723
724 // When targeting TOSA 1.0 or below, dynamic dims are not supported
725 if (dimIsDynamic)
726 return op->emitOpError() << "failed level check: " << operandOrResult
727 << " shape dimension cannot be dynamic when"
728 << " targeting TOSA specification version 1.0"
729 << " or below";
730 }
731
732 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
733 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
734 int64_t size = element_bytes * type.getNumElements();
735
736 // According to 1.11. Tensor Definitions of Tosa spec, the value of
737 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
738 // defined in 1.7. Levels.
739 // For each tensor, the number of tensor elements multiplied by the
740 // element size in bytes must be representable as a tensor_size_t.
741 const int64_t max_size =
742 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
743 if (size > max_size)
744 return op->emitOpError()
745 << "failed level check: " << operandOrResult
746 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
747 }
748 return success();
749}
750
751LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
752 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
753 // no need to do level checks
754 return success();
755 }
756
757 // check rank and sizes early so later checks can assume shaped operands
758 if (failed(levelCheckRanksAndSizes(op)))
759 return failure();
760
761 // additional level checks from spec 0.70
762 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
763 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
764 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
765 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
766 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
767 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
768 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
769 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
770 return failure();
771 }
772
773 // level check MAX_TENSOR_LIST_SIZE
774 if (failed(levelCheckListSize(op))) {
775 return failure();
776 }
777
778 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
779 if (failed(levelCheckMaxNesting(op))) {
780 return failure();
781 }
782 }
783
784 return success();
785}
786
787LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
788 if (failed(attributeCheckRescale(op)))
789 return failure();
790 return success();
791}
792
793inline bool CompatibleTypes(const mlir::Type &type,
794 const mlir::Type &declaredType) {
795 // for now, simply use type equality comparison
796 return type == declaredType;
797}
798
799LogicalResult TosaValidation::CheckVariable(Operation *op) {
800 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
801 mlir::StringAttr nameAttr = variableOp.getNameAttr();
802
803 if (variablesMap.count(nameAttr))
804 return op->emitOpError() << "name has already been declared";
805
806 auto elementType = variableOp.getType();
807 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
808 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
809 RankedTensorType variableType =
810 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
811
812 variablesMap[nameAttr] = variableType;
813 }
814
815 return success();
816}
817
818LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
819 if (isa<mlir::tosa::VariableReadOp>(op) ||
820 isa<mlir::tosa::VariableWriteOp>(op)) {
821 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
822 if (!variablesMap.count(nameAttr))
823 return op->emitOpError() << "name has not been declared";
824
825 auto varType = variablesMap[nameAttr];
826
827 for (auto v : op->getOperands()) {
828 auto type = v.getType();
829 if (!CompatibleTypes(type, varType))
830 return op->emitOpError() << "operand type does not equal variable type";
831 }
832
833 for (auto v : op->getResults()) {
834 auto type = v.getType();
835 if (!CompatibleTypes(type, varType))
836 return op->emitOpError() << "result type does not equal variable type";
837 }
838 }
839
840 return success();
841}
842
843LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
844 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
845 return failure();
846 return success();
847}
848
849LogicalResult checkErrorIfResize(Operation *op) {
850 auto resize = dyn_cast<tosa::ResizeOp>(op);
851 if (!resize)
852 return success();
853
854 const Value input = resize.getInput();
855 const Value output = resize.getOutput();
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(input.getType());
858 const RankedTensorType outputType =
859 llvm::dyn_cast<RankedTensorType>(output.getType());
860
861 if (!inputType || !outputType)
862 return op->emitOpError("expect ranked input/output tensor");
863
864 // Ensure the image size is supported by GPU APIs and that for integer
865 // implementations, position * stride does not overflow int32_t.
866 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
867 const SmallVector<int64_t, 4> sizes = {
868 outputType.getDimSize(1), outputType.getDimSize(2),
869 inputType.getDimSize(1), inputType.getDimSize(2)};
870 const int64_t *maxDim = llvm::max_element(sizes);
871 if (maxDim != sizes.end() && *maxDim >= 16384)
872 return op->emitOpError(
873 "expect input/output height/width dims to be < 16384, ")
874 << "got [OH, OW, IH, IW] = " << sizes;
875 }
876
877 SmallVector<int64_t> scale;
878 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
879 return failure();
880
881 const int64_t scaleYN = scale[0];
882 const int64_t scaleYD = scale[1];
883 const int64_t scaleXN = scale[2];
884 const int64_t scaleXD = scale[3];
885
886 // Ensure scale values don't overflow int32 accumulator
887 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
888 return op->emitOpError(
889 "expect all scale numerator values to be <= (1 << 11), "
890 "got scale_y_n=")
891 << scaleYN << ", scale_x_n=" << scaleXN;
892
893 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
894 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
895 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
896
897 SmallVector<int64_t> offset;
898 SmallVector<int64_t> border;
899 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
900 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
901 return failure();
902
903 const int64_t offsetY = offset[0];
904 const int64_t offsetX = offset[1];
905 // Set a consistent lower limit of 1/16 downscale to simplify
906 // implementations
907 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
908 return op->emitOpError(
909 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
910 << offsetY << "/" << scaleYN;
911 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
912 return op->emitOpError(
913 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
914 << offsetX << "/" << scaleXN;
915
916 const int64_t borderY = border[0];
917 const int64_t borderX = border[1];
918 if (borderY < -16 * scaleYN || borderY >= scaleYN)
919 return op->emitOpError(
920 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
921 << borderY << "/" << scaleYN;
922 if (borderX < -16 * scaleXN || borderX >= scaleXN)
923 return op->emitOpError(
924 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
925 << borderX << "/" << scaleXN;
926
927 // The following section of code is mostly duplicated with ResizeOp::verify().
928 //
929 // In TOSA specification, we do not support broadcast behavior.
930 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
931 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
932 // existing code, we keep the rewrite pattern untouched. So, we need
933 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
934 //
935 // Here is a strict checking to conform TOSA specification.
936 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
937 auto idivCheck = [](const int64_t lhs,
938 const int64_t rhs) -> std::optional<int64_t> {
939 if (lhs % rhs != 0)
940 return std::nullopt;
941 return lhs / rhs;
942 };
943
944 const int64_t oh = outputType.getDimSize(1);
945 const int64_t ow = outputType.getDimSize(2);
946 const int64_t ih = inputType.getDimSize(1);
947 const int64_t iw = inputType.getDimSize(2);
948
949 if (ih != ShapedType::kDynamic) {
950 const std::optional<int64_t> calculatedOutHeightMinusOne =
951 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
952 if (!calculatedOutHeightMinusOne.has_value())
953 return op->emitOpError(
954 "expected (input_height - 1) * scale_y_n - offset_y + "
955 "border_y ")
956 << "to be wholly divisible by scale_y_d, got ((" << ih
957 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
958 << ") / " << scaleYD;
959 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
960 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
961 return op->emitOpError(
962 "calculated output height did not match expected: ")
963 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
964 }
965
966 if (iw != ShapedType::kDynamic) {
967 const std::optional<int64_t> calculatedOutWidthMinusOne =
968 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
969 if (!calculatedOutWidthMinusOne.has_value())
970 return op->emitOpError(
971 "expected (input_width - 1) * scale_x_n - offset_x + "
972 "border_x ")
973 << "to be wholly divisible by scale_x_d, got ((" << iw
974 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
975 << ") / " << scaleXD;
976 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
977 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
978 return op->emitOpError("calculated output width did not match expected: ")
979 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
980 }
981
982 return success();
983}
984
985LogicalResult checkErrorIfMul(Operation *op) {
986 auto mul = dyn_cast<tosa::MulOp>(op);
987 if (!mul)
988 return success();
989
990 // REQUIRE(0 <= shift && shift <= 63);
991 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
992 ElementsAttr shift_elem;
993 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
994 return success();
995 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
996 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
997 if (inputElemType.isInteger(32)) {
998 // 0 <= shift <= 63 for int32_t type
999 if (shift < 0 || shift > 63)
1000 return op->emitOpError()
1001 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1002 } else {
1003 // shift must be 0 for all other types
1004 if (shift != 0)
1005 return op->emitOpError()
1006 << "requires shift = 0 for all input data types that "
1007 "are not int32_t, but got: "
1008 << shift;
1009 }
1010
1011 return success();
1012}
1013
1014LogicalResult checkErrorIfTable(Operation *op) {
1015 auto table = dyn_cast<tosa::TableOp>(op);
1016 if (!table)
1017 return success();
1018
1019 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1020 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1021 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1022
1023 const ShapeAdaptor tableShape(table.getTable().getType());
1024 if (tableShape.hasStaticShape()) {
1025 const auto numElements = tableShape.getNumElements();
1026 if (numElements != tableSize)
1027 return op->emitOpError() << "requires table size of " << tableSize
1028 << ", got " << numElements;
1029 }
1030
1031 return success();
1032}
1033
1034LogicalResult checkErrorIfRescale(Operation *op) {
1035 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1036 if (!rescale)
1037 return success();
1038
1039 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1040 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1041 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1042 !outputType.getElementType().isInteger())
1043 return success();
1044
1045 auto inElemType = inputType.getElementType();
1046 auto outElemType = outputType.getElementType();
1047 auto inWidth = inElemType.getIntOrFloatBitWidth();
1048 auto outWidth = outElemType.getIntOrFloatBitWidth();
1049
1050 bool inputUnsigned = rescale.getInputUnsigned();
1051 bool outputUnsigned = rescale.getOutputUnsigned();
1052
1053 bool scale32 = rescale.getScale32();
1054 auto roundingMode = rescale.getRoundingMode();
1055
1056 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1057 if (scale32 && inWidth == 48)
1058 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1059
1060 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1061 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1062 return op->emitOpError()
1063 << "DOUBLE_ROUND is only allowed with scale32=true.";
1064
1065 // ERROR_IF(input_unsigned && output_unsigned)
1066 if (inputUnsigned && outputUnsigned)
1067 return op->emitOpError() << "input and output cannot be both unsigned.";
1068
1069 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1070 if (outWidth == 32 && inputUnsigned)
1071 return op->emitOpError()
1072 << "i32 output type is not allowed with unsigned input.";
1073
1074 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1075 if (inWidth == 32 && outputUnsigned)
1076 return op->emitOpError()
1077 << "i32 input type is not allowed with unsigned output.";
1078
1079 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1080 if (inWidth == 48 && outputUnsigned)
1081 return op->emitOpError()
1082 << "i48 input type is not allowed with unsigned output.";
1083
1084 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1085 if (inWidth == 48 && inputUnsigned)
1086 return op->emitOpError() << "i48 input type cannot be unsigned.";
1087
1088 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1089 if (inWidth == 32 && inputUnsigned)
1090 return op->emitOpError() << "i32 input type cannot be unsigned.";
1091
1092 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1093 if (outWidth == 32 && outputUnsigned)
1094 return op->emitOpError() << "i32 output type cannot be unsigned.";
1095
1096 return success();
1097}
1098
1099LogicalResult checkErrorIfPad(Operation *op) {
1100 auto pad = dyn_cast<tosa::PadOp>(op);
1101 if (!pad)
1102 return success();
1103
1104 DenseIntElementsAttr paddingAttr;
1105 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1106 // Pad verifier will catch this
1107 return success();
1108
1109 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1110 if (val.getSExtValue() < 0)
1111 return op->emitOpError() << "padding value must all be non-negative, got "
1112 << val.getSExtValue();
1113 }
1114
1115 return success();
1116}
1117
1118static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1119 return llvm::all_of(op->getOperands(), [&](auto operand) {
1120 Region *operandRegion = operand.getParentRegion();
1121 return operandRegion && region->isAncestor(operandRegion);
1122 });
1123}
1124
1125static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1126 bool noLiveInValue = true;
1127 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1128 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1129 noLiveInValue = false;
1130 return WalkResult::interrupt();
1131 }
1132 return WalkResult::advance();
1133 });
1134 return noLiveInValue ? success() : failure();
1135}
1136
1137LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1138 StringRef regionName) {
1139 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1140 return success();
1141 return op->emitOpError()
1142 << "is not conformant to the TOSA specification. It requires the '"
1143 << regionName << "' region is isolated from above.\n";
1144}
1145
1146LogicalResult checkErrorIfCondIf(Operation *op) {
1147 auto ifOp = dyn_cast<tosa::IfOp>(op);
1148 if (!ifOp)
1149 return success();
1150
1151 // Currently the dialect supports declaring cond_if operations that
1152 // have then/else regions that reference values from outside these
1153 // regions. According to the specification, all values used by the
1154 // then/else regions must be explicitly declared within the regions.
1155 // Therefore we must check that the then/else regions are
1156 // "isolated from above", in order to be conformant to the
1157 // specification.
1158 //
1159 // Note: the dialect currently supports two styles of syntax for
1160 // declaring "cond_if" operations. We'll refer to these as follows:
1161 //
1162 // Generic:
1163 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1164 // ^bb0(%arg3, %arg4):
1165 // tosa.yield %arg3
1166 // }, {
1167 // ^bb0(%arg3, %arg4):
1168 // tosa.yield %arg4
1169 // })
1170 //
1171 // Simplified:
1172 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1173 // ^bb0(%arg3, %arg4):
1174 // tosa.yield %arg3
1175 // } else {
1176 // ^bb0(%arg3, %arg4):
1177 // tosa.yield %arg4
1178 // }
1179
1180 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1181 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1182 return failure();
1183 return success();
1184}
1185
1186LogicalResult checkErrorIfWhileLoop(Operation *op) {
1187 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1188 if (!whileOp)
1189 return success();
1190
1191 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1192 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1193 return failure();
1194 return success();
1195}
1196
1197LogicalResult checkErrorIfScatter(Operation *op) {
1198 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1199 if (!scatterOp)
1200 return success();
1201
1202 // for constant indices, check that there are no duplicate values
1203 DenseIntElementsAttr indicesAttr;
1204 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1205 return success();
1206
1207 auto const indicesType =
1208 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1209 if (!indicesType || !indicesType.hasRank()) {
1210 op->emitOpError("expect ranked indices tensor");
1211 return failure();
1212 }
1213
1214 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1215 op->emitOpError("indices values contain duplicates");
1216 return failure();
1217 }
1218
1219 return success();
1220}
1221
1222LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1223 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1224 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1225 failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
1226 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1227 return failure();
1228 return success();
1229}
1230
1231bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1232 if (isa<FloatType>(type)) {
1233 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1234 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1235 Float6E3M2FNType, Float8E8M0FNUType>(type);
1236 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1237 if (intTy.isSignless()) {
1238 switch (intTy.getWidth()) {
1239 case 1:
1240 case 4:
1241 case 8:
1242 case 16:
1243 case 32:
1244 case 48:
1245 case 64:
1246 return true;
1247 }
1248 } else if (allowUnsigned && intTy.isUnsigned()) {
1249 switch (intTy.getWidth()) {
1250 case 8:
1251 case 16:
1252 case 32:
1253 return true;
1254 }
1255 }
1256 } else if (isa<tosa::shapeType>(type))
1257 return true;
1258 else if (isa<tosa::mxint8Type>(type))
1259 return true;
1260 return false;
1261}
1262
1263void TosaValidation::runOnOperation() {
1264 ModuleOp modOp = getOperation();
1265 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1266 const auto maybeTargetEnv =
1267 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1268 if (failed(maybeTargetEnv))
1269 return signalPassFailure();
1270 targetEnv = *maybeTargetEnv;
1271
1272 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1273 if (!tosaDialect)
1274 return;
1275
1276 modOp.walk([&](Operation *op) {
1277 if (op->getDialect() != tosaDialect)
1278 return;
1279
1280 // validate operator element types:
1281 // - rescale operator is allowed to have ui8/ui16/ui32
1282 // operands/results when strictOpSpecAlignment is false
1283 // - perform valid element type check at the beginning to
1284 // protect rest of code against quantized element types
1285 const bool allowUnsigned =
1286 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1287 for (Value operand : op->getOperands()) {
1288 auto elementTy = getElementTypeOrSelf(operand);
1289 if (!isValidElementType(elementTy, allowUnsigned)) {
1290 op->emitOpError() << "is not profile-aligned: element type "
1291 << elementTy << " is not legal";
1292 return signalPassFailure();
1293 }
1294 }
1295 for (Type resultTy : op->getResultTypes()) {
1296 auto elementTy = getElementTypeOrSelf(resultTy);
1297 if (!isValidElementType(elementTy, allowUnsigned)) {
1298 op->emitOpError() << "is not profile-aligned: element type "
1299 << elementTy << " is not legal";
1300 return signalPassFailure();
1301 }
1302 }
1303
1304 if (strictOpSpecAlignment &&
1305 failed(profileComp.checkProfile(op, targetEnv)))
1306 return signalPassFailure();
1307
1308 if (strictOpSpecAlignment &&
1309 failed(profileComp.checkExtension(op, targetEnv)))
1310 return signalPassFailure();
1311
1312 if (!allowInvalidOpDatatypeCombinations &&
1313 failed(profileComp.checkInvalid(op)))
1314 return signalPassFailure();
1315
1316 // Some uses of TOSA rely on the constant operands of particular
1317 // operations.
1318 if (failed(applyConstantOperandCheck(op)))
1319 signalPassFailure();
1320
1321 // do level checks
1322 if (failed(applyLevelCheck(op)))
1323 signalPassFailure();
1324
1325 // check additional attribute restrictions
1326 if (failed(applyAttributeCheck(op)))
1327 signalPassFailure();
1328
1329 // do variable type checks
1330 if (failed(applyVariableCheck(op)))
1331 signalPassFailure();
1332
1333 // do error if checks
1334 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1335 signalPassFailure();
1336 });
1337}
1338} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:557
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_RANKS(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:97
TosaLevel getLevel() const
Definition TargetEnv.h:114
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:64
bool allows(Profile prof) const
Definition TargetEnv.h:124
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:110
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:64
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:42
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:609
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369