MLIR 23.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matches with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Matchers.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/Support/FormatVariadic.h"
32
33namespace mlir {
34namespace tosa {
35#define GEN_PASS_DEF_TOSAVALIDATION
36#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
37} // namespace tosa
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::tosa;
42
43namespace {
44
45static LogicalResult
46checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
47 for (const auto index : operandIndices) {
48 Attribute attr;
49 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
50 return op->emitOpError("expected compile time resolvable constant, but "
51 "got variable value for operand #")
52 << index;
53 }
54 }
55 return success();
56}
57
58static LogicalResult checkConstantOperandMul(Operation *op,
59 const TargetEnv &env) {
60 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
61 // Check 'shift'
62 return checkConstantOperands(op, {2});
63 }
64 return success();
65}
66
67static LogicalResult checkConstantOperandTable(Operation *op,
68 const TargetEnv &env) {
69 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
70 // Check 'table'
71 return checkConstantOperands(op, {1});
72 }
73 return success();
74}
75
76static LogicalResult checkConstantOperandPad(Operation *op,
77 const TargetEnv &env) {
78 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
79 // Assume this op is zero-padding if padConst is not presented
80 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
81 // Check 'pad_const'
82 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
83 return checkConstantOperands(op, {2});
84 }
85 return success();
86}
87
88static LogicalResult checkConstantOperandRescale(Operation *op,
89 const TargetEnv &env) {
90 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
91 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
92 return checkConstantOperands(op, {1, 2, 3, 4});
93 }
94 return success();
95}
96
97template <typename T>
98static LogicalResult checkConstantOperandConvOps(Operation *op,
99 const TargetEnv &env) {
100 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
101 // Check 'input_zp' and 'weight_zp'
102 return checkConstantOperands(op, {3, 4});
103 }
104 return success();
105}
106
107static LogicalResult checkConstantOperandMatMul(Operation *op,
108 const TargetEnv &env) {
109 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
110 // Check 'A_zp' and 'B_zp'
111 return checkConstantOperands(op, {2, 3});
112 }
113 return success();
114}
115
116static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
117 const TargetEnv &env) {
118 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
119 // Check 'input_zp' and 'output_zp'
120 return checkConstantOperands(op, {1, 2});
121 }
122 return success();
123}
124
125static LogicalResult checkConstantOperandNegate(Operation *op,
126 const TargetEnv &env) {
127 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
128 // Check 'input1_zp' and 'output_zp'
129 return checkConstantOperands(op, {1, 2});
130 }
131 return success();
132}
133
134static LogicalResult checkConstantOperandSilceShape(Operation *op,
135 const TargetEnv &env) {
136 if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
137 // Check 'start' and 'size'
138 return checkConstantOperands(op, {1, 2});
139 }
140 return success();
141}
142
143//===----------------------------------------------------------------------===//
144// TOSA Validation Pass.
145//===----------------------------------------------------------------------===//
146
147struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
148public:
149 explicit TosaValidation() { populateConstantOperandChecks(); }
150
151 explicit TosaValidation(const TosaValidationOptions &options)
152 : TosaValidation() {
153 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
154 this->allowInvalidOpDatatypeCombinations =
155 options.allowInvalidOpDatatypeCombinations;
156 }
157 void runOnOperation() final;
158
159 LogicalResult applyConstantOperandCheck(Operation *op) {
160 for (auto &checker : constCheckers) {
161 if (failed(checker(op, targetEnv)))
162 return failure();
163 }
164 return success();
165 }
166
167 LogicalResult applyLevelCheck(Operation *op);
168 LogicalResult applyAttributeCheck(Operation *op);
169
170 // check variable read/write data types against variable declarations
171 LogicalResult applyVariableCheck(Operation *op);
172
173 // check error if conditions
174 LogicalResult applyErrorIfCheck(Operation *op);
175
176private:
177 void populateConstantOperandChecks() {
178 constCheckers.emplace_back(checkConstantOperandMul);
179 constCheckers.emplace_back(checkConstantOperandTable);
180 constCheckers.emplace_back(checkConstantOperandPad);
181 constCheckers.emplace_back(checkConstantOperandRescale);
182 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
183 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
184 constCheckers.emplace_back(
185 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
186 constCheckers.emplace_back(
187 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
188 constCheckers.emplace_back(checkConstantOperandMatMul);
189 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
190 constCheckers.emplace_back(checkConstantOperandNegate);
191 constCheckers.emplace_back(checkConstantOperandSilceShape);
192 }
193
194 LogicalResult levelCheck(Operation *op, const int32_t calculatedValue,
195 const int32_t maxLevel, const StringRef inputName,
196 const StringRef levelName) {
197 if (calculatedValue > maxLevel)
198 return op->emitOpError()
199 << "failed level check: " << inputName << " <= " << levelName
200 << " (" << maxLevel << "), got " << calculatedValue;
201 return success();
202 }
203
204 LogicalResult levelCheckKernel(Operation *op, int32_t v,
205 const StringRef inputName) {
206 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
207 "MAX_KERNEL");
208 }
209
210 LogicalResult levelCheckStride(Operation *op, int32_t v,
211 const StringRef inputName) {
212 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
213 "MAX_STRIDE");
214 }
215
216 LogicalResult levelCheckScale(Operation *op, int32_t v,
217 const StringRef inputName) {
218 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
219 "MAX_SCALE");
220 }
221
222 LogicalResult levelCheckListSize(Operation *op, int32_t v,
223 const StringRef inputName) {
224 const std::string inputDesc =
225 llvm::formatv("length(tensor_list_shape({0}))", inputName);
226 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
227 inputDesc, "MAX_TENSOR_LIST_SIZE");
228 }
229
230 // Perform the Level Rank check on the tensor type.
231 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
232 const StringRef operandOrResult,
233 int32_t highest_rank) {
234 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
235 if (!type.hasRank())
236 return op->emitOpError() << "failed level check: unranked tensor";
237 if (type.getRank() > highest_rank)
238 return op->emitOpError() << "failed level check: " << operandOrResult
239 << " rank(shape) <= MAX_RANK";
240 }
241 return success();
242 }
243
244 // Perform the Level Rank check on the tensor value.
245 LogicalResult levelCheckRank(Operation *op, const Value &v,
246 const StringRef operandOrResult,
247 int32_t highest_rank) {
248 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
249 }
250
251 // Perform the Level tensor size check on the tensor type.
252 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
253 const StringRef operandOrResult);
254
255 // Perform the Level tensor size check on the tensor value.
256 LogicalResult levelCheckSize(Operation *op, const Value &v,
257 const StringRef operandOrResult) {
258 return levelCheckSize(op, v.getType(), operandOrResult);
259 }
260
261 // Perform the Level shape length check on a value.
262 LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
263 const StringRef operandOrResult) {
264 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
265 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
266 return op->emitOpError()
267 << "failed shape type level check: " << typeToCheck
268 << " exceeds MAX_SHAPE_LEN";
269 }
270 return success();
271 }
272
273 // Level check sizes of all operands and results of the operation.
274 template <typename T>
275 LogicalResult levelCheckSizes(T tosaOp) {
276 auto op = tosaOp.getOperation();
277 for (auto v : op->getOperands()) {
278 if (failed(levelCheckSize(op, v, "operand")))
279 return failure();
280 }
281
282 for (auto v : op->getResults()) {
283 if (failed(levelCheckSize(op, v, "result")))
284 return failure();
285 }
286 return success();
287 }
288
289 // Level check ranks of all operands, attribute and results of the operation.
290 template <typename T>
291 LogicalResult levelCheckRanks(T tosaOp) {
292 auto op = tosaOp.getOperation();
293 const TosaLevel tosaLevel = targetEnv.getLevel();
294 for (auto v : op->getOperands()) {
295 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
296 return failure();
297 }
298
299 for (auto v : op->getResults()) {
300 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
301 return failure();
302 }
303 return success();
304 }
305 // Level check shape lengths of all operands and results of an operation that
306 // are tosa.shape type.
307 template <typename T>
308 LogicalResult levelCheckShapeLengths(T tosaOp) {
309 for (const auto &v : tosaOp->getOperands()) {
310 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
311 return failure();
312 }
313 for (const auto &v : tosaOp->getResults()) {
314 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
315 return failure();
316 }
317
318 return success();
319 }
320
321 // Level check ranks and sizes.
322 LogicalResult levelCheckRanksAndSizes(Operation *op);
323
324 // Pool Op: level check kernel/stride/pad values
325 template <typename T>
326 LogicalResult levelCheckPool(Operation *op) {
327 if (auto poolOp = dyn_cast<T>(op)) {
328 for (auto k : poolOp.getKernel()) {
329 if (failed(levelCheckKernel(op, k, "kernel"))) {
330 return failure();
331 }
332 }
333 for (auto s : poolOp.getStride()) {
334 if (failed(levelCheckStride(op, s, "stride"))) {
335 return failure();
336 }
337 }
338 for (auto p : poolOp.getPad()) {
339 if (failed(levelCheckKernel(op, p, "pad"))) {
340 return failure();
341 }
342 }
343 }
344 return success();
345 }
346
347 // Conv Op: level check dilation/stride/pad values
348 template <typename T>
349 LogicalResult levelCheckConv(Operation *op) {
350 if (auto convOp = dyn_cast<T>(op)) {
351
352 for (auto k : convOp.getDilation()) {
353 if (failed(levelCheckKernel(op, k, "dilation"))) {
354 return failure();
355 }
356 }
357 for (auto p : convOp.getPad()) {
358 if (failed(levelCheckKernel(op, p, "pad"))) {
359 return failure();
360 }
361 }
362 for (auto s : convOp.getStride()) {
363 if (failed(levelCheckStride(op, s, "stride"))) {
364 return failure();
365 }
366 }
367 auto dilation = convOp.getDilation();
368 if (ShapedType weightType =
369 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
370 auto shape = weightType.getShape();
371 if (isa<tosa::Conv2DOp>(op)) {
372 assert(shape.size() == 4);
373 assert(dilation.size() == 2);
374 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
375 "dilation_y * KH")) ||
376 failed(levelCheckKernel(op, dilation[1] * shape[2],
377 "dilation_x * KW")))
378 return failure();
379 } else if (isa<tosa::Conv3DOp>(op)) {
380 assert(shape.size() == 5);
381 assert(dilation.size() == 3);
382 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
383 "dilation_d * KD")) ||
384 failed(levelCheckKernel(op, dilation[1] * shape[2],
385 "dilation_y * KH")) ||
386 failed(levelCheckKernel(op, dilation[2] * shape[3],
387 "dilation_x * KW")))
388 return failure();
389 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
390 assert(shape.size() == 4);
391 assert(dilation.size() == 2);
392 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
393 "dilation_y * KH")) ||
394 failed(levelCheckKernel(op, dilation[1] * shape[1],
395 "dilation_x * KW")))
396 return failure();
397 }
398 }
399 }
400 return success();
401 }
402
403 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
404 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
405 if (!convOp)
406 return success();
407
408 SmallVector<int64_t> padValues;
409 if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
410 for (const auto p : padValues)
411 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
412 return failure();
413 }
414
415 SmallVector<int64_t> strideValues;
416 if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
417 strideValues)) {
418 for (const auto s : strideValues)
419 if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
420 return failure();
421 }
422
423 SmallVector<int64_t> dilationValues;
424 if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
425 dilationValues)) {
426 int64_t KH = ShapedType::kDynamic;
427 int64_t KW = ShapedType::kDynamic;
428 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
429 KH = weightDataShape.getDimSize(1);
430 KW = weightDataShape.getDimSize(2);
431 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
432 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
433 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
434
435 if (!ShapedType::isDynamic(KH) &&
436 failed(levelCheckKernel(op, dilationValues[0] * KH,
437 "dilation_y * KH <= MAX_KERNEL)")))
438 return failure();
439
440 if (!ShapedType::isDynamic(KW) &&
441 failed(levelCheckKernel(op, dilationValues[1] * KW,
442 "dilation_x * KW <= MAX_KERNEL)")))
443 return failure();
444 }
445
446 return success();
447 }
448
449 // FFT op: level check H, W in input shape [N,H,W]
450 template <typename T>
451 LogicalResult levelCheckFFT(Operation *op) {
452 if (isa<T>(op)) {
453 for (auto v : op->getOperands()) {
454 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
455 auto shape = type.getShape();
456 assert(shape.size() == 3);
457 if (failed(levelCheckKernel(op, shape[1], "H")) ||
458 failed(levelCheckKernel(op, shape[2], "W"))) {
459 return failure();
460 }
461 }
462 }
463 }
464 return success();
465 }
466
467 // TransposeConv2d op: level check kH/kW, outpad, and stride
468 LogicalResult levelCheckTransposeConv2d(Operation *op) {
469 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
470 if (ShapedType filterType =
471 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
472 auto shape = filterType.getShape();
473 assert(shape.size() == 4);
474 // level check kernel sizes for kH and KW
475 if (failed(levelCheckKernel(op, shape[1], "KH")) ||
476 failed(levelCheckKernel(op, shape[2], "KW"))) {
477 return failure();
478 }
479 }
480 for (auto p : transpose.getOutPad()) {
481 if (failed(levelCheckKernel(op, p, "pad"))) {
482 return failure();
483 }
484 }
485 for (auto s : transpose.getStride()) {
486 if (failed(levelCheckStride(op, s, "stride"))) {
487 return failure();
488 }
489 }
490 }
491 return success();
492 }
493
494 // Resize op: level check max scales
495 LogicalResult levelCheckResize(Operation *op) {
496 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
497 SmallVector<int64_t> scale;
498 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
499 scale)) {
500 return failure();
501 }
502 const int64_t scaleYN = scale[0];
503 const int64_t scaleYD = scale[1];
504 const int64_t scaleXN = scale[2];
505 const int64_t scaleXD = scale[3];
506 if (failed(
507 levelCheckScale(op, scaleYN / scaleYD, "scale_y_n/scale_y_d")) ||
508 failed(
509 levelCheckScale(op, scaleXN / scaleXD, "scale_x_n/scale_x_d"))) {
510 return failure();
511 }
512 }
513 return success();
514 }
515
516 // Recursively perform a bottom-up search to determine the maximum nesting
517 // depth, starting from a specific operation and continuing up to the function
518 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
519 // time a new nested `region` is encountered.
520 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
521 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
522 return;
523
524 op = op->getParentOp();
525 if (!op)
526 return;
527
528 depth++;
529 getMaxNestedDepth(op, depth);
530 }
531
532 LogicalResult levelCheckMaxNesting(Operation *op) {
533 int32_t maxNestedDepth = 0;
534 getMaxNestedDepth(op, maxNestedDepth);
535
536 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
537 if (maxNestedDepth >= maxNestingLevel)
538 return op->emitOpError()
539 << "failed level check: tosa_nesting_depth < MAX_NESTING" << " ("
540 << maxNestingLevel << "), got " << maxNestedDepth;
541 return success();
542 }
543
544 LogicalResult levelCheckListSize(Operation *op) {
545 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
546 return levelCheckListSize(op, concat.getInput1().size(), "input1");
547 }
548 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
549 if (failed(levelCheckListSize(op, custom.getInputList().size(),
550 "input_list")) ||
551 failed(levelCheckListSize(op, custom.getOutputList().size(),
552 "output_list"))) {
553 return failure();
554 }
555 }
556 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
557 if (failed(
558 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
559 failed(levelCheckListSize(op, condIf.getOutputList().size(),
560 "outputs"))) {
561 return failure();
562 }
563 }
564 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
565 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
566 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
567 return failure();
568 }
569 }
570 if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
571 return levelCheckListSize(op, concat_shape.getInput().size(), "input");
572 return success();
573 }
574
575 LogicalResult attributeCheckRescale(Operation *op) {
576 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
577 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
578 !targetEnv.allows(Extension::doubleround)) {
579 op->emitOpError()
580 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
581 << "requires extension [doubleround]";
582 return failure();
583 }
584 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
585 !targetEnv.allows(Extension::inexactround)) {
586 op->emitOpError()
587 << "failed attribute check: rounding_mode = INEXACT_ROUND "
588 << "requires extension [inexactround]";
589 return failure();
590 }
591 }
592 return success();
593 }
594
595 LogicalResult CheckVariable(Operation *op);
596 LogicalResult CheckVariableReadOrWrite(Operation *op);
597 bool isValidElementType(Type type, const bool allowUnsigned = false);
598
599 SmallVector<
600 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
601 constCheckers;
603 TosaProfileCompliance profileComp;
604 tosa::TargetEnv targetEnv;
605};
606
607template <>
608LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
609 auto *op = tosaOp.getOperation();
610 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
611 targetEnv.getLevel().MAX_RANK)))
612 return failure();
613
614 // rank(output) = rank(input) - 1
615 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
616 targetEnv.getLevel().MAX_RANK - 1)))
617 return failure();
618
619 return success();
620}
621
622template <>
623LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
624 auto *op = tosaOp.getOperation();
625
626 // Only the condition input has rank limitation.
627 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
628 targetEnv.getLevel().MAX_RANK)))
629 return failure();
630
631 return success();
632}
633
634template <>
635LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
636 auto *op = tosaOp.getOperation();
637 auto variableType = getVariableType(tosaOp);
638 if (failed(levelCheckRank(op, variableType, "variable type",
639 targetEnv.getLevel().MAX_RANK)))
640 return failure();
641
642 return success();
643}
644
645template <>
646LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
647 auto *op = tosaOp.getOperation();
648 auto variableType = getVariableType(tosaOp);
649 if (failed(levelCheckSize(op, variableType, "variable type")))
650 return failure();
651
652 return success();
653}
654
655LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
656#define CHECK_RANKS_AND_SIZES(tosaOp) \
657 if (isa<tosa::tosaOp##Op>(op)) { \
658 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
659 return failure(); \
660 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
661 return failure(); \
662 }
663
664#define CHECK_SIZES(tosaOp) \
665 if (isa<tosa::tosaOp##Op>(op)) { \
666 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
667 return failure(); \
668 }
669
670#define CHECK_SHAPE_LEN(tosaOp) \
671 if (isa<tosa::tosaOp##Op>(op)) { \
672 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
673 return failure(); \
674 }
675
676 // Tensor Operators
677 CHECK_RANKS_AND_SIZES(ArgMax);
678 // Activation Functions
681 CHECK_RANKS_AND_SIZES(Sigmoid);
683 // Elementwise Binary Operators
685 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
686 CHECK_RANKS_AND_SIZES(BitwiseAnd);
687 CHECK_RANKS_AND_SIZES(BitwiseOr);
688 CHECK_RANKS_AND_SIZES(BitwiseXor);
689 CHECK_RANKS_AND_SIZES(IntDiv);
690 CHECK_RANKS_AND_SIZES(LogicalAnd);
691 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
692 CHECK_RANKS_AND_SIZES(LogicalRightShift);
693 CHECK_RANKS_AND_SIZES(LogicalOr);
694 CHECK_RANKS_AND_SIZES(LogicalXor);
695 CHECK_RANKS_AND_SIZES(Maximum);
696 CHECK_RANKS_AND_SIZES(Minimum);
701 // Elementwise Unary Operators
703 CHECK_RANKS_AND_SIZES(BitwiseNot);
710 CHECK_RANKS_AND_SIZES(LogicalNot);
711 CHECK_RANKS_AND_SIZES(Negate);
712 CHECK_RANKS_AND_SIZES(Reciprocal);
715 // Elementwise Ternary Operators
716 CHECK_RANKS_AND_SIZES(Select);
717 // Comparison Operators
719 CHECK_RANKS_AND_SIZES(Greater);
720 CHECK_RANKS_AND_SIZES(GreaterEqual);
721 // Reduction Operators
722 CHECK_RANKS_AND_SIZES(ReduceAll);
723 CHECK_RANKS_AND_SIZES(ReduceAny);
724 CHECK_RANKS_AND_SIZES(ReduceMax);
725 CHECK_RANKS_AND_SIZES(ReduceMin);
726 CHECK_RANKS_AND_SIZES(ReduceProduct);
727 CHECK_RANKS_AND_SIZES(ReduceSum);
728 // Data Layout Operators
729 CHECK_RANKS_AND_SIZES(Concat);
731 CHECK_RANKS_AND_SIZES(Reshape);
732 CHECK_RANKS_AND_SIZES(Reverse);
735 CHECK_RANKS_AND_SIZES(Transpose);
736 // Type Conversion
738 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
739 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
740 CHECK_RANKS_AND_SIZES(Rescale);
741 // Data Nodes
743 CHECK_RANKS_AND_SIZES(Identity);
744 // Control Flow Operators
746 // Variable Operators
747 CHECK_RANKS_AND_SIZES(Variable);
748 CHECK_RANKS_AND_SIZES(VariableWrite);
749 CHECK_RANKS_AND_SIZES(VariableRead);
750 // Shape Operators
752
753 // For the following operators, check whether the size of each tensor
754 // operand is valid in a given Level.
755
756 // Tensor Operators
757 CHECK_SIZES(AvgPool2d);
758 CHECK_SIZES(Conv2D);
759 CHECK_SIZES(Conv2DBlockScaled);
760 CHECK_SIZES(Conv3D);
761 CHECK_SIZES(DepthwiseConv2D);
762 CHECK_SIZES(TransposeConv2D);
763 CHECK_SIZES(FFT2d);
764 CHECK_SIZES(MatMul);
765 CHECK_SIZES(MatmulTBlockScaled);
766 CHECK_SIZES(MaxPool2d);
767 CHECK_SIZES(RFFT2d);
768 // Scatter/Gather Operators
770 CHECK_SIZES(Scatter);
771 // Image Operators
772 CHECK_SIZES(Resize);
773 // Custom Operators
774 CHECK_SIZES(Custom);
775 // Control Flow Operators
776 CHECK_SIZES(While);
777 // Shape Operators
778 CHECK_SIZES(ConstShape);
779
780 // For the following operations, check whether the shape length of each
781 // operand is valid given a level.
782
783 // Shape Operators
784 CHECK_SHAPE_LEN(AddShape);
785 CHECK_SHAPE_LEN(AssertEqualShape);
786 CHECK_SHAPE_LEN(ConcatShape);
787 CHECK_SHAPE_LEN(DivCeilShape);
788 CHECK_SHAPE_LEN(DivFloorShape);
789 CHECK_SHAPE_LEN(Exp2Shape);
790 CHECK_SHAPE_LEN(Log2CeilShape);
791 CHECK_SHAPE_LEN(Log2FloorShape);
792 CHECK_SHAPE_LEN(MaxShape);
793 CHECK_SHAPE_LEN(MinShape);
794 CHECK_SHAPE_LEN(ModShape);
795 CHECK_SHAPE_LEN(MulShape);
796 CHECK_SHAPE_LEN(SliceShape);
797 CHECK_SHAPE_LEN(SubShape);
798
799#undef CHECK_RANKS_AND_SIZES
800#undef CHECK_SIZES
801#undef CHECK_SHAPE_LEN
802 return success();
803}
804
805// Perform the Level tensor size check on the tensor type.
806LogicalResult TosaValidation::levelCheckSize(Operation *op,
807 const Type &typeToCheck,
808 const StringRef operandOrResult) {
809 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
810 if (!type.hasRank())
811 return op->emitOpError() << "failed level check: unranked tensor";
812 auto shape = type.getShape();
813 for (auto dim : shape) {
814 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
815 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
816 const TosaSpecificationVersion minRequiredVersion(1, 1);
817 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
818 dimIsDynamic)
819 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
820 // resolved at backend compile time. Runtime dynamism is not currently
821 // supported. Checking this requirement is met is delegated to backends.
822 return success();
823
824 // When targeting TOSA 1.0 or below, dynamic dims are not supported
825 if (dimIsDynamic)
826 return op->emitOpError() << "failed level check: " << operandOrResult
827 << " shape dimension cannot be dynamic when"
828 << " targeting TOSA specification version 1.0"
829 << " or below";
830 }
831
832 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
833 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
834 int64_t size = element_bytes * type.getNumElements();
835
836 // According to 1.11. Tensor Definitions of Tosa spec, the value of
837 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
838 // defined in 1.7. Levels.
839 // For each tensor, the number of tensor elements multiplied by the
840 // element size in bytes must be representable as a tensor_size_t.
841 const int64_t max_size =
842 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
843 if (size > max_size)
844 return op->emitOpError()
845 << "failed level check: " << operandOrResult
846 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
847 }
848 return success();
849}
850
851LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
852 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
853 // no need to do level checks
854 return success();
855 }
856
857 // check rank and sizes early so later checks can assume shaped operands
858 if (failed(levelCheckRanksAndSizes(op)))
859 return failure();
860
861 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
862 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
863 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
864 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
865 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
866 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
867 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
868 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
869 failed(levelCheckConv2DBlockScaled(op))) {
870 return failure();
871 }
872
873 // level check MAX_TENSOR_LIST_SIZE
874 if (failed(levelCheckListSize(op))) {
875 return failure();
876 }
877
878 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
879 if (failed(levelCheckMaxNesting(op))) {
880 return failure();
881 }
882 }
883
884 return success();
885}
886
887LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
888 if (failed(attributeCheckRescale(op)))
889 return failure();
890 return success();
891}
892
893inline bool CompatibleTypes(const mlir::Type &type,
894 const mlir::Type &declaredType) {
895 // for now, simply use type equality comparison
896 return type == declaredType;
897}
898
899LogicalResult TosaValidation::CheckVariable(Operation *op) {
900 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
901 mlir::StringAttr nameAttr = variableOp.getNameAttr();
902
903 if (variablesMap.count(nameAttr))
904 return op->emitOpError() << "name has already been declared";
905
906 auto elementType = variableOp.getType();
907 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
908 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
909 RankedTensorType variableType =
910 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
911
912 variablesMap[nameAttr] = variableType;
913 }
914
915 return success();
916}
917
918LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
919 if (isa<mlir::tosa::VariableReadOp>(op) ||
920 isa<mlir::tosa::VariableWriteOp>(op)) {
921 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
922 if (!variablesMap.count(nameAttr))
923 return op->emitOpError() << "name has not been declared";
924
925 auto varType = variablesMap[nameAttr];
926
927 for (auto v : op->getOperands()) {
928 auto type = v.getType();
929 if (!CompatibleTypes(type, varType))
930 return op->emitOpError() << "operand type does not equal variable type";
931 }
932
933 for (auto v : op->getResults()) {
934 auto type = v.getType();
935 if (!CompatibleTypes(type, varType))
936 return op->emitOpError() << "result type does not equal variable type";
937 }
938 }
939
940 return success();
941}
942
943LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
944 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
945 return failure();
946 return success();
947}
948
949LogicalResult checkErrorIfResize(Operation *op) {
950 auto resize = dyn_cast<tosa::ResizeOp>(op);
951 if (!resize)
952 return success();
953
954 const Value input = resize.getInput();
955 const Value output = resize.getOutput();
956 const RankedTensorType inputType =
957 llvm::dyn_cast<RankedTensorType>(input.getType());
958 const RankedTensorType outputType =
959 llvm::dyn_cast<RankedTensorType>(output.getType());
960
961 if (!inputType || !outputType)
962 return op->emitOpError("expect ranked input/output tensor");
963
964 // Ensure the image size is supported by GPU APIs and that for integer
965 // implementations, position * stride does not overflow int32_t.
966 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
967 const SmallVector<int64_t, 4> sizes = {
968 outputType.getDimSize(1), outputType.getDimSize(2),
969 inputType.getDimSize(1), inputType.getDimSize(2)};
970 const int64_t *maxDim = llvm::max_element(sizes);
971 if (maxDim != sizes.end() && *maxDim >= 16384)
972 return op->emitOpError(
973 "expect input/output height/width dims to be < 16384, ")
974 << "got [OH, OW, IH, IW] = " << sizes;
975 }
976
977 SmallVector<int64_t> scale;
978 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
979 return failure();
980
981 const int64_t scaleYN = scale[0];
982 const int64_t scaleYD = scale[1];
983 const int64_t scaleXN = scale[2];
984 const int64_t scaleXD = scale[3];
985
986 // Ensure scale values don't overflow int32 accumulator
987 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
988 return op->emitOpError(
989 "expect all scale numerator values to be <= (1 << 11), "
990 "got scale_y_n=")
991 << scaleYN << ", scale_x_n=" << scaleXN;
992
993 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
994 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
995 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
996
997 SmallVector<int64_t> offset;
998 SmallVector<int64_t> border;
999 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
1000 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
1001 return failure();
1002
1003 const int64_t offsetY = offset[0];
1004 const int64_t offsetX = offset[1];
1005 // Set a consistent lower limit of 1/16 downscale to simplify
1006 // implementations
1007 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1008 return op->emitOpError(
1009 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1010 << offsetY << "/" << scaleYN;
1011 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1012 return op->emitOpError(
1013 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1014 << offsetX << "/" << scaleXN;
1015
1016 const int64_t borderY = border[0];
1017 const int64_t borderX = border[1];
1018 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1019 return op->emitOpError(
1020 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1021 << borderY << "/" << scaleYN;
1022 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1023 return op->emitOpError(
1024 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1025 << borderX << "/" << scaleXN;
1026
1027 // The following section of code is mostly duplicated with ResizeOp::verify().
1028 //
1029 // In TOSA specification, we do not support broadcast behavior.
1030 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
1031 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
1032 // existing code, we keep the rewrite pattern untouched. So, we need
1033 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
1034 //
1035 // Here is a strict checking to conform TOSA specification.
1036 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
1037 auto idivCheck = [](const int64_t lhs,
1038 const int64_t rhs) -> std::optional<int64_t> {
1039 if (lhs % rhs != 0)
1040 return std::nullopt;
1041 return lhs / rhs;
1042 };
1043
1044 const int64_t oh = outputType.getDimSize(1);
1045 const int64_t ow = outputType.getDimSize(2);
1046 const int64_t ih = inputType.getDimSize(1);
1047 const int64_t iw = inputType.getDimSize(2);
1048
1049 if (ih != ShapedType::kDynamic) {
1050 const std::optional<int64_t> calculatedOutHeightMinusOne =
1051 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1052 if (!calculatedOutHeightMinusOne.has_value())
1053 return op->emitOpError(
1054 "expected (input_height - 1) * scale_y_n - offset_y + "
1055 "border_y ")
1056 << "to be wholly divisible by scale_y_d, got ((" << ih
1057 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1058 << ") / " << scaleYD;
1059 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1060 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1061 return op->emitOpError(
1062 "calculated output height did not match expected: ")
1063 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1064 }
1065
1066 if (iw != ShapedType::kDynamic) {
1067 const std::optional<int64_t> calculatedOutWidthMinusOne =
1068 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1069 if (!calculatedOutWidthMinusOne.has_value())
1070 return op->emitOpError(
1071 "expected (input_width - 1) * scale_x_n - offset_x + "
1072 "border_x ")
1073 << "to be wholly divisible by scale_x_d, got ((" << iw
1074 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1075 << ") / " << scaleXD;
1076 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1077 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1078 return op->emitOpError("calculated output width did not match expected: ")
1079 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1080 }
1081
1082 return success();
1083}
1084
1085LogicalResult checkErrorIfMul(Operation *op) {
1086 auto mul = dyn_cast<tosa::MulOp>(op);
1087 if (!mul)
1088 return success();
1089
1090 // REQUIRE(0 <= shift && shift <= 63);
1091 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1092 ElementsAttr shift_elem;
1093 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
1094 return success();
1095 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1096 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1097 if (inputElemType.isInteger(32)) {
1098 // 0 <= shift <= 63 for int32_t type
1099 if (shift < 0 || shift > 63)
1100 return op->emitOpError()
1101 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1102 } else {
1103 // shift must be 0 for all other types
1104 if (shift != 0)
1105 return op->emitOpError()
1106 << "requires shift = 0 for all input data types that "
1107 "are not int32_t, but got: "
1108 << shift;
1109 }
1110
1111 return success();
1112}
1113
1114LogicalResult checkErrorIfTable(Operation *op) {
1115 auto table = dyn_cast<tosa::TableOp>(op);
1116 if (!table)
1117 return success();
1118
1119 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1120 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1121 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1122
1123 const ShapeAdaptor tableShape(table.getTable().getType());
1124 if (tableShape.hasStaticShape()) {
1125 const auto numElements = tableShape.getNumElements();
1126 if (numElements != tableSize)
1127 return op->emitOpError() << "requires table size of " << tableSize
1128 << ", got " << numElements;
1129 }
1130
1131 return success();
1132}
1133
1134LogicalResult checkErrorIfRescale(Operation *op) {
1135 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1136 if (!rescale)
1137 return success();
1138
1139 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1140 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1141 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1142 !outputType.getElementType().isInteger())
1143 return success();
1144
1145 auto inElemType = inputType.getElementType();
1146 auto outElemType = outputType.getElementType();
1147 auto inWidth = inElemType.getIntOrFloatBitWidth();
1148 auto outWidth = outElemType.getIntOrFloatBitWidth();
1149
1150 bool inputUnsigned = rescale.getInputUnsigned();
1151 bool outputUnsigned = rescale.getOutputUnsigned();
1152
1153 bool scale32 = rescale.getScale32();
1154 auto roundingMode = rescale.getRoundingMode();
1155
1156 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1157 if (scale32 && inWidth == 48)
1158 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1159
1160 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1161 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1162 return op->emitOpError()
1163 << "DOUBLE_ROUND is only allowed with scale32=true.";
1164
1165 // ERROR_IF(input_unsigned && output_unsigned)
1166 if (inputUnsigned && outputUnsigned)
1167 return op->emitOpError() << "input and output cannot be both unsigned.";
1168
1169 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1170 if (outWidth == 32 && inputUnsigned)
1171 return op->emitOpError()
1172 << "i32 output type is not allowed with unsigned input.";
1173
1174 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1175 if (inWidth == 32 && outputUnsigned)
1176 return op->emitOpError()
1177 << "i32 input type is not allowed with unsigned output.";
1178
1179 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1180 if (inWidth == 48 && outputUnsigned)
1181 return op->emitOpError()
1182 << "i48 input type is not allowed with unsigned output.";
1183
1184 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1185 if (inWidth == 48 && inputUnsigned)
1186 return op->emitOpError() << "i48 input type cannot be unsigned.";
1187
1188 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1189 if (inWidth == 32 && inputUnsigned)
1190 return op->emitOpError() << "i32 input type cannot be unsigned.";
1191
1192 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1193 if (outWidth == 32 && outputUnsigned)
1194 return op->emitOpError() << "i32 output type cannot be unsigned.";
1195
1196 return success();
1197}
1198
1199LogicalResult checkErrorIfPad(Operation *op) {
1200 auto pad = dyn_cast<tosa::PadOp>(op);
1201 if (!pad)
1202 return success();
1203
1204 DenseIntElementsAttr paddingAttr;
1205 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1206 // Pad verifier will catch this
1207 return success();
1208
1209 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1210 if (val.getSExtValue() < 0)
1211 return op->emitOpError() << "padding value must all be non-negative, got "
1212 << val.getSExtValue();
1213 }
1214
1215 return success();
1216}
1217
1218LogicalResult checkErrorIfReshape(Operation *op) {
1219 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1220 if (!reshapeOp)
1221 return success();
1222
1223 SmallVector<int64_t> shapeValues;
1224 if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
1225 shapeValues))
1226 return success();
1227
1228 if (llvm::is_contained(shapeValues, kInferableDimSize))
1229 return op->emitOpError("shape input contains inferable dimension (")
1231 << ") "
1232 "which does not conform to the TOSA specification";
1233
1234 return success();
1235}
1236
1237LogicalResult checkErrorIfSlice(Operation *op) {
1238 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1239 if (!sliceOp)
1240 return success();
1241
1242 SmallVector<int64_t> startValues;
1243 SmallVector<int64_t> sizeValues;
1244 const bool hasStartValues = tosa::getConstShapeValues(
1245 sliceOp.getStart().getDefiningOp(), startValues);
1246 const bool hasSizeValues =
1247 tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
1248
1249 if (hasStartValues && llvm::is_contained(startValues, kInferableDimSize))
1250 return op->emitOpError("start input contains inferable dimension (")
1252 << ") which does not conform to the TOSA specification";
1253 if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDimSize))
1254 return op->emitOpError("size input contains inferable dimension (")
1256 << ") which "
1257 "does not conform to the TOSA specification";
1258
1259 return success();
1260}
1261
1262static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1263 return llvm::all_of(op->getOperands(), [&](auto operand) {
1264 Region *operandRegion = operand.getParentRegion();
1265 return operandRegion && region->isAncestor(operandRegion);
1266 });
1267}
1268
1269static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1270 bool noLiveInValue = true;
1271 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1272 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1273 noLiveInValue = false;
1274 return WalkResult::interrupt();
1275 }
1276 return WalkResult::advance();
1277 });
1278 return noLiveInValue ? success() : failure();
1279}
1280
1281LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1282 StringRef regionName) {
1283 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1284 return success();
1285 return op->emitOpError()
1286 << "is not conformant to the TOSA specification. It requires the '"
1287 << regionName << "' region is isolated from above.\n";
1288}
1289
1290LogicalResult checkErrorIfCondIf(Operation *op) {
1291 auto ifOp = dyn_cast<tosa::IfOp>(op);
1292 if (!ifOp)
1293 return success();
1294
1295 // Currently the dialect supports declaring cond_if operations that
1296 // have then/else regions that reference values from outside these
1297 // regions. According to the specification, all values used by the
1298 // then/else regions must be explicitly declared within the regions.
1299 // Therefore we must check that the then/else regions are
1300 // "isolated from above", in order to be conformant to the
1301 // specification.
1302 //
1303 // Note: the dialect currently supports two styles of syntax for
1304 // declaring "cond_if" operations. We'll refer to these as follows:
1305 //
1306 // Generic:
1307 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1308 // ^bb0(%arg3, %arg4):
1309 // tosa.yield %arg3
1310 // }, {
1311 // ^bb0(%arg3, %arg4):
1312 // tosa.yield %arg4
1313 // })
1314 //
1315 // Simplified:
1316 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1317 // ^bb0(%arg3, %arg4):
1318 // tosa.yield %arg3
1319 // } else {
1320 // ^bb0(%arg3, %arg4):
1321 // tosa.yield %arg4
1322 // }
1323
1324 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1325 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1326 return failure();
1327 return success();
1328}
1329
1330LogicalResult checkErrorIfWhileLoop(Operation *op) {
1331 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1332 if (!whileOp)
1333 return success();
1334
1335 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1336 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1337 return failure();
1338 return success();
1339}
1340
1341LogicalResult checkErrorIfScatter(Operation *op) {
1342 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1343 if (!scatterOp)
1344 return success();
1345
1346 // for constant indices, check that there are no duplicate values
1347 DenseIntElementsAttr indicesAttr;
1348 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1349 return success();
1350
1351 auto const indicesType =
1352 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1353 if (!indicesType || !indicesType.hasRank()) {
1354 op->emitOpError("expect ranked indices tensor");
1355 return failure();
1356 }
1357
1358 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1359 op->emitOpError("indices values contain duplicates");
1360 return failure();
1361 }
1362
1363 return success();
1364}
1365
1366LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1367 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1368 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1369 failed(checkErrorIfPad(op)) || failed(checkErrorIfReshape(op)) ||
1370 failed(checkErrorIfSlice(op)) || failed(checkErrorIfCondIf(op)) ||
1371 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1372 return failure();
1373 return success();
1374}
1375
1376bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1377 if (isa<FloatType>(type)) {
1378 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1379 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1380 Float6E3M2FNType, Float8E8M0FNUType>(type);
1381 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1382 if (intTy.isSignless()) {
1383 switch (intTy.getWidth()) {
1384 case 1:
1385 case 4:
1386 case 8:
1387 case 16:
1388 case 32:
1389 case 48:
1390 case 64:
1391 return true;
1392 }
1393 } else if (allowUnsigned && intTy.isUnsigned()) {
1394 switch (intTy.getWidth()) {
1395 case 8:
1396 case 16:
1397 case 32:
1398 return true;
1399 }
1400 }
1401 } else if (isa<tosa::shapeType>(type))
1402 return true;
1403 else if (isa<tosa::mxint8Type>(type))
1404 return true;
1405 return false;
1406}
1407
1408void TosaValidation::runOnOperation() {
1409 ModuleOp modOp = getOperation();
1410 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1411 if (!tosaDialect)
1412 return;
1413
1414 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1415 const auto maybeTargetEnv =
1416 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1417 if (failed(maybeTargetEnv))
1418 return signalPassFailure();
1419 targetEnv = *maybeTargetEnv;
1420
1421 modOp.walk([&](Operation *op) {
1422 if (op->getDialect() != tosaDialect)
1423 return;
1424
1425 // validate operator element types:
1426 // - rescale operator is allowed to have ui8/ui16/ui32
1427 // operands/results when strictOpSpecAlignment is false
1428 // - perform valid element type check at the beginning to
1429 // protect rest of code against quantized element types
1430 const bool allowUnsigned =
1431 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1432 for (Value operand : op->getOperands()) {
1433 auto elementTy = getElementTypeOrSelf(operand);
1434 if (!isValidElementType(elementTy, allowUnsigned)) {
1435 op->emitOpError() << "is not profile-aligned: element type "
1436 << elementTy << " is not legal";
1437 return signalPassFailure();
1438 }
1439 }
1440 for (Type resultTy : op->getResultTypes()) {
1441 auto elementTy = getElementTypeOrSelf(resultTy);
1442 if (!isValidElementType(elementTy, allowUnsigned)) {
1443 op->emitOpError() << "is not profile-aligned: element type "
1444 << elementTy << " is not legal";
1445 return signalPassFailure();
1446 }
1447 }
1448
1449 if (strictOpSpecAlignment &&
1450 failed(profileComp.checkProfile(op, targetEnv)))
1451 return signalPassFailure();
1452
1453 if (strictOpSpecAlignment &&
1454 failed(profileComp.checkExtension(op, targetEnv)))
1455 return signalPassFailure();
1456
1457 if (!allowInvalidOpDatatypeCombinations &&
1458 failed(profileComp.checkInvalid(op)))
1459 return signalPassFailure();
1460
1461 // Some uses of TOSA rely on the constant operands of particular
1462 // operations.
1463 if (failed(applyConstantOperandCheck(op)))
1464 signalPassFailure();
1465
1466 // do level checks
1467 if (failed(applyLevelCheck(op)))
1468 signalPassFailure();
1469
1470 // check additional attribute restrictions
1471 if (failed(applyAttributeCheck(op)))
1472 signalPassFailure();
1473
1474 // do variable type checks
1475 if (failed(applyVariableCheck(op)))
1476 signalPassFailure();
1477
1478 // do error if checks
1479 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1480 signalPassFailure();
1481 });
1482}
1483} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:568
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
Value getOperand(unsigned idx)
Definition Operation.h:379
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:100
TosaLevel getLevel() const
Definition TargetEnv.h:117
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:91
bool allows(Profile prof) const
Definition TargetEnv.h:127
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:113
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:45
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:620
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369