MLIR 23.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matches with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19#include <type_traits>
20
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Matchers.h"
28#include "mlir/Pass/Pass.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
33
34namespace mlir {
35namespace tosa {
36#define GEN_PASS_DEF_TOSAVALIDATION
37#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
38} // namespace tosa
39} // namespace mlir
40
41using namespace mlir;
42using namespace mlir::tosa;
43
44namespace {
45
46static LogicalResult
47checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
48 for (const auto index : operandIndices) {
49 Attribute attr;
50 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
51 return op->emitOpError("expected compile time resolvable constant, but "
52 "got variable value for operand #")
53 << index;
54 }
55 }
56 return success();
57}
58
59static LogicalResult checkConstantOperandMul(Operation *op,
60 const TargetEnv &env) {
61 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
62 // Check 'shift'
63 return checkConstantOperands(op, {2});
64 }
65 return success();
66}
67
68static LogicalResult checkConstantOperandTable(Operation *op,
69 const TargetEnv &env) {
70 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
71 // Check 'table'
72 return checkConstantOperands(op, {1});
73 }
74 return success();
75}
76
77static LogicalResult checkConstantOperandPad(Operation *op,
78 const TargetEnv &env) {
79 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
80 // Assume this op is zero-padding if padConst is not presented
81 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
82 // Check 'pad_const'
83 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
84 return checkConstantOperands(op, {2});
85 }
86 return success();
87}
88
89static LogicalResult checkConstantOperandRescale(Operation *op,
90 const TargetEnv &env) {
91 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
92 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
93 return checkConstantOperands(op, {1, 2, 3, 4});
94 }
95 return success();
96}
97
98template <typename T>
99static LogicalResult checkConstantOperandConvOps(Operation *op,
100 const TargetEnv &env) {
101 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
102 // Check 'input_zp' and 'weight_zp'
103 return checkConstantOperands(op, {3, 4});
104 }
105 return success();
106}
107
108static LogicalResult checkConstantOperandMatMul(Operation *op,
109 const TargetEnv &env) {
110 if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
111 // Check 'A_zp' and 'B_zp'
112 return checkConstantOperands(op, {2, 3});
113 }
114 return success();
115}
116
117static LogicalResult
118checkConstantOperandRowGatherBlockScaled(Operation *op, const TargetEnv &env) {
119 if (!env.allows(Extension::dynamic) &&
120 isa<tosa::RowGatherBlockScaledOp>(op)) {
121 auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
122 const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
123 return checkConstantOperands(op, {rowCountIndex});
124 }
125 return success();
126}
127
128static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
129 const TargetEnv &env) {
130 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
131 // Check 'input_zp' and 'output_zp'
132 return checkConstantOperands(op, {1, 2});
133 }
134 return success();
135}
136
137static LogicalResult
138checkConstantOperandAvgPool2dAdaptive(Operation *op, const TargetEnv &env) {
139 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dAdaptiveOp>(op)) {
140 // Check 'input_zp' and 'output_zp'.
141 // Note: 'kernel', 'stride', and 'pad' (operands 3, 4, 5) are not checked
142 // as they are tosa.shape types.
143 return checkConstantOperands(op, {1, 2});
144 }
145 return success();
146}
147
148static LogicalResult checkConstantOperandNegate(Operation *op,
149 const TargetEnv &env) {
150 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
151 // Check 'input1_zp' and 'output_zp'
152 return checkConstantOperands(op, {1, 2});
153 }
154 return success();
155}
156
157static LogicalResult checkConstantOperandSilceShape(Operation *op,
158 const TargetEnv &env) {
159 if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
160 // Check 'start' and 'size'
161 return checkConstantOperands(op, {1, 2});
162 }
163 return success();
164}
165
166//===----------------------------------------------------------------------===//
167// TOSA Validation Pass.
168//===----------------------------------------------------------------------===//
169
170struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
171public:
172 explicit TosaValidation() { populateConstantOperandChecks(); }
173
174 explicit TosaValidation(const TosaValidationOptions &options)
175 : TosaValidation() {
176 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
177 this->allowInvalidOpDatatypeCombinations =
178 options.allowInvalidOpDatatypeCombinations;
179 }
180 void runOnOperation() final;
181
182 LogicalResult applyConstantOperandCheck(Operation *op) {
183 for (auto &checker : constCheckers) {
184 if (failed(checker(op, targetEnv)))
185 return failure();
186 }
187 return success();
188 }
189
190 LogicalResult applyFunctionSignatureCheck(func::FuncOp op);
191 LogicalResult applyLevelCheck(Operation *op);
192 LogicalResult applyAttributeCheck(Operation *op);
193
194 // check variable read/write data types against variable declarations
195 LogicalResult applyVariableCheck(Operation *op);
196
197 // check error if conditions
198 LogicalResult applyErrorIfCheck(Operation *op);
199
200private:
201 void populateConstantOperandChecks() {
202 constCheckers.emplace_back(checkConstantOperandMul);
203 constCheckers.emplace_back(checkConstantOperandTable);
204 constCheckers.emplace_back(checkConstantOperandPad);
205 constCheckers.emplace_back(checkConstantOperandRescale);
206 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
207 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
208 constCheckers.emplace_back(
209 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
210 constCheckers.emplace_back(
211 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
212 constCheckers.emplace_back(checkConstantOperandMatMul);
213 constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
214 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
215 constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
216 constCheckers.emplace_back(checkConstantOperandNegate);
217 constCheckers.emplace_back(checkConstantOperandSilceShape);
218 }
219
220 LogicalResult levelCheck(Operation *op, const int32_t calculatedValue,
221 const int32_t maxLevel, const StringRef inputName,
222 const StringRef levelName) {
223 if (calculatedValue > maxLevel)
224 return op->emitOpError()
225 << "failed level check: " << inputName << " <= " << levelName
226 << " (" << maxLevel << "), got " << calculatedValue;
227 return success();
228 }
229
230 LogicalResult levelCheckKernel(Operation *op, int32_t v,
231 const StringRef inputName) {
232 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
233 "MAX_KERNEL");
234 }
235
236 LogicalResult levelCheckStride(Operation *op, int32_t v,
237 const StringRef inputName) {
238 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
239 "MAX_STRIDE");
240 }
241
242 LogicalResult levelCheckScale(Operation *op, int32_t v,
243 const StringRef inputName) {
244 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
245 "MAX_SCALE");
246 }
247
248 LogicalResult levelCheckListSize(Operation *op, int32_t v,
249 const StringRef inputName) {
250 const std::string inputDesc =
251 llvm::formatv("length(tensor_list_shape({0}))", inputName);
252 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
253 inputDesc, "MAX_TENSOR_LIST_SIZE");
254 }
255
256 // Perform the Level Rank check on the tensor type.
257 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
258 const StringRef operandOrResult,
259 int32_t highest_rank) {
260 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
261 if (!type.hasRank())
262 return op->emitOpError() << "failed level check: unranked tensor";
263 if (type.getRank() > highest_rank)
264 return op->emitOpError() << "failed level check: " << operandOrResult
265 << " rank(shape) <= MAX_RANK";
266 }
267 return success();
268 }
269
270 // Perform the Level Rank check on the tensor value.
271 LogicalResult levelCheckRank(Operation *op, const Value &v,
272 const StringRef operandOrResult,
273 int32_t highest_rank) {
274 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
275 }
276
277 // Perform the Level tensor size check on the tensor type.
278 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
279 const StringRef operandOrResult);
280
281 // Perform the Level tensor size check on the tensor value.
282 LogicalResult levelCheckSize(Operation *op, const Value &v,
283 const StringRef operandOrResult) {
284 return levelCheckSize(op, v.getType(), operandOrResult);
285 }
286
287 // Perform the Level shape length check on a value.
288 LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
289 const StringRef operandOrResult) {
290 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
291 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
292 return op->emitOpError()
293 << "failed shape type level check: " << typeToCheck
294 << " exceeds MAX_SHAPE_LEN";
295 }
296 return success();
297 }
298
299 // Level check sizes of all operands and results of the operation.
300 template <typename T>
301 LogicalResult levelCheckSizes(T tosaOp) {
302 auto op = tosaOp.getOperation();
303 for (auto v : op->getOperands()) {
304 if (failed(levelCheckSize(op, v, "operand")))
305 return failure();
306 }
307
308 for (auto v : op->getResults()) {
309 if (failed(levelCheckSize(op, v, "result")))
310 return failure();
311 }
312 return success();
313 }
314
315 // Level check ranks of all operands, attribute and results of the operation.
316 template <typename T>
317 LogicalResult levelCheckRanks(T tosaOp) {
318 auto op = tosaOp.getOperation();
319 const TosaLevel tosaLevel = targetEnv.getLevel();
320 for (auto v : op->getOperands()) {
321 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
322 return failure();
323 }
324
325 for (auto v : op->getResults()) {
326 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
327 return failure();
328 }
329 return success();
330 }
331 // Level check shape lengths of all operands and results of an operation that
332 // are tosa.shape type.
333 template <typename T>
334 LogicalResult levelCheckShapeLengths(T tosaOp) {
335 for (const auto &v : tosaOp->getOperands()) {
336 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
337 return failure();
338 }
339 for (const auto &v : tosaOp->getResults()) {
340 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
341 return failure();
342 }
343
344 return success();
345 }
346
347 // Level check ranks and sizes.
348 LogicalResult levelCheckRanksAndSizes(Operation *op);
349
350 // Pool Op: level check kernel/stride/pad values
351 template <typename T>
352 LogicalResult levelCheckPool(Operation *op) {
353 if (auto poolOp = dyn_cast<T>(op)) {
354 for (auto k : poolOp.getKernel()) {
355 if (failed(levelCheckKernel(op, k, "kernel"))) {
356 return failure();
357 }
358 }
359 for (auto s : poolOp.getStride()) {
360 if (failed(levelCheckStride(op, s, "stride"))) {
361 return failure();
362 }
363 }
364 for (auto p : poolOp.getPad()) {
365 if (failed(levelCheckKernel(op, p, "pad"))) {
366 return failure();
367 }
368 }
369 }
370 return success();
371 }
372
373 template <typename T>
374 static constexpr bool IsSupportedAdaptivePoolOp =
375 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
376 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
377
378 template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
379 int>::type = 0>
380 LogicalResult levelCheckAdaptivePool(Operation *op) {
381 auto poolOp = dyn_cast<T>(op);
382 if (!poolOp)
383 return success();
384
385 SmallVector<int64_t> kernelValues;
386 if (tosa::getConstShapeValues(poolOp.getKernel().getDefiningOp(),
387 kernelValues)) {
388 for (const auto k : kernelValues)
389 if (failed(levelCheckKernel(op, k, "kernel")))
390 return failure();
391 }
392
393 SmallVector<int64_t> strideValues;
394 if (tosa::getConstShapeValues(poolOp.getStride().getDefiningOp(),
395 strideValues)) {
396 for (const auto s : strideValues)
397 if (failed(levelCheckStride(op, s, "stride")))
398 return failure();
399 }
400
401 SmallVector<int64_t> padValues;
402 if (tosa::getConstShapeValues(poolOp.getPad().getDefiningOp(), padValues)) {
403 for (const auto p : padValues)
404 if (failed(levelCheckKernel(op, p, "pad")))
405 return failure();
406 }
407
408 return success();
409 }
410
411 // Conv Op: level check dilation/stride/pad values
412 template <typename T>
413 LogicalResult levelCheckConv(Operation *op) {
414 if (auto convOp = dyn_cast<T>(op)) {
415
416 for (auto k : convOp.getDilation()) {
417 if (failed(levelCheckKernel(op, k, "dilation"))) {
418 return failure();
419 }
420 }
421 for (auto p : convOp.getPad()) {
422 if (failed(levelCheckKernel(op, p, "pad"))) {
423 return failure();
424 }
425 }
426 for (auto s : convOp.getStride()) {
427 if (failed(levelCheckStride(op, s, "stride"))) {
428 return failure();
429 }
430 }
431 auto dilation = convOp.getDilation();
432 if (ShapedType weightType =
433 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
434 auto shape = weightType.getShape();
435 if (isa<tosa::Conv2DOp>(op)) {
436 assert(shape.size() == 4);
437 assert(dilation.size() == 2);
438 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
439 "dilation_y * KH")) ||
440 failed(levelCheckKernel(op, dilation[1] * shape[2],
441 "dilation_x * KW")))
442 return failure();
443 } else if (isa<tosa::Conv3DOp>(op)) {
444 assert(shape.size() == 5);
445 assert(dilation.size() == 3);
446 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
447 "dilation_d * KD")) ||
448 failed(levelCheckKernel(op, dilation[1] * shape[2],
449 "dilation_y * KH")) ||
450 failed(levelCheckKernel(op, dilation[2] * shape[3],
451 "dilation_x * KW")))
452 return failure();
453 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
454 assert(shape.size() == 4);
455 assert(dilation.size() == 2);
456 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
457 "dilation_y * KH")) ||
458 failed(levelCheckKernel(op, dilation[1] * shape[1],
459 "dilation_x * KW")))
460 return failure();
461 }
462 }
463 }
464 return success();
465 }
466
467 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
468 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
469 if (!convOp)
470 return success();
471
472 SmallVector<int64_t> padValues;
473 if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
474 for (const auto p : padValues)
475 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
476 return failure();
477 }
478
479 SmallVector<int64_t> strideValues;
480 if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
481 strideValues)) {
482 for (const auto s : strideValues)
483 if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
484 return failure();
485 }
486
487 SmallVector<int64_t> dilationValues;
488 if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
489 dilationValues)) {
490 int64_t KH = ShapedType::kDynamic;
491 int64_t KW = ShapedType::kDynamic;
492 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
493 KH = weightDataShape.getDimSize(1);
494 KW = weightDataShape.getDimSize(2);
495 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
496 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
497 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
498
499 if (!ShapedType::isDynamic(KH) &&
500 failed(levelCheckKernel(op, dilationValues[0] * KH,
501 "dilation_y * KH <= MAX_KERNEL)")))
502 return failure();
503
504 if (!ShapedType::isDynamic(KW) &&
505 failed(levelCheckKernel(op, dilationValues[1] * KW,
506 "dilation_x * KW <= MAX_KERNEL)")))
507 return failure();
508 }
509
510 return success();
511 }
512
513 // FFT op: level check H, W in input shape [N,H,W]
514 template <typename T>
515 LogicalResult levelCheckFFT(Operation *op) {
516 if (isa<T>(op)) {
517 for (auto v : op->getOperands()) {
518 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
519 auto shape = type.getShape();
520 assert(shape.size() == 3);
521 if (failed(levelCheckKernel(op, shape[1], "H")) ||
522 failed(levelCheckKernel(op, shape[2], "W"))) {
523 return failure();
524 }
525 }
526 }
527 }
528 return success();
529 }
530
531 // TransposeConv2d op: level check kH/kW, outpad, and stride
532 LogicalResult levelCheckTransposeConv2d(Operation *op) {
533 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
534 if (ShapedType filterType =
535 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
536 auto shape = filterType.getShape();
537 assert(shape.size() == 4);
538 // level check kernel sizes for kH and KW
539 if (failed(levelCheckKernel(op, shape[1], "KH")) ||
540 failed(levelCheckKernel(op, shape[2], "KW"))) {
541 return failure();
542 }
543 }
544 for (auto p : transpose.getOutPad()) {
545 if (failed(levelCheckKernel(op, p, "pad"))) {
546 return failure();
547 }
548 }
549 for (auto s : transpose.getStride()) {
550 if (failed(levelCheckStride(op, s, "stride"))) {
551 return failure();
552 }
553 }
554 }
555 return success();
556 }
557
558 // Resize op: level check max scales
559 LogicalResult levelCheckResize(Operation *op) {
560 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
561 SmallVector<int64_t> scale;
562 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
563 scale)) {
564 return failure();
565 }
566 const int64_t scaleYN = scale[0];
567 const int64_t scaleYD = scale[1];
568 const int64_t scaleXN = scale[2];
569 const int64_t scaleXD = scale[3];
570 if (failed(
571 levelCheckScale(op, scaleYN / scaleYD, "scale_y_n/scale_y_d")) ||
572 failed(
573 levelCheckScale(op, scaleXN / scaleXD, "scale_x_n/scale_x_d"))) {
574 return failure();
575 }
576 }
577 return success();
578 }
579
580 // Recursively perform a bottom-up search to determine the maximum nesting
581 // depth, starting from a specific operation and continuing up to the function
582 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
583 // time a new nested `region` is encountered.
584 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
585 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
586 return;
587
588 op = op->getParentOp();
589 if (!op)
590 return;
591
592 depth++;
593 getMaxNestedDepth(op, depth);
594 }
595
596 LogicalResult levelCheckMaxNesting(Operation *op) {
597 int32_t maxNestedDepth = 0;
598 getMaxNestedDepth(op, maxNestedDepth);
599
600 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
601 if (maxNestedDepth >= maxNestingLevel)
602 return op->emitOpError()
603 << "failed level check: tosa_nesting_depth < MAX_NESTING" << " ("
604 << maxNestingLevel << "), got " << maxNestedDepth;
605 return success();
606 }
607
608 LogicalResult levelCheckListSize(Operation *op) {
609 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
610 return levelCheckListSize(op, concat.getInput1().size(), "input1");
611 }
612 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
613 if (failed(levelCheckListSize(op, custom.getInputList().size(),
614 "input_list")) ||
615 failed(levelCheckListSize(op, custom.getOutputList().size(),
616 "output_list"))) {
617 return failure();
618 }
619 }
620 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
621 if (failed(
622 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
623 failed(levelCheckListSize(op, condIf.getOutputList().size(),
624 "outputs"))) {
625 return failure();
626 }
627 }
628 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
629 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
630 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
631 return failure();
632 }
633 }
634 if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
635 return levelCheckListSize(op, concat_shape.getInput().size(), "input");
636 return success();
637 }
638
639 LogicalResult attributeCheckRescale(Operation *op) {
640 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
641 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
642 !targetEnv.allows(Extension::doubleround)) {
643 op->emitOpError()
644 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
645 << "requires extension [doubleround]";
646 return failure();
647 }
648 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
649 !targetEnv.allows(Extension::inexactround)) {
650 op->emitOpError()
651 << "failed attribute check: rounding_mode = INEXACT_ROUND "
652 << "requires extension [inexactround]";
653 return failure();
654 }
655 }
656 return success();
657 }
658
659 LogicalResult CheckVariable(Operation *op);
660 LogicalResult CheckVariableReadOrWrite(Operation *op);
661 bool isValidElementType(Type type, const bool allowUnsigned = false);
662
663 SmallVector<
664 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
665 constCheckers;
667 TosaProfileCompliance profileComp;
668 tosa::TargetEnv targetEnv;
669};
670
671template <>
672LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
673 auto *op = tosaOp.getOperation();
674 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
675 targetEnv.getLevel().MAX_RANK)))
676 return failure();
677
678 // rank(output) = rank(input) - 1
679 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
680 targetEnv.getLevel().MAX_RANK - 1)))
681 return failure();
682
683 return success();
684}
685
686template <>
687LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
688 auto *op = tosaOp.getOperation();
689
690 // Only the condition input has rank limitation.
691 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
692 targetEnv.getLevel().MAX_RANK)))
693 return failure();
694
695 return success();
696}
697
698template <>
699LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
700 auto *op = tosaOp.getOperation();
701 auto variableType = getVariableType(tosaOp);
702 if (failed(levelCheckRank(op, variableType, "variable type",
703 targetEnv.getLevel().MAX_RANK)))
704 return failure();
705
706 return success();
707}
708
709template <>
710LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
711 auto *op = tosaOp.getOperation();
712 auto variableType = getVariableType(tosaOp);
713 if (failed(levelCheckSize(op, variableType, "variable type")))
714 return failure();
715
716 return success();
717}
718
719LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
720#define CHECK_RANKS_AND_SIZES(tosaOp) \
721 if (isa<tosa::tosaOp##Op>(op)) { \
722 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
723 return failure(); \
724 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
725 return failure(); \
726 }
727
728#define CHECK_SIZES(tosaOp) \
729 if (isa<tosa::tosaOp##Op>(op)) { \
730 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
731 return failure(); \
732 }
733
734#define CHECK_SHAPE_LEN(tosaOp) \
735 if (isa<tosa::tosaOp##Op>(op)) { \
736 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
737 return failure(); \
738 }
739
740 // Tensor Operators
741 CHECK_RANKS_AND_SIZES(ArgMax);
742 // Activation Functions
745 CHECK_RANKS_AND_SIZES(Sigmoid);
747 // Elementwise Binary Operators
749 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
750 CHECK_RANKS_AND_SIZES(BitwiseAnd);
751 CHECK_RANKS_AND_SIZES(BitwiseOr);
752 CHECK_RANKS_AND_SIZES(BitwiseXor);
753 CHECK_RANKS_AND_SIZES(IntDiv);
754 CHECK_RANKS_AND_SIZES(LogicalAnd);
755 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
756 CHECK_RANKS_AND_SIZES(LogicalRightShift);
757 CHECK_RANKS_AND_SIZES(LogicalOr);
758 CHECK_RANKS_AND_SIZES(LogicalXor);
759 CHECK_RANKS_AND_SIZES(Maximum);
760 CHECK_RANKS_AND_SIZES(Minimum);
765 // Elementwise Unary Operators
767 CHECK_RANKS_AND_SIZES(BitwiseNot);
774 CHECK_RANKS_AND_SIZES(LogicalNot);
775 CHECK_RANKS_AND_SIZES(Negate);
776 CHECK_RANKS_AND_SIZES(Reciprocal);
779 // Elementwise Ternary Operators
780 CHECK_RANKS_AND_SIZES(Select);
781 // Comparison Operators
783 CHECK_RANKS_AND_SIZES(Greater);
784 CHECK_RANKS_AND_SIZES(GreaterEqual);
785 // Reduction Operators
786 CHECK_RANKS_AND_SIZES(ReduceAll);
787 CHECK_RANKS_AND_SIZES(ReduceAny);
788 CHECK_RANKS_AND_SIZES(ReduceMax);
789 CHECK_RANKS_AND_SIZES(ReduceMin);
790 CHECK_RANKS_AND_SIZES(ReduceProduct);
791 CHECK_RANKS_AND_SIZES(ReduceSum);
792 // Data Layout Operators
793 CHECK_RANKS_AND_SIZES(Concat);
795 CHECK_RANKS_AND_SIZES(Reshape);
796 CHECK_RANKS_AND_SIZES(ReshapeBlockScaled);
797 CHECK_RANKS_AND_SIZES(Reverse);
800 CHECK_RANKS_AND_SIZES(Transpose);
801 // Type Conversion
803 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
804 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
805 CHECK_RANKS_AND_SIZES(Rescale);
806 // Data Nodes
808 CHECK_RANKS_AND_SIZES(Identity);
809 // Control Flow Operators
811 // Variable Operators
812 CHECK_RANKS_AND_SIZES(Variable);
813 CHECK_RANKS_AND_SIZES(VariableWrite);
814 CHECK_RANKS_AND_SIZES(VariableRead);
815 // Shape Operators
817
818 // For the following operators, check whether the size of each tensor
819 // operand is valid in a given Level.
820
821 // Tensor Operators
822 CHECK_SIZES(AvgPool2d);
823 CHECK_SIZES(AvgPool2dAdaptive);
824 CHECK_SIZES(Conv2D);
825 CHECK_SIZES(Conv2DBlockScaled);
826 CHECK_SIZES(Conv3D);
827 CHECK_SIZES(DepthwiseConv2D);
828 CHECK_SIZES(TransposeConv2D);
829 CHECK_SIZES(FFT2d);
830 CHECK_SIZES(MatMul);
831 CHECK_SIZES(MatmulTBlockScaled);
832 CHECK_SIZES(MaxPool2d);
833 CHECK_SIZES(MaxPool2dAdaptive);
834 CHECK_SIZES(RFFT2d);
835 // Scatter/Gather Operators
837 CHECK_SIZES(Scatter);
838 // Image Operators
839 CHECK_SIZES(Resize);
840 // Custom Operators
841 CHECK_SIZES(Custom);
842 // Control Flow Operators
843 CHECK_SIZES(While);
844 // Shape Operators
845 CHECK_SIZES(ConstShape);
846
847 // For the following operations, check whether the shape length of each
848 // operand is valid given a level.
849
850 // Shape Operators
851 CHECK_SHAPE_LEN(AddShape);
852 CHECK_SHAPE_LEN(AssertEqualShape);
853 CHECK_SHAPE_LEN(ConcatShape);
854 CHECK_SHAPE_LEN(DivCeilShape);
855 CHECK_SHAPE_LEN(DivFloorShape);
856 CHECK_SHAPE_LEN(Exp2Shape);
857 CHECK_SHAPE_LEN(Log2CeilShape);
858 CHECK_SHAPE_LEN(Log2FloorShape);
859 CHECK_SHAPE_LEN(MaxShape);
860 CHECK_SHAPE_LEN(MinShape);
861 CHECK_SHAPE_LEN(ModShape);
862 CHECK_SHAPE_LEN(MulShape);
863 CHECK_SHAPE_LEN(SliceShape);
864 CHECK_SHAPE_LEN(SubShape);
865
866#undef CHECK_RANKS_AND_SIZES
867#undef CHECK_SIZES
868#undef CHECK_SHAPE_LEN
869 return success();
870}
871
872// Perform the Level tensor size check on the tensor type.
873LogicalResult TosaValidation::levelCheckSize(Operation *op,
874 const Type &typeToCheck,
875 const StringRef operandOrResult) {
876 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
877 if (!type.hasRank())
878 return op->emitOpError() << "failed level check: unranked tensor";
879 auto shape = type.getShape();
880 for (auto dim : shape) {
881 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
882 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
883 const TosaSpecificationVersion minRequiredVersion(1, 1, true);
884 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
885 dimIsDynamic)
886 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
887 // resolved at backend compile time. Runtime dynamism is not currently
888 // supported. Checking this requirement is met is delegated to backends.
889 return success();
890
891 // When targeting TOSA 1.0 or below, dynamic dims are not supported
892 if (dimIsDynamic)
893 return op->emitOpError() << "failed level check: " << operandOrResult
894 << " shape dimension cannot be dynamic when"
895 << " targeting TOSA specification version 1.0"
896 << " or below";
897 }
898
899 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
900 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
901 int64_t size = element_bytes * type.getNumElements();
902
903 // According to 1.11. Tensor Definitions of Tosa spec, the value of
904 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
905 // defined in 1.7. Levels.
906 // For each tensor, the number of tensor elements multiplied by the
907 // element size in bytes must be representable as a tensor_size_t.
908 const int64_t max_size =
909 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
910 if (size > max_size)
911 return op->emitOpError()
912 << "failed level check: " << operandOrResult
913 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
914 }
915 return success();
916}
917
918LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
919 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
920 // no need to do level checks
921 return success();
922 }
923
924 // check rank and sizes early so later checks can assume shaped operands
925 if (failed(levelCheckRanksAndSizes(op)))
926 return failure();
927
928 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
929 failed(levelCheckAdaptivePool<tosa::AvgPool2dAdaptiveOp>(op)) ||
930 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
931 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
932 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
933 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
934 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
935 failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
936 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
937 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
938 failed(levelCheckConv2DBlockScaled(op))) {
939 return failure();
940 }
941
942 // level check MAX_TENSOR_LIST_SIZE
943 if (failed(levelCheckListSize(op))) {
944 return failure();
945 }
946
947 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
948 if (failed(levelCheckMaxNesting(op))) {
949 return failure();
950 }
951 }
952
953 return success();
954}
955
956LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
957 if (failed(attributeCheckRescale(op)))
958 return failure();
959 return success();
960}
961
962inline bool CompatibleTypes(const mlir::Type &type,
963 const mlir::Type &declaredType) {
964 // for now, simply use type equality comparison
965 return type == declaredType;
966}
967
968LogicalResult TosaValidation::CheckVariable(Operation *op) {
969 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
970 mlir::StringAttr nameAttr = variableOp.getNameAttr();
971
972 if (variablesMap.count(nameAttr))
973 return op->emitOpError() << "name has already been declared";
974
975 auto elementType = variableOp.getType();
976 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
977 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
978 RankedTensorType variableType =
979 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
980
981 variablesMap[nameAttr] = variableType;
982 }
983
984 return success();
985}
986
987LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
988 if (isa<mlir::tosa::VariableReadOp>(op) ||
989 isa<mlir::tosa::VariableWriteOp>(op)) {
990 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
991 if (!variablesMap.count(nameAttr))
992 return op->emitOpError() << "name has not been declared";
993
994 auto varType = variablesMap[nameAttr];
995
996 for (auto v : op->getOperands()) {
997 auto type = v.getType();
998 if (!CompatibleTypes(type, varType))
999 return op->emitOpError() << "operand type does not equal variable type";
1000 }
1002 for (auto v : op->getResults()) {
1003 auto type = v.getType();
1004 if (!CompatibleTypes(type, varType))
1005 return op->emitOpError() << "result type does not equal variable type";
1007 }
1008
1009 return success();
1010}
1011
1012LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
1013 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
1014 return failure();
1015 return success();
1016}
1018LogicalResult checkErrorIfResize(Operation *op) {
1019 auto resize = dyn_cast<tosa::ResizeOp>(op);
1020 if (!resize)
1021 return success();
1022
1023 const Value input = resize.getInput();
1024 const Value output = resize.getOutput();
1025 const RankedTensorType inputType =
1026 llvm::dyn_cast<RankedTensorType>(input.getType());
1027 const RankedTensorType outputType =
1028 llvm::dyn_cast<RankedTensorType>(output.getType());
1029
1030 if (!inputType || !outputType)
1031 return op->emitOpError("expect ranked input/output tensor");
1032
1033 // Ensure the image size is supported by GPU APIs and that for integer
1034 // implementations, position * stride does not overflow int32_t.
1035 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1036 const SmallVector<int64_t, 4> sizes = {
1037 outputType.getDimSize(1), outputType.getDimSize(2),
1038 inputType.getDimSize(1), inputType.getDimSize(2)};
1039 const int64_t *maxDim = llvm::max_element(sizes);
1040 if (maxDim != sizes.end() && *maxDim >= 16384)
1041 return op->emitOpError(
1042 "expect input/output height/width dims to be < 16384, ")
1043 << "got [OH, OW, IH, IW] = " << sizes;
1044 }
1045
1047 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
1048 return failure();
1049
1050 const int64_t scaleYN = scale[0];
1051 const int64_t scaleYD = scale[1];
1052 const int64_t scaleXN = scale[2];
1053 const int64_t scaleXD = scale[3];
1054
1055 // Ensure scale values don't overflow int32 accumulator
1056 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
1057 return op->emitOpError(
1058 "expect all scale numerator values to be <= (1 << 11), "
1059 "got scale_y_n=")
1060 << scaleYN << ", scale_x_n=" << scaleXN;
1062 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
1063 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
1064 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
1068 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
1069 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
1070 return failure();
1071
1072 const int64_t offsetY = offset[0];
1073 const int64_t offsetX = offset[1];
1074 // Set a consistent lower limit of 1/16 downscale to simplify
1075 // implementations
1076 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1077 return op->emitOpError(
1078 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1079 << offsetY << "/" << scaleYN;
1080 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1081 return op->emitOpError(
1082 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1083 << offsetX << "/" << scaleXN;
1084
1085 const int64_t borderY = border[0];
1086 const int64_t borderX = border[1];
1087 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1088 return op->emitOpError(
1089 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1090 << borderY << "/" << scaleYN;
1091 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1092 return op->emitOpError(
1093 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1094 << borderX << "/" << scaleXN;
1095
1096 // The following section of code is mostly duplicated with ResizeOp::verify().
1097 //
1098 // In TOSA specification, we do not support broadcast behavior.
1099 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
1100 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
1101 // existing code, we keep the rewrite pattern untouched. So, we need
1102 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
1103 //
1104 // Here is a strict checking to conform TOSA specification.
1105 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
1106 auto idivCheck = [](const int64_t lhs,
1107 const int64_t rhs) -> std::optional<int64_t> {
1108 if (lhs % rhs != 0)
1109 return std::nullopt;
1110 return lhs / rhs;
1111 };
1112
1113 const int64_t oh = outputType.getDimSize(1);
1114 const int64_t ow = outputType.getDimSize(2);
1115 const int64_t ih = inputType.getDimSize(1);
1116 const int64_t iw = inputType.getDimSize(2);
1117
1118 if (ih != ShapedType::kDynamic) {
1119 const std::optional<int64_t> calculatedOutHeightMinusOne =
1120 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1121 if (!calculatedOutHeightMinusOne.has_value())
1122 return op->emitOpError(
1123 "expected (input_height - 1) * scale_y_n - offset_y + "
1124 "border_y ")
1125 << "to be wholly divisible by scale_y_d, got ((" << ih
1126 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1127 << ") / " << scaleYD;
1128 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1129 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1130 return op->emitOpError(
1131 "calculated output height did not match expected: ")
1132 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1133 }
1134
1135 if (iw != ShapedType::kDynamic) {
1136 const std::optional<int64_t> calculatedOutWidthMinusOne =
1137 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1138 if (!calculatedOutWidthMinusOne.has_value())
1139 return op->emitOpError(
1140 "expected (input_width - 1) * scale_x_n - offset_x + "
1141 "border_x ")
1142 << "to be wholly divisible by scale_x_d, got ((" << iw
1143 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1144 << ") / " << scaleXD;
1145 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1146 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1147 return op->emitOpError("calculated output width did not match expected: ")
1148 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1149 }
1150
1151 return success();
1152}
1153
1154LogicalResult checkErrorIfMul(Operation *op) {
1155 auto mul = dyn_cast<tosa::MulOp>(op);
1156 if (!mul)
1157 return success();
1158
1159 // REQUIRE(0 <= shift && shift <= 63);
1160 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1161 ElementsAttr shift_elem;
1162 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
1163 return success();
1164 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1165 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1166 if (inputElemType.isInteger(32)) {
1167 // 0 <= shift <= 63 for int32_t type
1168 if (shift < 0 || shift > 63)
1169 return op->emitOpError()
1170 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1171 } else {
1172 // shift must be 0 for all other types
1173 if (shift != 0)
1174 return op->emitOpError()
1175 << "requires shift = 0 for all input data types that "
1176 "are not int32_t, but got: "
1177 << shift;
1178 }
1179
1180 return success();
1181}
1182
1183LogicalResult checkErrorIfTable(Operation *op) {
1184 auto table = dyn_cast<tosa::TableOp>(op);
1185 if (!table)
1186 return success();
1187
1188 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1189 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1190 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1191
1192 const ShapeAdaptor tableShape(table.getTable().getType());
1193 if (tableShape.hasStaticShape()) {
1194 const auto numElements = tableShape.getNumElements();
1195 if (numElements != tableSize)
1196 return op->emitOpError() << "requires table size of " << tableSize
1197 << ", got " << numElements;
1198 }
1199
1200 return success();
1201}
1202
1203LogicalResult checkErrorIfRescale(Operation *op) {
1204 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1205 if (!rescale)
1206 return success();
1207
1208 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1209 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1210 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1211 !outputType.getElementType().isInteger())
1212 return success();
1213
1214 auto inElemType = inputType.getElementType();
1215 auto outElemType = outputType.getElementType();
1216 auto inWidth = inElemType.getIntOrFloatBitWidth();
1217 auto outWidth = outElemType.getIntOrFloatBitWidth();
1218
1219 bool inputUnsigned = rescale.getInputUnsigned();
1220 bool outputUnsigned = rescale.getOutputUnsigned();
1221
1222 bool scale32 = rescale.getScale32();
1223 auto roundingMode = rescale.getRoundingMode();
1224
1225 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1226 if (scale32 && inWidth == 48)
1227 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1228
1229 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1230 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1231 return op->emitOpError()
1232 << "DOUBLE_ROUND is only allowed with scale32=true.";
1233
1234 // ERROR_IF(input_unsigned && output_unsigned)
1235 if (inputUnsigned && outputUnsigned)
1236 return op->emitOpError() << "input and output cannot be both unsigned.";
1237
1238 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1239 if (outWidth == 32 && inputUnsigned)
1240 return op->emitOpError()
1241 << "i32 output type is not allowed with unsigned input.";
1242
1243 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1244 if (inWidth == 32 && outputUnsigned)
1245 return op->emitOpError()
1246 << "i32 input type is not allowed with unsigned output.";
1247
1248 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1249 if (inWidth == 48 && outputUnsigned)
1250 return op->emitOpError()
1251 << "i48 input type is not allowed with unsigned output.";
1252
1253 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1254 if (inWidth == 48 && inputUnsigned)
1255 return op->emitOpError() << "i48 input type cannot be unsigned.";
1256
1257 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1258 if (inWidth == 32 && inputUnsigned)
1259 return op->emitOpError() << "i32 input type cannot be unsigned.";
1260
1261 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1262 if (outWidth == 32 && outputUnsigned)
1263 return op->emitOpError() << "i32 output type cannot be unsigned.";
1264
1265 return success();
1266}
1267
1268LogicalResult checkErrorIfPad(Operation *op) {
1269 auto pad = dyn_cast<tosa::PadOp>(op);
1270 if (!pad)
1271 return success();
1272
1273 DenseIntElementsAttr paddingAttr;
1274 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1275 // Pad verifier will catch this
1276 return success();
1277
1278 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1279 if (val.getSExtValue() < 0)
1280 return op->emitOpError() << "padding value must all be non-negative, got "
1281 << val.getSExtValue();
1282 }
1283
1284 return success();
1285}
1286
1287LogicalResult checkErrorIfReshape(Operation *op) {
1288 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1289 if (!reshapeOp)
1290 return success();
1291
1292 SmallVector<int64_t> shapeValues;
1293 if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
1294 shapeValues))
1295 return success();
1296
1297 if (llvm::is_contained(shapeValues, kInferableDimSize))
1298 return op->emitOpError("shape input contains inferable dimension (")
1300 << ") "
1301 "which does not conform to the TOSA specification";
1302
1303 return success();
1304}
1305
1306LogicalResult checkErrorIfSlice(Operation *op) {
1307 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1308 if (!sliceOp)
1309 return success();
1310
1311 SmallVector<int64_t> startValues;
1312 SmallVector<int64_t> sizeValues;
1313 const bool hasStartValues = tosa::getConstShapeValues(
1314 sliceOp.getStart().getDefiningOp(), startValues);
1315 const bool hasSizeValues =
1316 tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
1317
1318 if (hasStartValues && llvm::is_contained(startValues, kInferableDimSize))
1319 return op->emitOpError("start input contains inferable dimension (")
1321 << ") which does not conform to the TOSA specification";
1322 if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDimSize))
1323 return op->emitOpError("size input contains inferable dimension (")
1325 << ") which "
1326 "does not conform to the TOSA specification";
1327
1328 return success();
1329}
1330
1331static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1332 return llvm::all_of(op->getOperands(), [&](auto operand) {
1333 Region *operandRegion = operand.getParentRegion();
1334 return operandRegion && region->isAncestor(operandRegion);
1335 });
1336}
1337
1338static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1339 bool noLiveInValue = true;
1340 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1341 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1342 noLiveInValue = false;
1343 return WalkResult::interrupt();
1344 }
1345 return WalkResult::advance();
1346 });
1347 return noLiveInValue ? success() : failure();
1348}
1349
1350LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1351 StringRef regionName) {
1352 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1353 return success();
1354 return op->emitOpError()
1355 << "is not conformant to the TOSA specification. It requires the '"
1356 << regionName << "' region is isolated from above.\n";
1357}
1358
1359LogicalResult checkErrorIfCondIf(Operation *op) {
1360 auto ifOp = dyn_cast<tosa::IfOp>(op);
1361 if (!ifOp)
1362 return success();
1363
1364 // Currently the dialect supports declaring cond_if operations that
1365 // have then/else regions that reference values from outside these
1366 // regions. According to the specification, all values used by the
1367 // then/else regions must be explicitly declared within the regions.
1368 // Therefore we must check that the then/else regions are
1369 // "isolated from above", in order to be conformant to the
1370 // specification.
1371 //
1372 // Note: the dialect currently supports two styles of syntax for
1373 // declaring "cond_if" operations. We'll refer to these as follows:
1374 //
1375 // Generic:
1376 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1377 // ^bb0(%arg3, %arg4):
1378 // tosa.yield %arg3
1379 // }, {
1380 // ^bb0(%arg3, %arg4):
1381 // tosa.yield %arg4
1382 // })
1383 //
1384 // Simplified:
1385 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1386 // ^bb0(%arg3, %arg4):
1387 // tosa.yield %arg3
1388 // } else {
1389 // ^bb0(%arg3, %arg4):
1390 // tosa.yield %arg4
1391 // }
1392
1393 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1394 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1395 return failure();
1396 return success();
1397}
1398
1399LogicalResult checkErrorIfWhileLoop(Operation *op) {
1400 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1401 if (!whileOp)
1402 return success();
1403
1404 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1405 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1406 return failure();
1407 return success();
1408}
1409
1410LogicalResult checkErrorIfScatter(Operation *op) {
1411 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1412 if (!scatterOp)
1413 return success();
1414
1415 // for constant indices, check that there are no duplicate values
1416 DenseIntElementsAttr indicesAttr;
1417 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1418 return success();
1419
1420 auto const indicesType =
1421 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1422 if (!indicesType || !indicesType.hasRank()) {
1423 op->emitOpError("expect ranked indices tensor");
1424 return failure();
1425 }
1426
1427 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1428 op->emitOpError("indices values contain duplicates");
1429 return failure();
1430 }
1431
1432 return success();
1433}
1434
1435LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1436 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1437 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1438 failed(checkErrorIfPad(op)) || failed(checkErrorIfReshape(op)) ||
1439 failed(checkErrorIfSlice(op)) || failed(checkErrorIfCondIf(op)) ||
1440 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1441 return failure();
1442 return success();
1443}
1444
1445LogicalResult TosaValidation::applyFunctionSignatureCheck(func::FuncOp op) {
1446 const auto isShapeType = [](Type type) { return isa<tosa::shapeType>(type); };
1447 if (llvm::any_of(op.getArgumentTypes(), isShapeType))
1448 return op.emitOpError()
1449 << "Function argument types must be a tensor type to be TOSA "
1450 "compliant, got !tosa.shape type";
1451 if (llvm::any_of(op.getResultTypes(), isShapeType))
1452 return op.emitOpError()
1453 << "Function return types must be a tensor type to be TOSA "
1454 "compliant, got !tosa.shape type";
1455 return success();
1456}
1457
1458bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1459 if (isa<FloatType>(type)) {
1460 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1461 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1462 Float6E3M2FNType, Float8E8M0FNUType>(type);
1463 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1464 if (intTy.isSignless()) {
1465 switch (intTy.getWidth()) {
1466 case 1:
1467 case 4:
1468 case 8:
1469 case 16:
1470 case 32:
1471 case 48:
1472 case 64:
1473 return true;
1474 }
1475 } else if (allowUnsigned && intTy.isUnsigned()) {
1476 switch (intTy.getWidth()) {
1477 case 8:
1478 case 16:
1479 case 32:
1480 return true;
1481 }
1482 }
1483 } else if (isa<tosa::shapeType>(type))
1484 return true;
1485 else if (isa<tosa::mxint8Type>(type))
1486 return true;
1487 return false;
1488}
1489
1490void TosaValidation::runOnOperation() {
1491 ModuleOp modOp = getOperation();
1492 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1493 if (!tosaDialect)
1494 return;
1495
1496 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1497 const auto maybeTargetEnv =
1498 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1499 if (failed(maybeTargetEnv))
1500 return signalPassFailure();
1501 targetEnv = *maybeTargetEnv;
1502
1503 const auto functions = modOp.getOps<func::FuncOp>();
1504 if (llvm::any_of(functions, [&](func::FuncOp func) {
1505 return failed(applyFunctionSignatureCheck(func));
1506 }))
1507 return signalPassFailure();
1508
1509 modOp.walk([&](Operation *op) {
1510 if (op->getDialect() != tosaDialect)
1511 return;
1512
1513 // validate operator element types:
1514 // - rescale operator is allowed to have ui8/ui16/ui32
1515 // operands/results when strictOpSpecAlignment is false
1516 // - perform valid element type check at the beginning to
1517 // protect rest of code against quantized element types
1518 const bool allowUnsigned =
1519 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1520 for (Value operand : op->getOperands()) {
1521 auto elementTy = getElementTypeOrSelf(operand);
1522 if (!isValidElementType(elementTy, allowUnsigned)) {
1523 op->emitOpError() << "is not profile-aligned: element type "
1524 << elementTy << " is not legal";
1525 return signalPassFailure();
1526 }
1527 }
1528 for (Type resultTy : op->getResultTypes()) {
1529 auto elementTy = getElementTypeOrSelf(resultTy);
1530 if (!isValidElementType(elementTy, allowUnsigned)) {
1531 op->emitOpError() << "is not profile-aligned: element type "
1532 << elementTy << " is not legal";
1533 return signalPassFailure();
1534 }
1535 }
1536
1537 if (strictOpSpecAlignment &&
1538 failed(profileComp.checkProfile(op, targetEnv)))
1539 return signalPassFailure();
1540
1541 if (strictOpSpecAlignment &&
1542 failed(profileComp.checkExtension(op, targetEnv)))
1543 return signalPassFailure();
1544
1545 if (!allowInvalidOpDatatypeCombinations &&
1546 failed(profileComp.checkInvalid(op)))
1547 return signalPassFailure();
1548
1549 // Some uses of TOSA rely on the constant operands of particular
1550 // operations.
1551 if (failed(applyConstantOperandCheck(op)))
1552 signalPassFailure();
1553
1554 // do level checks
1555 if (failed(applyLevelCheck(op)))
1556 signalPassFailure();
1557
1558 // check additional attribute restrictions
1559 if (failed(applyAttributeCheck(op)))
1560 signalPassFailure();
1561
1562 // do variable type checks
1563 if (failed(applyVariableCheck(op)))
1564 signalPassFailure();
1565
1566 // do error if checks
1567 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1568 signalPassFailure();
1569 });
1570}
1571} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:578
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Value getOperand(unsigned idx)
Definition Operation.h:376
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:117
TosaLevel getLevel() const
Definition TargetEnv.h:134
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:92
bool allows(Profile prof) const
Definition TargetEnv.h:144
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:130
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:45
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:630
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369