MLIR 22.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matchs with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Matchers.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/StringExtras.h"
30
31namespace mlir {
32namespace tosa {
33#define GEN_PASS_DEF_TOSAVALIDATION
34#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
35} // namespace tosa
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::tosa;
40
41namespace {
42
43static LogicalResult
44checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
45 for (const auto index : operandIndices) {
46 Attribute attr;
47 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
48 return op->emitOpError("expected compile time resolvable constant, but "
49 "got variable value for operand #")
50 << index;
51 }
52 }
53 return success();
54}
55
56static LogicalResult checkConstantOperandMul(Operation *op,
57 const TargetEnv &env) {
58 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
59 // Check 'shift'
60 return checkConstantOperands(op, {2});
61 }
62 return success();
63}
64
65static LogicalResult checkConstantOperandTable(Operation *op,
66 const TargetEnv &env) {
67 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
68 // Check 'table'
69 return checkConstantOperands(op, {1});
70 }
71 return success();
72}
73
74static LogicalResult checkConstantOperandPad(Operation *op,
75 const TargetEnv &env) {
76 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
77 // Assume this op is zero-padding if padConst is not presented
78 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
79 // Check 'pad_const'
80 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
81 return checkConstantOperands(op, {2});
82 }
83 return success();
84}
85
86static LogicalResult checkConstantOperandRescale(Operation *op,
87 const TargetEnv &env) {
88 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
89 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
90 return checkConstantOperands(op, {1, 2, 3, 4});
91 }
92 return success();
93}
94
95template <typename T>
96static LogicalResult checkConstantOperandConvOps(Operation *op,
97 const TargetEnv &env) {
98 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
99 // Check 'input_zp' and 'weight_zp'
100 return checkConstantOperands(op, {3, 4});
101 }
102 return success();
103}
104
105static LogicalResult checkConstantOperandMatMul(Operation *op,
106 const TargetEnv &env) {
107 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
108 // Check 'A_zp' and 'B_zp'
109 return checkConstantOperands(op, {2, 3});
110 }
111 return success();
112}
113
114static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
115 const TargetEnv &env) {
116 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
117 // Check 'input_zp' and 'output_zp'
118 return checkConstantOperands(op, {1, 2});
119 }
120 return success();
121}
122
123static LogicalResult checkConstantOperandNegate(Operation *op,
124 const TargetEnv &env) {
125 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
126 // Check 'input1_zp' and 'output_zp'
127 return checkConstantOperands(op, {1, 2});
128 }
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// TOSA Validation Pass.
134//===----------------------------------------------------------------------===//
135
136struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137public:
138 explicit TosaValidation() { populateConstantOperandChecks(); }
139
140 explicit TosaValidation(const TosaValidationOptions &options)
141 : TosaValidation() {
142 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
143 this->allowInvalidOpDatatypeCombinations =
144 options.allowInvalidOpDatatypeCombinations;
145 }
146 void runOnOperation() final;
147
148 LogicalResult applyConstantOperandCheck(Operation *op) {
149 for (auto &checker : constCheckers) {
150 if (failed(checker(op, targetEnv)))
151 return failure();
152 }
153 return success();
154 }
155
156 LogicalResult applyLevelCheck(Operation *op);
157 LogicalResult applyAttributeCheck(Operation *op);
158
159 // check variable read/write data types against variable declarations
160 LogicalResult applyVariableCheck(Operation *op);
161
162 // check error if conditions
163 LogicalResult applyErrorIfCheck(Operation *op);
164
165private:
166 void populateConstantOperandChecks() {
167 constCheckers.emplace_back(checkConstantOperandMul);
168 constCheckers.emplace_back(checkConstantOperandTable);
169 constCheckers.emplace_back(checkConstantOperandPad);
170 constCheckers.emplace_back(checkConstantOperandRescale);
171 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173 constCheckers.emplace_back(
174 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175 constCheckers.emplace_back(
176 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177 constCheckers.emplace_back(checkConstantOperandMatMul);
178 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179 constCheckers.emplace_back(checkConstantOperandNegate);
180 }
181
182 LogicalResult levelCheckKernel(Operation *op, int32_t v,
183 const StringRef checkDesc) {
184 if (v > targetEnv.getLevel().MAX_KERNEL)
185 return op->emitOpError() << "failed level check: " << checkDesc;
186 return success();
187 }
188
189 LogicalResult levelCheckStride(Operation *op, int32_t v,
190 const StringRef checkDesc) {
191 if (v > targetEnv.getLevel().MAX_STRIDE)
192 return op->emitOpError() << "failed level check: " << checkDesc;
193 return success();
194 }
195
196 LogicalResult levelCheckScale(Operation *op, int32_t v,
197 const StringRef checkDesc) {
198 if (v > targetEnv.getLevel().MAX_SCALE)
199 return op->emitOpError() << "failed level check: " << checkDesc;
200 return success();
201 }
202
203 LogicalResult levelCheckListSize(Operation *op, int32_t v,
204 const StringRef checkDesc) {
205 if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
206 return op->emitOpError()
207 << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
208 return success();
209 }
210
211 // Perform the Level Rank check on the tensor type.
212 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
213 const StringRef operandOrResult,
214 int32_t highest_rank) {
215 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
216 if (!type.hasRank())
217 return op->emitOpError() << "failed level check: unranked tensor";
218 if (type.getRank() > highest_rank)
219 return op->emitOpError() << "failed level check: " << operandOrResult
220 << " rank(shape) <= MAX_RANK";
221 }
222 return success();
223 }
224
225 // Perform the Level Rank check on the tensor value.
226 LogicalResult levelCheckRank(Operation *op, const Value &v,
227 const StringRef operandOrResult,
228 int32_t highest_rank) {
229 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
230 }
231
232 // Perform the Level tensor size check on the tensor type.
233 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
234 const StringRef operandOrResult);
235
236 // Perform the Level tensor size check on the tensor value.
237 LogicalResult levelCheckSize(Operation *op, const Value &v,
238 const StringRef operandOrResult) {
239 return levelCheckSize(op, v.getType(), operandOrResult);
240 }
241
242 // Level check sizes of all operands and results of the operation.
243 template <typename T>
244 LogicalResult levelCheckSizes(T tosaOp) {
245 auto op = tosaOp.getOperation();
246 for (auto v : op->getOperands()) {
247 if (failed(levelCheckSize(op, v, "operand")))
248 return failure();
249 }
250
251 for (auto v : op->getResults()) {
252 if (failed(levelCheckSize(op, v, "result")))
253 return failure();
254 }
255 return success();
256 }
257
258 // Level check ranks of all operands, attribute and results of the operation.
259 template <typename T>
260 LogicalResult levelCheckRanks(T tosaOp) {
261 auto op = tosaOp.getOperation();
262 const TosaLevel tosaLevel = targetEnv.getLevel();
263 for (auto v : op->getOperands()) {
264 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
265 return failure();
266 }
267
268 for (auto v : op->getResults()) {
269 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
270 return failure();
271 }
272 return success();
273 }
274
275 // Level check ranks and sizes.
276 LogicalResult levelCheckRanksAndSizes(Operation *op);
277
278 // Pool Op: level check kernel/stride/pad values
279 template <typename T>
280 LogicalResult levelCheckPool(Operation *op) {
281 if (auto poolOp = dyn_cast<T>(op)) {
282 for (auto k : poolOp.getKernel()) {
283 if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
284 return failure();
285 }
286 }
287 for (auto s : poolOp.getStride()) {
288 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
289 return failure();
290 }
291 }
292 for (auto p : poolOp.getPad()) {
293 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
294 return failure();
295 }
296 }
297 }
298 return success();
299 }
300
301 // Conv Op: level check dilation/stride/pad values
302 template <typename T>
303 LogicalResult levelCheckConv(Operation *op) {
304 if (auto convOp = dyn_cast<T>(op)) {
305
306 for (auto k : convOp.getDilation()) {
307 if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
308 return failure();
309 }
310 }
311 for (auto p : convOp.getPad()) {
312 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
313 return failure();
314 }
315 }
316 for (auto s : convOp.getStride()) {
317 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
318 return failure();
319 }
320 }
321 auto dilation = convOp.getDilation();
322 if (ShapedType weightType =
323 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
324 auto shape = weightType.getShape();
325 if (isa<tosa::Conv2DOp>(op)) {
326 assert(shape.size() == 4);
327 assert(dilation.size() == 2);
328 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
329 "dilation_y * KH <= MAX_KERNEL)")) ||
330 failed(levelCheckKernel(op, dilation[1] * shape[2],
331 "dilation_x * KW <= MAX_KERNEL)")))
332 return failure();
333 } else if (isa<tosa::Conv3DOp>(op)) {
334 assert(shape.size() == 5);
335 assert(dilation.size() == 3);
336 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
337 "dilation_d * KD <= MAX_KERNEL)")) ||
338 failed(levelCheckKernel(op, dilation[1] * shape[2],
339 "dilation_y * KH <= MAX_KERNEL)")) ||
340 failed(levelCheckKernel(op, dilation[2] * shape[3],
341 "dilation_x * KW <= MAX_KERNEL)")))
342 return failure();
343 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
344 assert(shape.size() == 4);
345 assert(dilation.size() == 2);
346 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
347 "dilation_y * KH <= MAX_KERNEL)")) ||
348 failed(levelCheckKernel(op, dilation[1] * shape[1],
349 "dilation_x * KW <= MAX_KERNEL)")))
350 return failure();
351 }
352 }
353 }
354 return success();
355 }
356
357 // FFT op: level check H, W in input shape [N,H,W]
358 template <typename T>
359 LogicalResult levelCheckFFT(Operation *op) {
360 if (isa<T>(op)) {
361 for (auto v : op->getOperands()) {
362 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
363 auto shape = type.getShape();
364 assert(shape.size() == 3);
365 if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
366 failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
367 return failure();
368 }
369 }
370 }
371 }
372 return success();
373 }
374
375 // TransposeConv2d op: level check kH/kW, outpad, and stride
376 LogicalResult levelCheckTransposeConv2d(Operation *op) {
377 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
378 if (ShapedType filterType =
379 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
380 auto shape = filterType.getShape();
381 assert(shape.size() == 4);
382 // level check kernel sizes for kH and KW
383 if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
384 failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
385 return failure();
386 }
387 }
388 for (auto p : transpose.getOutPad()) {
389 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
390 return failure();
391 }
392 }
393 for (auto s : transpose.getStride()) {
394 if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
395 return failure();
396 }
397 }
398 }
399 return success();
400 }
401
402 // Resize op: level check max scales
403 LogicalResult levelCheckResize(Operation *op) {
404 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
405 SmallVector<int64_t> scale;
406 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
407 scale)) {
408 return failure();
409 }
410 const int64_t scaleYN = scale[0];
411 const int64_t scaleYD = scale[1];
412 const int64_t scaleXN = scale[2];
413 const int64_t scaleXD = scale[3];
414 if (failed(levelCheckScale(op, scaleYN / scaleYD,
415 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
416 failed(levelCheckScale(op, scaleXN / scaleXD,
417 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
418 return failure();
419 }
420 }
421 return success();
422 }
423
424 // Recursively perform a bottom-up search to determine the maximum nesting
425 // depth, starting from a specific operation and continuing up to the function
426 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
427 // time a new nested `region` is encountered.
428 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
429 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
430 return;
431
432 op = op->getParentOp();
433 if (!op)
434 return;
435
436 depth++;
437 getMaxNestedDepth(op, depth);
438 }
439
440 LogicalResult levelCheckMaxNesting(Operation *op) {
441 int32_t maxNestedDepth = 0;
442 getMaxNestedDepth(op, maxNestedDepth);
443
444 if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
445 op->emitOpError() << "failed level check: " << maxNestedDepth
446 << " >= MAX_NESTING";
447 return failure();
448 }
449 return success();
450 }
451
452 LogicalResult levelCheckListSize(Operation *op) {
453 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454 return levelCheckListSize(op, concat.getInput1().size(), "input1");
455 }
456 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
457 if (failed(levelCheckListSize(op, custom.getInputList().size(),
458 "input_list")) ||
459 failed(levelCheckListSize(op, custom.getOutputList().size(),
460 "output_list"))) {
461 return failure();
462 }
463 }
464 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
465 if (failed(
466 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
467 failed(levelCheckListSize(op, condIf.getOutputList().size(),
468 "outputs"))) {
469 return failure();
470 }
471 }
472 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
473 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
474 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
475 return failure();
476 }
477 }
478 return success();
479 }
480
481 LogicalResult attributeCheckRescale(Operation *op) {
482 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
483 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
484 !targetEnv.allows(Extension::doubleround)) {
485 op->emitOpError()
486 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
487 << "requires extension [doubleround]";
488 return failure();
489 }
490 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
491 !targetEnv.allows(Extension::inexactround)) {
492 op->emitOpError()
493 << "failed attribute check: rounding_mode = INEXACT_ROUND "
494 << "requires extension [inexactround]";
495 return failure();
496 }
497 }
498 return success();
499 }
500
501 LogicalResult CheckVariable(Operation *op);
502 LogicalResult CheckVariableReadOrWrite(Operation *op);
503 bool isValidElementType(Type type, const bool allowUnsigned = false);
504
505 SmallVector<
506 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
507 constCheckers;
509 TosaProfileCompliance profileComp;
510 tosa::TargetEnv targetEnv;
511};
512
513template <>
514LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
515 auto *op = tosaOp.getOperation();
516 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
517 targetEnv.getLevel().MAX_RANK)))
518 return failure();
519
520 // rank(output) = rank(input) - 1
521 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
522 targetEnv.getLevel().MAX_RANK - 1)))
523 return failure();
524
525 return success();
526}
527
528template <>
529LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
530 auto *op = tosaOp.getOperation();
531
532 // Only the condition input has rank limitation.
533 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
534 targetEnv.getLevel().MAX_RANK)))
535 return failure();
536
537 return success();
538}
539
540template <>
541LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
542 auto *op = tosaOp.getOperation();
543 auto variableType = getVariableType(tosaOp);
544 if (failed(levelCheckRank(op, variableType, "variable type",
545 targetEnv.getLevel().MAX_RANK)))
546 return failure();
547
548 return success();
549}
550
551template <>
552LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
553 auto *op = tosaOp.getOperation();
554 auto variableType = getVariableType(tosaOp);
555 if (failed(levelCheckSize(op, variableType, "variable type")))
556 return failure();
557
558 return success();
559}
560
561LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
562#define CHECK_RANKS_AND_SIZES(tosaOp) \
563 if (isa<tosa::tosaOp##Op>(op)) { \
564 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
565 return failure(); \
566 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
567 return failure(); \
568 }
569
570#define CHECK_SIZES(tosaOp) \
571 if (isa<tosa::tosaOp##Op>(op)) { \
572 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
573 return failure(); \
574 }
575
576 // Tensor Operators
577 CHECK_RANKS_AND_SIZES(ArgMax);
578 // Activation Functions
581 CHECK_RANKS_AND_SIZES(Sigmoid);
583 // Elementwise Binary Operators
585 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
586 CHECK_RANKS_AND_SIZES(BitwiseAnd);
587 CHECK_RANKS_AND_SIZES(BitwiseOr);
588 CHECK_RANKS_AND_SIZES(BitwiseXor);
589 CHECK_RANKS_AND_SIZES(IntDiv);
590 CHECK_RANKS_AND_SIZES(LogicalAnd);
591 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
592 CHECK_RANKS_AND_SIZES(LogicalRightShift);
593 CHECK_RANKS_AND_SIZES(LogicalOr);
594 CHECK_RANKS_AND_SIZES(LogicalXor);
595 CHECK_RANKS_AND_SIZES(Maximum);
596 CHECK_RANKS_AND_SIZES(Minimum);
601 // Elementwise Unary Operators
603 CHECK_RANKS_AND_SIZES(BitwiseNot);
610 CHECK_RANKS_AND_SIZES(LogicalNot);
611 CHECK_RANKS_AND_SIZES(Negate);
612 CHECK_RANKS_AND_SIZES(Reciprocal);
615 // Elementwise Ternary Operators
616 CHECK_RANKS_AND_SIZES(Select);
617 // Comparison Operators
619 CHECK_RANKS_AND_SIZES(Greater);
620 CHECK_RANKS_AND_SIZES(GreaterEqual);
621 // Reduction Operators
622 CHECK_RANKS_AND_SIZES(ReduceAll);
623 CHECK_RANKS_AND_SIZES(ReduceAny);
624 CHECK_RANKS_AND_SIZES(ReduceMax);
625 CHECK_RANKS_AND_SIZES(ReduceMin);
626 CHECK_RANKS_AND_SIZES(ReduceProduct);
627 CHECK_RANKS_AND_SIZES(ReduceSum);
628 // Data Layout Operators
629 CHECK_RANKS_AND_SIZES(Concat);
631 CHECK_RANKS_AND_SIZES(Reshape);
632 CHECK_RANKS_AND_SIZES(Reverse);
635 CHECK_RANKS_AND_SIZES(Transpose);
636 // Type Conversion
638 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
639 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
640 CHECK_RANKS_AND_SIZES(Rescale);
641 // Control Flow Operators
643 // Variable Operators
644 CHECK_RANKS_AND_SIZES(Variable);
645 CHECK_RANKS_AND_SIZES(VariableWrite);
646 CHECK_RANKS_AND_SIZES(VariableRead);
647 // Data Nodes
649 CHECK_RANKS_AND_SIZES(Identity);
650
651 // For the following operators, check whether the size of each tensor
652 // operand is valid in a given Level.
653
654 // Tensor Operators
655 CHECK_SIZES(AvgPool2d);
656 CHECK_SIZES(Conv2D);
657 CHECK_SIZES(Conv3D);
658 CHECK_SIZES(DepthwiseConv2D);
659 CHECK_SIZES(TransposeConv2D);
660 CHECK_SIZES(FFT2d);
661 CHECK_SIZES(MatMul);
662 CHECK_SIZES(MatmulTBlockScaled);
663 CHECK_SIZES(MaxPool2d);
664 CHECK_SIZES(RFFT2d);
665 // Scatter/Gather Operators
667 CHECK_SIZES(Scatter);
668 // Image Operators
669 CHECK_SIZES(Resize);
670 // Custom Operators
671 CHECK_SIZES(Custom);
672 // Control Flow Operators
673 CHECK_SIZES(While);
674 // Shape Operators
675 CHECK_SIZES(ConstShape);
676
677#undef CHECK_RANKS_AND_SIZES
678#undef CHECK_SIZES
679 return success();
680}
681
682// Perform the Level tensor size check on the tensor type.
683LogicalResult TosaValidation::levelCheckSize(Operation *op,
684 const Type &typeToCheck,
685 const StringRef operandOrResult) {
686 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
687 if (!type.hasRank())
688 return op->emitOpError() << "failed level check: unranked tensor";
689 auto shape = type.getShape();
690 for (auto dim : shape) {
691 if (mlir::ShapedType::isDynamic(dim))
692 return op->emitOpError() << "failed level check: " << operandOrResult
693 << " shape dimension cannot be dynamic";
694 }
695
696 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
697 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
698 int64_t size = element_bytes * type.getNumElements();
699
700 // According to 1.11. Tensor Definitions of Tosa spec, the value of
701 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
702 // defined in 1.7. Levels.
703 // For each tensor, the number of tensor elements multiplied by the
704 // element size in bytes must be representable as a tensor_size_t.
705 const int64_t max_size =
706 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
707 if (size > max_size)
708 return op->emitOpError()
709 << "failed level check: " << operandOrResult
710 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
711 }
712 return success();
713}
714
715LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
716 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
717 // no need to do level checks
718 return success();
719 }
720
721 // check rank and sizes early so later checks can assume shaped operands
722 if (failed(levelCheckRanksAndSizes(op)))
723 return failure();
724
725 // additional level checks from spec 0.70
726 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
727 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
728 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
729 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
730 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
731 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
732 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
733 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
734 return failure();
737 // level check MAX_TENSOR_LIST_SIZE
738 if (failed(levelCheckListSize(op))) {
739 return failure();
741
742 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
743 if (failed(levelCheckMaxNesting(op))) {
744 return failure();
746 }
747
748 return success();
749}
750
751LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
752 if (failed(attributeCheckRescale(op)))
753 return failure();
754 return success();
755}
756
757inline bool CompatibleTypes(const mlir::Type &type,
758 const mlir::Type &declaredType) {
759 // for now, simply use type equality comparison
760 return type == declaredType;
761}
762
763LogicalResult TosaValidation::CheckVariable(Operation *op) {
764 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
765 mlir::StringAttr nameAttr = variableOp.getNameAttr();
766
767 if (variablesMap.count(nameAttr))
768 return op->emitOpError() << "name has already been declared";
769
770 auto elementType = variableOp.getType();
771 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
772 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
773 RankedTensorType variableType =
774 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
775
776 variablesMap[nameAttr] = variableType;
777 }
779 return success();
780}
781
782LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
783 if (isa<mlir::tosa::VariableReadOp>(op) ||
784 isa<mlir::tosa::VariableWriteOp>(op)) {
785 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
786 if (!variablesMap.count(nameAttr))
787 return op->emitOpError() << "name has not been declared";
788
789 auto varType = variablesMap[nameAttr];
790
791 for (auto v : op->getOperands()) {
792 auto type = v.getType();
793 if (!CompatibleTypes(type, varType))
794 return op->emitOpError() << "operand type does not equal variable type";
795 }
797 for (auto v : op->getResults()) {
798 auto type = v.getType();
799 if (!CompatibleTypes(type, varType))
800 return op->emitOpError() << "result type does not equal variable type";
801 }
802 }
803
804 return success();
805}
806
807LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
808 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
809 return failure();
810 return success();
811}
812
813LogicalResult checkErrorIfResize(Operation *op) {
814 auto resize = dyn_cast<tosa::ResizeOp>(op);
815 if (!resize)
816 return success();
817
818 const Value input = resize.getInput();
819 const Value output = resize.getOutput();
820 const RankedTensorType inputType =
821 llvm::dyn_cast<RankedTensorType>(input.getType());
822 const RankedTensorType outputType =
823 llvm::dyn_cast<RankedTensorType>(output.getType());
824
825 if (!inputType || !outputType)
826 return op->emitOpError("expect ranked input/output tensor");
827
828 // Ensure the image size is supported by GPU APIs and that for integer
829 // implementations, position * stride does not overflow int32_t.
830 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
831 const SmallVector<int64_t, 4> sizes = {
832 outputType.getDimSize(1), outputType.getDimSize(2),
833 inputType.getDimSize(1), inputType.getDimSize(2)};
834 const int64_t *maxDim = llvm::max_element(sizes);
835 if (maxDim != sizes.end() && *maxDim >= 16384)
836 return op->emitOpError(
837 "expect input/output height/width dims to be < 16384, ")
838 << "got [OH, OW, IH, IW] = " << sizes;
839 }
840
841 SmallVector<int64_t> scale;
842 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
843 return failure();
844
845 const int64_t scaleYN = scale[0];
846 const int64_t scaleYD = scale[1];
847 const int64_t scaleXN = scale[2];
848 const int64_t scaleXD = scale[3];
849
850 // Ensure scale values don't overflow int32 accumulator
851 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
852 return op->emitOpError(
853 "expect all scale numerator values to be <= (1 << 11), "
854 "got scale_y_n=")
855 << scaleYN << ", scale_x_n=" << scaleXN;
856
857 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
858 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
859 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
860
861 SmallVector<int64_t> offset;
862 SmallVector<int64_t> border;
863 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
864 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
865 return failure();
866
867 const int64_t offsetY = offset[0];
868 const int64_t offsetX = offset[1];
869 // Set a consistent lower limit of 1/16 downscale to simplify
870 // implementations
871 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
872 return op->emitOpError(
873 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
874 << offsetY << "/" << scaleYN;
875 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
876 return op->emitOpError(
877 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
878 << offsetX << "/" << scaleXN;
879
880 const int64_t borderY = border[0];
881 const int64_t borderX = border[1];
882 if (borderY < -16 * scaleYN || borderY >= scaleYN)
883 return op->emitOpError(
884 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
885 << borderY << "/" << scaleYN;
886 if (borderX < -16 * scaleXN || borderX >= scaleXN)
887 return op->emitOpError(
888 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
889 << borderX << "/" << scaleXN;
890
891 // The following section of code is mostly duplicated with ResizeOp::verify().
892 //
893 // In TOSA specification, we do not support broadcast behavior.
894 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
895 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
896 // existing code, we keep the rewrite pattern untouched. So, we need
897 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
898 //
899 // Here is a strict checking to conform TOSA specification.
900 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
901 auto idivCheck = [](const int64_t lhs,
902 const int64_t rhs) -> std::optional<int64_t> {
903 if (lhs % rhs != 0)
904 return std::nullopt;
905 return lhs / rhs;
906 };
907
908 const int64_t oh = outputType.getDimSize(1);
909 const int64_t ow = outputType.getDimSize(2);
910 const int64_t ih = inputType.getDimSize(1);
911 const int64_t iw = inputType.getDimSize(2);
912
913 if (ih != ShapedType::kDynamic) {
914 const std::optional<int64_t> calculatedOutHeightMinusOne =
915 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
916 if (!calculatedOutHeightMinusOne.has_value())
917 return op->emitOpError(
918 "expected (input_height - 1) * scale_y_n - offset_y + "
919 "border_y ")
920 << "to be wholly divisible by scale_y_d, got ((" << ih
921 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
922 << ") / " << scaleYD;
923 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
924 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
925 return op->emitOpError(
926 "calculated output height did not match expected: ")
927 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
928 }
929
930 if (iw != ShapedType::kDynamic) {
931 const std::optional<int64_t> calculatedOutWidthMinusOne =
932 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
933 if (!calculatedOutWidthMinusOne.has_value())
934 return op->emitOpError(
935 "expected (input_width - 1) * scale_x_n - offset_x + "
936 "border_x ")
937 << "to be wholly divisible by scale_x_d, got ((" << iw
938 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
939 << ") / " << scaleXD;
940 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
941 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
942 return op->emitOpError("calculated output width did not match expected: ")
943 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
944 }
945
946 return success();
947}
948
949LogicalResult checkErrorIfMul(Operation *op) {
950 auto mul = dyn_cast<tosa::MulOp>(op);
951 if (!mul)
952 return success();
953
954 // REQUIRE(0 <= shift && shift <= 63);
955 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
956 ElementsAttr shift_elem;
957 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
958 return success();
959 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
960 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
961 if (inputElemType.isInteger(32)) {
962 // 0 <= shift <= 63 for int32_t type
963 if (shift < 0 || shift > 63)
964 return op->emitOpError()
965 << "requires 0 <= shift && shift <= 63, but got: " << shift;
966 } else {
967 // shift must be 0 for all other types
968 if (shift != 0)
969 return op->emitOpError()
970 << "requires shift = 0 for all input data types that "
971 "are not int32_t, but got: "
972 << shift;
973 }
974
975 return success();
976}
977
978LogicalResult checkErrorIfTable(Operation *op) {
979 auto table = dyn_cast<tosa::TableOp>(op);
980 if (!table)
981 return success();
982
983 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
984 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
985 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
986
987 const ShapeAdaptor tableShape(table.getTable().getType());
988 if (tableShape.hasStaticShape()) {
989 const auto numElements = tableShape.getNumElements();
990 if (numElements != tableSize)
991 return op->emitOpError() << "requires table size of " << tableSize
992 << ", got " << numElements;
993 }
994
995 return success();
996}
997
998LogicalResult checkErrorIfRescale(Operation *op) {
999 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1000 if (!rescale)
1001 return success();
1002
1003 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1004 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1005 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1006 !outputType.getElementType().isInteger())
1007 return success();
1008
1009 auto inElemType = inputType.getElementType();
1010 auto outElemType = outputType.getElementType();
1011 auto inWidth = inElemType.getIntOrFloatBitWidth();
1012 auto outWidth = outElemType.getIntOrFloatBitWidth();
1013
1014 bool inputUnsigned = rescale.getInputUnsigned();
1015 bool outputUnsigned = rescale.getOutputUnsigned();
1016
1017 bool scale32 = rescale.getScale32();
1018 auto roundingMode = rescale.getRoundingMode();
1019
1020 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1021 if (scale32 && inWidth == 48)
1022 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1023
1024 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1025 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1026 return op->emitOpError()
1027 << "DOUBLE_ROUND is only allowed with scale32=true.";
1028
1029 // ERROR_IF(input_unsigned && output_unsigned)
1030 if (inputUnsigned && outputUnsigned)
1031 return op->emitOpError() << "input and output cannot be both unsigned.";
1032
1033 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1034 if (outWidth == 32 && inputUnsigned)
1035 return op->emitOpError()
1036 << "i32 output type is not allowed with unsigned input.";
1037
1038 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1039 if (inWidth == 32 && outputUnsigned)
1040 return op->emitOpError()
1041 << "i32 input type is not allowed with unsigned output.";
1042
1043 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1044 if (inWidth == 48 && outputUnsigned)
1045 return op->emitOpError()
1046 << "i48 input type is not allowed with unsigned output.";
1047
1048 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1049 if (inWidth == 48 && inputUnsigned)
1050 return op->emitOpError() << "i48 input type cannot be unsigned.";
1051
1052 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1053 if (inWidth == 32 && inputUnsigned)
1054 return op->emitOpError() << "i32 input type cannot be unsigned.";
1055
1056 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1057 if (outWidth == 32 && outputUnsigned)
1058 return op->emitOpError() << "i32 output type cannot be unsigned.";
1059
1060 return success();
1061}
1062
1063LogicalResult checkErrorIfPad(Operation *op) {
1064 auto pad = dyn_cast<tosa::PadOp>(op);
1065 if (!pad)
1066 return success();
1067
1068 DenseIntElementsAttr paddingAttr;
1069 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1070 // Pad verifier will catch this
1071 return success();
1072
1073 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1074 if (val.getSExtValue() < 0)
1075 return op->emitOpError() << "padding value must all be non-negative, got "
1076 << val.getSExtValue();
1077 }
1078
1079 return success();
1080}
1081
1082static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1083 return llvm::all_of(op->getOperands(), [&](auto operand) {
1084 Region *operandRegion = operand.getParentRegion();
1085 return operandRegion && region->isAncestor(operandRegion);
1086 });
1087}
1088
1089static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1090 bool noLiveInValue = true;
1091 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1092 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1093 noLiveInValue = false;
1094 return WalkResult::interrupt();
1095 }
1096 return WalkResult::advance();
1097 });
1098 return noLiveInValue ? success() : failure();
1099}
1100
1101LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1102 StringRef regionName) {
1103 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1104 return success();
1105 return op->emitOpError()
1106 << "is not conformant to the TOSA specification. It requires the '"
1107 << regionName << "' region is isolated from above.\n";
1108}
1109
1110LogicalResult checkErrorIfCondIf(Operation *op) {
1111 auto ifOp = dyn_cast<tosa::IfOp>(op);
1112 if (!ifOp)
1113 return success();
1114
1115 // Currently the dialect supports declaring cond_if operations that
1116 // have then/else regions that reference values from outside these
1117 // regions. According to the specification, all values used by the
1118 // then/else regions must be explicitly declared within the regions.
1119 // Therefore we must check that the then/else regions are
1120 // "isolated from above", in order to be conformant to the
1121 // specification.
1122 //
1123 // Note: the dialect currently supports two styles of syntax for
1124 // declaring "cond_if" operations. We'll refer to these as follows:
1125 //
1126 // Generic:
1127 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1128 // ^bb0(%arg3, %arg4):
1129 // tosa.yield %arg3
1130 // }, {
1131 // ^bb0(%arg3, %arg4):
1132 // tosa.yield %arg4
1133 // })
1134 //
1135 // Simplified:
1136 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1137 // ^bb0(%arg3, %arg4):
1138 // tosa.yield %arg3
1139 // } else {
1140 // ^bb0(%arg3, %arg4):
1141 // tosa.yield %arg4
1142 // }
1143
1144 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1145 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1146 return failure();
1147 return success();
1148}
1149
1150LogicalResult checkErrorIfWhileLoop(Operation *op) {
1151 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1152 if (!whileOp)
1153 return success();
1154
1155 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1156 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1157 return failure();
1158 return success();
1159}
1160
1161LogicalResult checkErrorIfScatter(Operation *op) {
1162 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1163 if (!scatterOp)
1164 return success();
1165
1166 // for constant indices, check that there are no duplicate values
1167 DenseIntElementsAttr indicesAttr;
1168 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1169 return success();
1170
1171 auto const indicesType =
1172 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1173 if (!indicesType || !indicesType.hasRank()) {
1174 op->emitOpError("expect ranked indices tensor");
1175 return failure();
1176 }
1177
1178 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1179 op->emitOpError("indices values contain duplicates");
1180 return failure();
1181 }
1182
1183 return success();
1184}
1185
1186LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1187 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1188 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1189 failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
1190 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1191 return failure();
1192 return success();
1193}
1194
1195bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1196 if (isa<FloatType>(type)) {
1197 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1198 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1199 Float6E3M2FNType, Float8E8M0FNUType>(type);
1200 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1201 if (intTy.isSignless()) {
1202 switch (intTy.getWidth()) {
1203 case 1:
1204 case 4:
1205 case 8:
1206 case 16:
1207 case 32:
1208 case 48:
1209 case 64:
1210 return true;
1211 }
1212 } else if (allowUnsigned && intTy.isUnsigned()) {
1213 switch (intTy.getWidth()) {
1214 case 8:
1215 case 16:
1216 case 32:
1217 return true;
1218 }
1219 }
1220 } else if (isa<tosa::shapeType>(type))
1221 return true;
1222 else if (isa<tosa::mxint8Type>(type))
1223 return true;
1224 return false;
1225}
1226
1227void TosaValidation::runOnOperation() {
1228 ModuleOp modOp = getOperation();
1229 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1230 const auto maybeTargetEnv =
1231 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1232 if (failed(maybeTargetEnv))
1233 return signalPassFailure();
1234 targetEnv = *maybeTargetEnv;
1235
1236 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1237 if (!tosaDialect)
1238 return;
1239
1240 modOp.walk([&](Operation *op) {
1241 if (op->getDialect() != tosaDialect)
1242 return;
1243
1244 // validate operator element types:
1245 // - rescale operator is allowed to have ui8/ui16/ui32
1246 // operands/results when strictOpSpecAlignment is false
1247 // - perform valid element type check at the beginning to
1248 // protect rest of code against quantized element types
1249 const bool allowUnsigned =
1250 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1251 for (Value operand : op->getOperands()) {
1252 auto elementTy = getElementTypeOrSelf(operand);
1253 if (!isValidElementType(elementTy, allowUnsigned)) {
1254 op->emitOpError() << "is not profile-aligned: element type "
1255 << elementTy << " is not legal";
1256 return signalPassFailure();
1257 }
1258 }
1259 for (Type resultTy : op->getResultTypes()) {
1260 auto elementTy = getElementTypeOrSelf(resultTy);
1261 if (!isValidElementType(elementTy, allowUnsigned)) {
1262 op->emitOpError() << "is not profile-aligned: element type "
1263 << elementTy << " is not legal";
1264 return signalPassFailure();
1265 }
1266 }
1267
1268 if (strictOpSpecAlignment &&
1269 failed(profileComp.checkProfile(op, targetEnv)))
1270 return signalPassFailure();
1271
1272 if (strictOpSpecAlignment &&
1273 failed(profileComp.checkExtension(op, targetEnv)))
1274 return signalPassFailure();
1275
1276 if (!allowInvalidOpDatatypeCombinations &&
1277 failed(profileComp.checkInvalid(op)))
1278 return signalPassFailure();
1279
1280 // Some uses of TOSA rely on the constant operands of particular
1281 // operations.
1282 if (failed(applyConstantOperandCheck(op)))
1283 signalPassFailure();
1284
1285 // do level checks
1286 if (failed(applyLevelCheck(op)))
1287 signalPassFailure();
1288
1289 // check additional attribute restrictions
1290 if (failed(applyAttributeCheck(op)))
1291 signalPassFailure();
1292
1293 // do variable type checks
1294 if (failed(applyVariableCheck(op)))
1295 signalPassFailure();
1296
1297 // do error if checks
1298 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1299 signalPassFailure();
1300 });
1301}
1302} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:557
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:97
TosaLevel getLevel() const
Definition TargetEnv.h:114
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:63
bool allows(Profile prof) const
Definition TargetEnv.h:124
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:42
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:609
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369