MLIR 23.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1//===- TosaValidation.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Validate if TOSA dialect input matches with the specification for given
10// requirements.
11//
12//===----------------------------------------------------------------------===//
13
17
18#include <string>
19#include <type_traits>
20
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Matchers.h"
28#include "mlir/Pass/Pass.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/Support/FormatVariadic.h"
33
34namespace mlir {
35namespace tosa {
36#define GEN_PASS_DEF_TOSAVALIDATION
37#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
38} // namespace tosa
39} // namespace mlir
40
41using namespace mlir;
42using namespace mlir::tosa;
43
44namespace {
45
46static LogicalResult
47checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
48 for (const auto index : operandIndices) {
49 Attribute attr;
50 if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
51 return op->emitOpError("expected compile time resolvable constant, but "
52 "got variable value for operand #")
53 << index;
54 }
55 }
56 return success();
57}
58
59static LogicalResult checkConstantOperandMul(Operation *op,
60 const TargetEnv &env) {
61 if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
62 // Check 'shift'
63 return checkConstantOperands(op, {2});
64 }
65 return success();
66}
67
68static LogicalResult checkConstantOperandTable(Operation *op,
69 const TargetEnv &env) {
70 if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
71 // Check 'table'
72 return checkConstantOperands(op, {1});
73 }
74 return success();
75}
76
77static LogicalResult checkConstantOperandPad(Operation *op,
78 const TargetEnv &env) {
79 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
80 // Assume this op is zero-padding if padConst is not presented
81 if (!env.allows(Extension::dynamic) && padOp.getPadConst())
82 // Check 'pad_const'
83 // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
84 return checkConstantOperands(op, {2});
85 }
86 return success();
87}
88
89static LogicalResult checkConstantOperandRescale(Operation *op,
90 const TargetEnv &env) {
91 if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
92 // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
93 return checkConstantOperands(op, {1, 2, 3, 4});
94 }
95 return success();
96}
97
98template <typename T>
99static LogicalResult checkConstantOperandConvOps(Operation *op,
100 const TargetEnv &env) {
101 if (!env.allows(Extension::dynamic) && isa<T>(op)) {
102 // Check 'input_zp' and 'weight_zp'
103 return checkConstantOperands(op, {3, 4});
104 }
105 return success();
106}
107
108static LogicalResult checkConstantOperandMatMul(Operation *op,
109 const TargetEnv &env) {
110 if (!env.allows(Extension::dynamic) &&
111 isa<tosa::MatMulOp, tosa::MatMulTOp>(op)) {
112 // Check 'A_zp' and 'B_zp'
113 return checkConstantOperands(op, {2, 3});
114 }
115 return success();
116}
117
118static LogicalResult
119checkConstantOperandRowGatherBlockScaled(Operation *op, const TargetEnv &env) {
120 if (!env.allows(Extension::dynamic) &&
121 isa<tosa::RowGatherBlockScaledOp>(op)) {
122 auto rowGatherOp = cast<tosa::RowGatherBlockScaledOp>(op);
123 const unsigned rowCountIndex = rowGatherOp.getValues().size() + 1;
124 return checkConstantOperands(op, {rowCountIndex});
125 }
126 return success();
127}
128
129static LogicalResult checkConstantOperandRowGather(Operation *op,
130 const TargetEnv &env) {
131 if (!env.allows(Extension::dynamic) && isa<tosa::RowGatherOp>(op)) {
132 // Check 'row_count'
133 return checkConstantOperands(op, {2});
134 }
135 return success();
136}
137
138static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
139 const TargetEnv &env) {
140 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
141 // Check 'input_zp' and 'output_zp'
142 return checkConstantOperands(op, {1, 2});
143 }
144 return success();
145}
146
147static LogicalResult
148checkConstantOperandAvgPool2dAdaptive(Operation *op, const TargetEnv &env) {
149 if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dAdaptiveOp>(op)) {
150 // Check 'input_zp' and 'output_zp'.
151 // Note: 'kernel', 'stride', and 'pad' (operands 3, 4, 5) are not checked
152 // as they are tosa.shape types.
153 return checkConstantOperands(op, {1, 2});
154 }
155 return success();
156}
157
158static LogicalResult checkConstantOperandNegate(Operation *op,
159 const TargetEnv &env) {
160 if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
161 // Check 'input1_zp' and 'output_zp'
162 return checkConstantOperands(op, {1, 2});
163 }
164 return success();
165}
166
167static LogicalResult checkConstantOperandSilceShape(Operation *op,
168 const TargetEnv &env) {
169 if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
170 // Check 'start' and 'size'
171 return checkConstantOperands(op, {1, 2});
172 }
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// TOSA Validation Pass.
178//===----------------------------------------------------------------------===//
179
180struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
181public:
182 explicit TosaValidation() { populateConstantOperandChecks(); }
183
184 explicit TosaValidation(const TosaValidationOptions &options)
185 : TosaValidation() {
186 this->strictOpSpecAlignment = options.strictOpSpecAlignment;
187 this->allowInvalidOpDatatypeCombinations =
188 options.allowInvalidOpDatatypeCombinations;
189 }
190 void runOnOperation() final;
191
192 LogicalResult applyConstantOperandCheck(Operation *op) {
193 for (auto &checker : constCheckers) {
194 if (failed(checker(op, targetEnv)))
195 return failure();
196 }
197 return success();
198 }
199
200 LogicalResult applyFunctionSignatureCheck(func::FuncOp op);
201 LogicalResult applyLevelCheck(Operation *op);
202 LogicalResult applyAttributeCheck(Operation *op);
203
204 // check variable read/write data types against variable declarations
205 LogicalResult applyVariableCheck(Operation *op);
206
207 // check error if conditions
208 LogicalResult applyErrorIfCheck(Operation *op);
209
210private:
211 void populateConstantOperandChecks() {
212 constCheckers.emplace_back(checkConstantOperandMul);
213 constCheckers.emplace_back(checkConstantOperandTable);
214 constCheckers.emplace_back(checkConstantOperandPad);
215 constCheckers.emplace_back(checkConstantOperandRescale);
216 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
217 constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
218 constCheckers.emplace_back(
219 checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
220 constCheckers.emplace_back(
221 checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
222 constCheckers.emplace_back(checkConstantOperandMatMul);
223 constCheckers.emplace_back(checkConstantOperandRowGather);
224 constCheckers.emplace_back(checkConstantOperandRowGatherBlockScaled);
225 constCheckers.emplace_back(checkConstantOperandAvgPool2d);
226 constCheckers.emplace_back(checkConstantOperandAvgPool2dAdaptive);
227 constCheckers.emplace_back(checkConstantOperandNegate);
228 constCheckers.emplace_back(checkConstantOperandSilceShape);
229 }
230
231 LogicalResult levelCheck(Operation *op, const int32_t calculatedValue,
232 const int32_t maxLevel, const StringRef inputName,
233 const StringRef levelName) {
234 if (calculatedValue > maxLevel)
235 return op->emitOpError()
236 << "failed level check: " << inputName << " <= " << levelName
237 << " (" << maxLevel << "), got " << calculatedValue;
238 return success();
239 }
240
241 LogicalResult levelCheckKernel(Operation *op, int32_t v,
242 const StringRef inputName) {
243 return levelCheck(op, v, targetEnv.getLevel().MAX_KERNEL, inputName,
244 "MAX_KERNEL");
245 }
246
247 LogicalResult levelCheckStride(Operation *op, int32_t v,
248 const StringRef inputName) {
249 return levelCheck(op, v, targetEnv.getLevel().MAX_STRIDE, inputName,
250 "MAX_STRIDE");
251 }
252
253 LogicalResult levelCheckScale(Operation *op, int32_t v,
254 const StringRef inputName) {
255 return levelCheck(op, v, targetEnv.getLevel().MAX_SCALE, inputName,
256 "MAX_SCALE");
257 }
258
259 LogicalResult levelCheckListSize(Operation *op, int32_t v,
260 const StringRef inputName) {
261 const std::string inputDesc =
262 llvm::formatv("length(tensor_list_shape({0}))", inputName);
263 return levelCheck(op, v, targetEnv.getLevel().MAX_TENSOR_LIST_SIZE,
264 inputDesc, "MAX_TENSOR_LIST_SIZE");
265 }
266
267 // Perform the Level Rank check on the tensor type.
268 LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
269 const StringRef operandOrResult,
270 int32_t highest_rank) {
271 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
272 if (!type.hasRank())
273 return op->emitOpError() << "failed level check: unranked tensor";
274 if (type.getRank() > highest_rank)
275 return op->emitOpError() << "failed level check: " << operandOrResult
276 << " rank(shape) <= MAX_RANK";
277 }
278 return success();
279 }
280
281 // Perform the Level Rank check on the tensor value.
282 LogicalResult levelCheckRank(Operation *op, const Value &v,
283 const StringRef operandOrResult,
284 int32_t highest_rank) {
285 return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
286 }
287
288 // Perform the Level tensor size check on the tensor type.
289 LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
290 const StringRef operandOrResult);
291
292 // Perform the Level tensor size check on the tensor value.
293 LogicalResult levelCheckSize(Operation *op, const Value &v,
294 const StringRef operandOrResult) {
295 return levelCheckSize(op, v.getType(), operandOrResult);
296 }
297
298 // Perform the Level shape length check on a value.
299 LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
300 const StringRef operandOrResult) {
301 if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
302 if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
303 return op->emitOpError()
304 << "failed shape type level check: " << typeToCheck
305 << " exceeds MAX_SHAPE_LEN";
306 }
307 return success();
308 }
309
310 // Level check sizes of all operands and results of the operation.
311 template <typename T>
312 LogicalResult levelCheckSizes(T tosaOp) {
313 auto op = tosaOp.getOperation();
314 for (auto v : op->getOperands()) {
315 if (failed(levelCheckSize(op, v, "operand")))
316 return failure();
317 }
318
319 for (auto v : op->getResults()) {
320 if (failed(levelCheckSize(op, v, "result")))
321 return failure();
322 }
323 return success();
324 }
325
326 // Level check ranks of all operands, attribute and results of the operation.
327 template <typename T>
328 LogicalResult levelCheckRanks(T tosaOp) {
329 auto op = tosaOp.getOperation();
330 const TosaLevel tosaLevel = targetEnv.getLevel();
331 for (auto v : op->getOperands()) {
332 if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
333 return failure();
334 }
335
336 for (auto v : op->getResults()) {
337 if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
338 return failure();
339 }
340 return success();
341 }
342 // Level check shape lengths of all operands and results of an operation that
343 // are tosa.shape type.
344 template <typename T>
345 LogicalResult levelCheckShapeLengths(T tosaOp) {
346 for (const auto &v : tosaOp->getOperands()) {
347 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
348 return failure();
349 }
350 for (const auto &v : tosaOp->getResults()) {
351 if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
352 return failure();
353 }
354
355 return success();
356 }
357
358 // Level check ranks and sizes.
359 LogicalResult levelCheckRanksAndSizes(Operation *op);
360
361 // Pool Op: level check kernel/stride/pad values
362 template <typename T>
363 LogicalResult levelCheckPool(Operation *op) {
364 if (auto poolOp = dyn_cast<T>(op)) {
365 for (auto k : poolOp.getKernel()) {
366 if (failed(levelCheckKernel(op, k, "kernel"))) {
367 return failure();
368 }
369 }
370 for (auto s : poolOp.getStride()) {
371 if (failed(levelCheckStride(op, s, "stride"))) {
372 return failure();
373 }
374 }
375 for (auto p : poolOp.getPad()) {
376 if (failed(levelCheckKernel(op, p, "pad"))) {
377 return failure();
378 }
379 }
380 }
381 return success();
382 }
383
384 template <typename T>
385 static constexpr bool IsSupportedAdaptivePoolOp =
386 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
387 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
388
389 template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
390 int>::type = 0>
391 LogicalResult levelCheckAdaptivePool(Operation *op) {
392 auto poolOp = dyn_cast<T>(op);
393 if (!poolOp)
394 return success();
395
396 SmallVector<int64_t> kernelValues;
397 if (tosa::getConstShapeValues(poolOp.getKernel().getDefiningOp(),
398 kernelValues)) {
399 for (const auto k : kernelValues)
400 if (failed(levelCheckKernel(op, k, "kernel")))
401 return failure();
402 }
403
404 SmallVector<int64_t> strideValues;
405 if (tosa::getConstShapeValues(poolOp.getStride().getDefiningOp(),
406 strideValues)) {
407 for (const auto s : strideValues)
408 if (failed(levelCheckStride(op, s, "stride")))
409 return failure();
410 }
411
412 SmallVector<int64_t> padValues;
413 if (tosa::getConstShapeValues(poolOp.getPad().getDefiningOp(), padValues)) {
414 for (const auto p : padValues)
415 if (failed(levelCheckKernel(op, p, "pad")))
416 return failure();
417 }
418
419 return success();
420 }
421
422 // Conv Op: level check dilation/stride/pad values
423 template <typename T>
424 LogicalResult levelCheckConv(Operation *op) {
425 if (auto convOp = dyn_cast<T>(op)) {
426
427 for (auto k : convOp.getDilation()) {
428 if (failed(levelCheckKernel(op, k, "dilation"))) {
429 return failure();
430 }
431 }
432 for (auto p : convOp.getPad()) {
433 if (failed(levelCheckKernel(op, p, "pad"))) {
434 return failure();
435 }
436 }
437 for (auto s : convOp.getStride()) {
438 if (failed(levelCheckStride(op, s, "stride"))) {
439 return failure();
440 }
441 }
442 auto dilation = convOp.getDilation();
443 if (ShapedType weightType =
444 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
445 auto shape = weightType.getShape();
446 if (isa<tosa::Conv2DOp>(op)) {
447 assert(shape.size() == 4);
448 assert(dilation.size() == 2);
449 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
450 "dilation_y * KH")) ||
451 failed(levelCheckKernel(op, dilation[1] * shape[2],
452 "dilation_x * KW")))
453 return failure();
454 } else if (isa<tosa::Conv3DOp>(op)) {
455 assert(shape.size() == 5);
456 assert(dilation.size() == 3);
457 if (failed(levelCheckKernel(op, dilation[0] * shape[1],
458 "dilation_d * KD")) ||
459 failed(levelCheckKernel(op, dilation[1] * shape[2],
460 "dilation_y * KH")) ||
461 failed(levelCheckKernel(op, dilation[2] * shape[3],
462 "dilation_x * KW")))
463 return failure();
464 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
465 assert(shape.size() == 4);
466 assert(dilation.size() == 2);
467 if (failed(levelCheckKernel(op, dilation[0] * shape[0],
468 "dilation_y * KH")) ||
469 failed(levelCheckKernel(op, dilation[1] * shape[1],
470 "dilation_x * KW")))
471 return failure();
472 }
473 }
474 }
475 return success();
476 }
477
478 LogicalResult levelCheckConv2DBlockScaled(Operation *op) {
479 auto convOp = dyn_cast<Conv2DBlockScaledOp>(op);
480 if (!convOp)
481 return success();
482
483 SmallVector<int64_t> padValues;
484 if (tosa::getConstShapeValues(convOp.getPad().getDefiningOp(), padValues)) {
485 for (const auto p : padValues)
486 if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL")))
487 return failure();
488 }
489
490 SmallVector<int64_t> strideValues;
491 if (tosa::getConstShapeValues(convOp.getStride().getDefiningOp(),
492 strideValues)) {
493 for (const auto s : strideValues)
494 if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL")))
495 return failure();
496 }
497
498 SmallVector<int64_t> dilationValues;
499 if (tosa::getConstShapeValues(convOp.getDilation().getDefiningOp(),
500 dilationValues)) {
501 int64_t KH = ShapedType::kDynamic;
502 int64_t KW = ShapedType::kDynamic;
503 const ShapeAdaptor weightDataShape(convOp.getWeightData().getType());
504 KH = weightDataShape.getDimSize(1);
505 KW = weightDataShape.getDimSize(2);
506 const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType());
507 KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH;
508 KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW;
509
510 if (!ShapedType::isDynamic(KH) &&
511 failed(levelCheckKernel(op, dilationValues[0] * KH,
512 "dilation_y * KH <= MAX_KERNEL)")))
513 return failure();
514
515 if (!ShapedType::isDynamic(KW) &&
516 failed(levelCheckKernel(op, dilationValues[1] * KW,
517 "dilation_x * KW <= MAX_KERNEL)")))
518 return failure();
519 }
520
521 return success();
522 }
523
524 // FFT op: level check H, W in input shape [N,H,W]
525 template <typename T>
526 LogicalResult levelCheckFFT(Operation *op) {
527 if (isa<T>(op)) {
528 for (auto v : op->getOperands()) {
529 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
530 auto shape = type.getShape();
531 assert(shape.size() == 3);
532 if (failed(levelCheckKernel(op, shape[1], "H")) ||
533 failed(levelCheckKernel(op, shape[2], "W"))) {
534 return failure();
535 }
536 }
537 }
538 }
539 return success();
540 }
541
542 // TransposeConv2d op: level check kH/kW, outpad, and stride
543 LogicalResult levelCheckTransposeConv2d(Operation *op) {
544 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
545 if (ShapedType filterType =
546 dyn_cast<ShapedType>(transpose.getWeight().getType())) {
547 auto shape = filterType.getShape();
548 assert(shape.size() == 4);
549 // level check kernel sizes for kH and KW
550 if (failed(levelCheckKernel(op, shape[1], "KH")) ||
551 failed(levelCheckKernel(op, shape[2], "KW"))) {
552 return failure();
553 }
554 }
555 for (auto p : transpose.getOutPad()) {
556 if (failed(levelCheckKernel(op, p, "pad"))) {
557 return failure();
558 }
559 }
560 for (auto s : transpose.getStride()) {
561 if (failed(levelCheckStride(op, s, "stride"))) {
562 return failure();
563 }
564 }
565 }
566 return success();
567 }
568
569 // Resize op: level check max scales
570 LogicalResult levelCheckResize(Operation *op) {
571 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
572 SmallVector<int64_t> scale;
573 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
574 scale)) {
575 return failure();
576 }
577 const int64_t scaleYN = scale[0];
578 const int64_t scaleYD = scale[1];
579 const int64_t scaleXN = scale[2];
580 const int64_t scaleXD = scale[3];
581 if (failed(
582 levelCheckScale(op, scaleYN / scaleYD, "scale_y_n/scale_y_d")) ||
583 failed(
584 levelCheckScale(op, scaleXN / scaleXD, "scale_x_n/scale_x_d"))) {
585 return failure();
586 }
587 }
588 return success();
589 }
590
591 // Recursively perform a bottom-up search to determine the maximum nesting
592 // depth, starting from a specific operation and continuing up to the function
593 // or module scope. Tosa nesting_depth starts at 0 and increments by one each
594 // time a new nested `region` is encountered.
595 static void getMaxNestedDepth(Operation *op, int32_t &depth) {
596 if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
597 return;
598
599 op = op->getParentOp();
600 if (!op)
601 return;
602
603 depth++;
604 getMaxNestedDepth(op, depth);
605 }
606
607 LogicalResult levelCheckMaxNesting(Operation *op) {
608 int32_t maxNestedDepth = 0;
609 getMaxNestedDepth(op, maxNestedDepth);
610
611 const int32_t maxNestingLevel = targetEnv.getLevel().MAX_NESTING;
612 if (maxNestedDepth >= maxNestingLevel)
613 return op->emitOpError()
614 << "failed level check: tosa_nesting_depth < MAX_NESTING" << " ("
615 << maxNestingLevel << "), got " << maxNestedDepth;
616 return success();
617 }
618
619 LogicalResult levelCheckListSize(Operation *op) {
620 if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
621 return levelCheckListSize(op, concat.getInput1().size(), "input1");
622 }
623 if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
624 if (failed(levelCheckListSize(op, custom.getInputList().size(),
625 "input_list")) ||
626 failed(levelCheckListSize(op, custom.getOutputList().size(),
627 "output_list"))) {
628 return failure();
629 }
630 }
631 if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
632 if (failed(
633 levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
634 failed(levelCheckListSize(op, condIf.getOutputList().size(),
635 "outputs"))) {
636 return failure();
637 }
638 }
639 if (auto w = dyn_cast<tosa::WhileOp>(op)) {
640 if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
641 failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
642 return failure();
643 }
644 }
645 if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
646 return levelCheckListSize(op, concat_shape.getInput().size(), "input");
647 return success();
648 }
649
650 LogicalResult attributeCheckRescale(Operation *op) {
651 if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
652 if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
653 !targetEnv.allows(Extension::doubleround)) {
654 op->emitOpError()
655 << "failed attribute check: rounding_mode = DOUBLE_ROUND "
656 << "requires extension [doubleround]";
657 return failure();
658 }
659 if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
660 !targetEnv.allows(Extension::inexactround)) {
661 op->emitOpError()
662 << "failed attribute check: rounding_mode = INEXACT_ROUND "
663 << "requires extension [inexactround]";
664 return failure();
665 }
666 }
667 return success();
668 }
669
670 LogicalResult CheckVariable(Operation *op);
671 LogicalResult CheckVariableReadOrWrite(Operation *op);
672 bool isValidElementType(Type type, const bool allowUnsigned = false);
673
674 SmallVector<
675 std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
676 constCheckers;
678 TosaProfileCompliance profileComp;
679 tosa::TargetEnv targetEnv;
680};
681
682template <>
683LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
684 auto *op = tosaOp.getOperation();
685 if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
686 targetEnv.getLevel().MAX_RANK)))
687 return failure();
688
689 // rank(output) = rank(input) - 1
690 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
691 targetEnv.getLevel().MAX_RANK - 1)))
692 return failure();
693
694 return success();
695}
696
697template <>
698LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
699 auto *op = tosaOp.getOperation();
700
701 // Only the condition input has rank limitation.
702 if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
703 targetEnv.getLevel().MAX_RANK)))
704 return failure();
705
706 return success();
707}
708
709template <>
710LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
711 auto *op = tosaOp.getOperation();
712 auto variableType = getVariableType(tosaOp);
713 if (failed(levelCheckRank(op, variableType, "variable type",
714 targetEnv.getLevel().MAX_RANK)))
715 return failure();
716
717 return success();
718}
719
720template <>
721LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
722 auto *op = tosaOp.getOperation();
723 auto variableType = getVariableType(tosaOp);
724 if (failed(levelCheckSize(op, variableType, "variable type")))
725 return failure();
726
727 return success();
728}
729
730LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
731#define CHECK_RANKS_AND_SIZES(tosaOp) \
732 if (isa<tosa::tosaOp##Op>(op)) { \
733 if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
734 return failure(); \
735 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
736 return failure(); \
737 }
738
739#define CHECK_SIZES(tosaOp) \
740 if (isa<tosa::tosaOp##Op>(op)) { \
741 if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
742 return failure(); \
743 }
744
745#define CHECK_SHAPE_LEN(tosaOp) \
746 if (isa<tosa::tosaOp##Op>(op)) { \
747 if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
748 return failure(); \
749 }
750
751 // Tensor Operators
752 CHECK_RANKS_AND_SIZES(ArgMax);
753 // Activation Functions
756 CHECK_RANKS_AND_SIZES(Sigmoid);
758 // Elementwise Binary Operators
760 CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
761 CHECK_RANKS_AND_SIZES(BitwiseAnd);
762 CHECK_RANKS_AND_SIZES(BitwiseOr);
763 CHECK_RANKS_AND_SIZES(BitwiseXor);
764 CHECK_RANKS_AND_SIZES(IntDiv);
765 CHECK_RANKS_AND_SIZES(LogicalAnd);
766 CHECK_RANKS_AND_SIZES(LogicalLeftShift);
767 CHECK_RANKS_AND_SIZES(LogicalRightShift);
768 CHECK_RANKS_AND_SIZES(LogicalOr);
769 CHECK_RANKS_AND_SIZES(LogicalXor);
770 CHECK_RANKS_AND_SIZES(Maximum);
771 CHECK_RANKS_AND_SIZES(Minimum);
776 // Elementwise Unary Operators
778 CHECK_RANKS_AND_SIZES(BitwiseNot);
785 CHECK_RANKS_AND_SIZES(LogicalNot);
786 CHECK_RANKS_AND_SIZES(Negate);
787 CHECK_RANKS_AND_SIZES(Reciprocal);
790 // Elementwise Ternary Operators
791 CHECK_RANKS_AND_SIZES(Select);
792 // Comparison Operators
794 CHECK_RANKS_AND_SIZES(Greater);
795 CHECK_RANKS_AND_SIZES(GreaterEqual);
796 // Reduction Operators
797 CHECK_RANKS_AND_SIZES(ReduceAll);
798 CHECK_RANKS_AND_SIZES(ReduceAny);
799 CHECK_RANKS_AND_SIZES(ReduceMax);
800 CHECK_RANKS_AND_SIZES(ReduceMin);
801 CHECK_RANKS_AND_SIZES(ReduceProduct);
802 CHECK_RANKS_AND_SIZES(ReduceSum);
803 // Data Layout Operators
804 CHECK_RANKS_AND_SIZES(Concat);
806 CHECK_RANKS_AND_SIZES(Reshape);
807 CHECK_RANKS_AND_SIZES(ReshapeBlockScaled);
808 CHECK_RANKS_AND_SIZES(Reverse);
811 CHECK_RANKS_AND_SIZES(Transpose);
812 // Type Conversion
814 CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
815 CHECK_RANKS_AND_SIZES(CastToBlockScaled);
816 CHECK_RANKS_AND_SIZES(Rescale);
817 // Data Nodes
819 CHECK_RANKS_AND_SIZES(Identity);
820 // Control Flow Operators
822 // Variable Operators
823 CHECK_RANKS_AND_SIZES(Variable);
824 CHECK_RANKS_AND_SIZES(VariableWrite);
825 CHECK_RANKS_AND_SIZES(VariableRead);
826 // Shape Operators
828
829 // For the following operators, check whether the size of each tensor
830 // operand is valid in a given Level.
831
832 // Tensor Operators
833 CHECK_SIZES(AvgPool2d);
834 CHECK_SIZES(AvgPool2dAdaptive);
835 CHECK_SIZES(Conv2D);
836 CHECK_SIZES(Conv2DBlockScaled);
837 CHECK_SIZES(Conv3D);
838 CHECK_SIZES(DepthwiseConv2D);
839 CHECK_SIZES(TransposeConv2D);
840 CHECK_SIZES(FFT2d);
841 CHECK_SIZES(MatMul);
842 CHECK_SIZES(MatMulT);
843 CHECK_SIZES(MatmulTBlockScaled);
844 CHECK_SIZES(MaxPool2d);
845 CHECK_SIZES(MaxPool2dAdaptive);
846 CHECK_SIZES(RFFT2d);
847 // Scatter/Gather Operators
849 CHECK_SIZES(RowGather);
850 CHECK_SIZES(Scatter);
851 // Image Operators
852 CHECK_SIZES(Resize);
853 // Custom Operators
854 CHECK_SIZES(Custom);
855 // Control Flow Operators
856 CHECK_SIZES(While);
857 // Shape Operators
858 CHECK_SIZES(ConstShape);
859
860 // For the following operations, check whether the shape length of each
861 // operand is valid given a level.
862
863 // Shape Operators
864 CHECK_SHAPE_LEN(AddShape);
865 CHECK_SHAPE_LEN(AssertEqualShape);
866 CHECK_SHAPE_LEN(ConcatShape);
867 CHECK_SHAPE_LEN(DivCeilShape);
868 CHECK_SHAPE_LEN(DivFloorShape);
869 CHECK_SHAPE_LEN(Exp2Shape);
870 CHECK_SHAPE_LEN(Log2CeilShape);
871 CHECK_SHAPE_LEN(Log2FloorShape);
872 CHECK_SHAPE_LEN(MaxShape);
873 CHECK_SHAPE_LEN(MinShape);
874 CHECK_SHAPE_LEN(ModShape);
875 CHECK_SHAPE_LEN(MulShape);
876 CHECK_SHAPE_LEN(SliceShape);
877 CHECK_SHAPE_LEN(SubShape);
878
879#undef CHECK_RANKS_AND_SIZES
880#undef CHECK_SIZES
881#undef CHECK_SHAPE_LEN
882 return success();
883}
884
885// Perform the Level tensor size check on the tensor type.
886LogicalResult TosaValidation::levelCheckSize(Operation *op,
887 const Type &typeToCheck,
888 const StringRef operandOrResult) {
889 if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
890 if (!type.hasRank())
891 return op->emitOpError() << "failed level check: unranked tensor";
892 auto shape = type.getShape();
893 for (auto dim : shape) {
894 const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
895 const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
896 const TosaSpecificationVersion minRequiredVersion(1, 1, true);
897 if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
898 dimIsDynamic)
899 // TOSA 1.1 and above supports dynamic dimensions, however, they must be
900 // resolved at backend compile time. Runtime dynamism is not currently
901 // supported. Checking this requirement is met is delegated to backends.
902 return success();
903
904 // When targeting TOSA 1.0 or below, dynamic dims are not supported
905 if (dimIsDynamic)
906 return op->emitOpError() << "failed level check: " << operandOrResult
907 << " shape dimension cannot be dynamic when"
908 << " targeting TOSA specification version 1.0"
909 << " or below";
910 }
911
912 int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
913 int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
914 int64_t size = element_bytes * type.getNumElements();
915
916 // According to 1.11. Tensor Definitions of Tosa spec, the value of
917 // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
918 // defined in 1.7. Levels.
919 // For each tensor, the number of tensor elements multiplied by the
920 // element size in bytes must be representable as a tensor_size_t.
921 const int64_t max_size =
922 (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
923 if (size > max_size)
924 return op->emitOpError()
925 << "failed level check: " << operandOrResult
926 << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
927 }
928 return success();
929}
930
931LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
932 if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
933 // no need to do level checks
934 return success();
935 }
936
937 // check rank and sizes early so later checks can assume shaped operands
938 if (failed(levelCheckRanksAndSizes(op)))
939 return failure();
940
941 if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
942 failed(levelCheckAdaptivePool<tosa::AvgPool2dAdaptiveOp>(op)) ||
943 failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
944 failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
945 failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
946 failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
947 failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
948 failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
949 failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
950 failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
951 failed(levelCheckConv2DBlockScaled(op))) {
952 return failure();
953 }
954
955 // level check MAX_TENSOR_LIST_SIZE
956 if (failed(levelCheckListSize(op))) {
957 return failure();
958 }
959
960 if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
961 if (failed(levelCheckMaxNesting(op))) {
962 return failure();
963 }
964 }
965
966 return success();
967}
968
969LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
970 if (failed(attributeCheckRescale(op)))
971 return failure();
972 return success();
973}
974
975inline bool CompatibleTypes(const mlir::Type &type,
976 const mlir::Type &declaredType) {
977 // for now, simply use type equality comparison
978 return type == declaredType;
979}
980
981LogicalResult TosaValidation::CheckVariable(Operation *op) {
982 if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
983 mlir::StringAttr nameAttr = variableOp.getNameAttr();
984
985 if (variablesMap.count(nameAttr))
986 return op->emitOpError() << "name has already been declared";
987
988 auto elementType = variableOp.getType();
989 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
990 SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
991 RankedTensorType variableType =
992 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
993
994 variablesMap[nameAttr] = variableType;
995 }
996
997 return success();
998}
999
1000LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
1001 if (isa<mlir::tosa::VariableReadOp>(op) ||
1002 isa<mlir::tosa::VariableWriteOp>(op)) {
1003 mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
1004 if (!variablesMap.count(nameAttr))
1005 return op->emitOpError() << "name has not been declared";
1006
1007 auto varType = variablesMap[nameAttr];
1008
1009 for (auto v : op->getOperands()) {
1010 auto type = v.getType();
1011 if (!CompatibleTypes(type, varType))
1012 return op->emitOpError() << "operand type does not equal variable type";
1013 }
1014
1015 for (auto v : op->getResults()) {
1016 auto type = v.getType();
1017 if (!CompatibleTypes(type, varType))
1018 return op->emitOpError() << "result type does not equal variable type";
1019 }
1020 }
1021
1022 return success();
1023}
1024
1025LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
1026 if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
1027 return failure();
1028 return success();
1029}
1030
1031LogicalResult checkErrorIfResize(Operation *op) {
1032 auto resize = dyn_cast<tosa::ResizeOp>(op);
1033 if (!resize)
1034 return success();
1035
1036 const Value input = resize.getInput();
1037 const Value output = resize.getOutput();
1038 const RankedTensorType inputType =
1039 llvm::dyn_cast<RankedTensorType>(input.getType());
1040 const RankedTensorType outputType =
1041 llvm::dyn_cast<RankedTensorType>(output.getType());
1042
1043 if (!inputType || !outputType)
1044 return op->emitOpError("expect ranked input/output tensor");
1045
1046 // Ensure the image size is supported by GPU APIs and that for integer
1047 // implementations, position * stride does not overflow int32_t.
1048 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1049 const SmallVector<int64_t, 4> sizes = {
1050 outputType.getDimSize(1), outputType.getDimSize(2),
1051 inputType.getDimSize(1), inputType.getDimSize(2)};
1052 const int64_t *maxDim = llvm::max_element(sizes);
1053 if (maxDim != sizes.end() && *maxDim >= 16384)
1054 return op->emitOpError(
1055 "expect input/output height/width dims to be < 16384, ")
1056 << "got [OH, OW, IH, IW] = " << sizes;
1057 }
1058
1059 SmallVector<int64_t> scale;
1060 if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
1061 return failure();
1062
1063 const int64_t scaleYN = scale[0];
1064 const int64_t scaleYD = scale[1];
1065 const int64_t scaleXN = scale[2];
1066 const int64_t scaleXD = scale[3];
1067
1068 // Ensure scale values don't overflow int32 accumulator
1069 if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
1070 return op->emitOpError(
1071 "expect all scale numerator values to be <= (1 << 11), "
1072 "got scale_y_n=")
1073 << scaleYN << ", scale_x_n=" << scaleXN;
1074
1075 if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
1076 return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
1077 << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
1081 if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
1082 !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
1083 return failure();
1084
1085 const int64_t offsetY = offset[0];
1086 const int64_t offsetX = offset[1];
1087 // Set a consistent lower limit of 1/16 downscale to simplify
1088 // implementations
1089 if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
1090 return op->emitOpError(
1091 "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
1092 << offsetY << "/" << scaleYN;
1093 if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
1094 return op->emitOpError(
1095 "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
1096 << offsetX << "/" << scaleXN;
1097
1098 const int64_t borderY = border[0];
1099 const int64_t borderX = border[1];
1100 if (borderY < -16 * scaleYN || borderY >= scaleYN)
1101 return op->emitOpError(
1102 "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
1103 << borderY << "/" << scaleYN;
1104 if (borderX < -16 * scaleXN || borderX >= scaleXN)
1105 return op->emitOpError(
1106 "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
1107 << borderX << "/" << scaleXN;
1108
1109 // The following section of code is mostly duplicated with ResizeOp::verify().
1110 //
1111 // In TOSA specification, we do not support broadcast behavior.
1112 // However, there is a rewrite pattern to materialize broadcast ResizeOp.
1113 // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
1114 // existing code, we keep the rewrite pattern untouched. So, we need
1115 // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
1116 //
1117 // Here is a strict checking to conform TOSA specification.
1118 // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
1119 auto idivCheck = [](const int64_t lhs,
1120 const int64_t rhs) -> std::optional<int64_t> {
1121 if (lhs % rhs != 0)
1122 return std::nullopt;
1123 return lhs / rhs;
1125
1126 const int64_t oh = outputType.getDimSize(1);
1127 const int64_t ow = outputType.getDimSize(2);
1128 const int64_t ih = inputType.getDimSize(1);
1129 const int64_t iw = inputType.getDimSize(2);
1130
1131 if (ih != ShapedType::kDynamic) {
1132 const std::optional<int64_t> calculatedOutHeightMinusOne =
1133 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1134 if (!calculatedOutHeightMinusOne.has_value())
1135 return op->emitOpError(
1136 "expected (input_height - 1) * scale_y_n - offset_y + "
1137 "border_y ")
1138 << "to be wholly divisible by scale_y_d, got ((" << ih
1139 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1140 << ") / " << scaleYD;
1141 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1142 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1143 return op->emitOpError(
1144 "calculated output height did not match expected: ")
1145 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1146 }
1147
1148 if (iw != ShapedType::kDynamic) {
1149 const std::optional<int64_t> calculatedOutWidthMinusOne =
1150 idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1151 if (!calculatedOutWidthMinusOne.has_value())
1152 return op->emitOpError(
1153 "expected (input_width - 1) * scale_x_n - offset_x + "
1154 "border_x ")
1155 << "to be wholly divisible by scale_x_d, got ((" << iw
1156 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1157 << ") / " << scaleXD;
1158 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1159 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1160 return op->emitOpError("calculated output width did not match expected: ")
1161 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1162 }
1163
1164 return success();
1165}
1166
1167LogicalResult checkErrorIfMul(Operation *op) {
1168 auto mul = dyn_cast<tosa::MulOp>(op);
1169 if (!mul)
1170 return success();
1171
1172 // REQUIRE(0 <= shift && shift <= 63);
1173 // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1174 ElementsAttr shift_elem;
1175 if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
1176 return success();
1177 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1178 auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1179 if (inputElemType.isInteger(32)) {
1180 // 0 <= shift <= 63 for int32_t type
1181 if (shift < 0 || shift > 63)
1182 return op->emitOpError()
1183 << "requires 0 <= shift && shift <= 63, but got: " << shift;
1184 } else {
1185 // shift must be 0 for all other types
1186 if (shift != 0)
1187 return op->emitOpError()
1188 << "requires shift = 0 for all input data types that "
1189 "are not int32_t, but got: "
1190 << shift;
1191 }
1192
1193 return success();
1194}
1195
1196LogicalResult checkErrorIfTable(Operation *op) {
1197 auto table = dyn_cast<tosa::TableOp>(op);
1198 if (!table)
1199 return success();
1200
1201 // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1202 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1203 const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1204
1205 const ShapeAdaptor tableShape(table.getTable().getType());
1206 if (tableShape.hasStaticShape()) {
1207 const auto numElements = tableShape.getNumElements();
1208 if (numElements != tableSize)
1209 return op->emitOpError() << "requires table size of " << tableSize
1210 << ", got " << numElements;
1211 }
1212
1213 return success();
1214}
1215
1216LogicalResult checkErrorIfRescale(Operation *op) {
1217 auto rescale = dyn_cast<tosa::RescaleOp>(op);
1218 if (!rescale)
1219 return success();
1220
1221 auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1222 auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1223 if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1224 !outputType.getElementType().isInteger())
1225 return success();
1226
1227 auto inElemType = inputType.getElementType();
1228 auto outElemType = outputType.getElementType();
1229 auto inWidth = inElemType.getIntOrFloatBitWidth();
1230 auto outWidth = outElemType.getIntOrFloatBitWidth();
1231
1232 bool inputUnsigned = rescale.getInputUnsigned();
1233 bool outputUnsigned = rescale.getOutputUnsigned();
1234
1235 bool scale32 = rescale.getScale32();
1236 auto roundingMode = rescale.getRoundingMode();
1237
1238 // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1239 if (scale32 && inWidth == 48)
1240 return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1241
1242 // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1243 if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1244 return op->emitOpError()
1245 << "DOUBLE_ROUND is only allowed with scale32=true.";
1246
1247 // ERROR_IF(input_unsigned && output_unsigned)
1248 if (inputUnsigned && outputUnsigned)
1249 return op->emitOpError() << "input and output cannot be both unsigned.";
1250
1251 // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1252 if (outWidth == 32 && inputUnsigned)
1253 return op->emitOpError()
1254 << "i32 output type is not allowed with unsigned input.";
1255
1256 // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1257 if (inWidth == 32 && outputUnsigned)
1258 return op->emitOpError()
1259 << "i32 input type is not allowed with unsigned output.";
1260
1261 // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1262 if (inWidth == 48 && outputUnsigned)
1263 return op->emitOpError()
1264 << "i48 input type is not allowed with unsigned output.";
1265
1266 // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1267 if (inWidth == 48 && inputUnsigned)
1268 return op->emitOpError() << "i48 input type cannot be unsigned.";
1269
1270 // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1271 if (inWidth == 32 && inputUnsigned)
1272 return op->emitOpError() << "i32 input type cannot be unsigned.";
1273
1274 // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1275 if (outWidth == 32 && outputUnsigned)
1276 return op->emitOpError() << "i32 output type cannot be unsigned.";
1277
1278 return success();
1279}
1280
1281LogicalResult checkErrorIfPad(Operation *op) {
1282 auto pad = dyn_cast<tosa::PadOp>(op);
1283 if (!pad)
1284 return success();
1285
1286 DenseIntElementsAttr paddingAttr;
1287 if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1288 // Pad verifier will catch this
1289 return success();
1290
1291 for (const APInt &val : paddingAttr.getValues<APInt>()) {
1292 if (val.getSExtValue() < 0)
1293 return op->emitOpError() << "padding value must all be non-negative, got "
1294 << val.getSExtValue();
1295 }
1296
1297 return success();
1298}
1299
1300LogicalResult checkErrorIfReshape(Operation *op) {
1301 auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
1302 if (!reshapeOp)
1303 return success();
1304
1305 SmallVector<int64_t> shapeValues;
1306 if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
1307 shapeValues))
1308 return success();
1309
1310 if (llvm::is_contained(shapeValues, kInferableDimSize))
1311 return op->emitOpError("shape input contains inferable dimension (")
1313 << ") "
1314 "which does not conform to the TOSA specification";
1315
1316 return success();
1317}
1318
1319LogicalResult checkErrorIfSlice(Operation *op) {
1320 auto sliceOp = dyn_cast<tosa::SliceOp>(op);
1321 if (!sliceOp)
1322 return success();
1323
1324 SmallVector<int64_t> startValues;
1325 SmallVector<int64_t> sizeValues;
1326 const bool hasStartValues = tosa::getConstShapeValues(
1327 sliceOp.getStart().getDefiningOp(), startValues);
1328 const bool hasSizeValues =
1329 tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
1330
1331 if (hasStartValues && llvm::is_contained(startValues, kInferableDimSize))
1332 return op->emitOpError("start input contains inferable dimension (")
1334 << ") which does not conform to the TOSA specification";
1335 if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDimSize))
1336 return op->emitOpError("size input contains inferable dimension (")
1338 << ") which "
1339 "does not conform to the TOSA specification";
1340
1341 return success();
1342}
1343
1344static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1345 return llvm::all_of(op->getOperands(), [&](auto operand) {
1346 Region *operandRegion = operand.getParentRegion();
1347 return operandRegion && region->isAncestor(operandRegion);
1348 });
1349}
1350
1351static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1352 bool noLiveInValue = true;
1353 regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1354 if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1355 noLiveInValue = false;
1356 return WalkResult::interrupt();
1357 }
1358 return WalkResult::advance();
1359 });
1360 return noLiveInValue ? success() : failure();
1361}
1362
1363LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1364 StringRef regionName) {
1365 if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1366 return success();
1367 return op->emitOpError()
1368 << "is not conformant to the TOSA specification. It requires the '"
1369 << regionName << "' region is isolated from above.\n";
1370}
1371
1372LogicalResult checkErrorIfCondIf(Operation *op) {
1373 auto ifOp = dyn_cast<tosa::IfOp>(op);
1374 if (!ifOp)
1375 return success();
1376
1377 // Currently the dialect supports declaring cond_if operations that
1378 // have then/else regions that reference values from outside these
1379 // regions. According to the specification, all values used by the
1380 // then/else regions must be explicitly declared within the regions.
1381 // Therefore we must check that the then/else regions are
1382 // "isolated from above", in order to be conformant to the
1383 // specification.
1384 //
1385 // Note: the dialect currently supports two styles of syntax for
1386 // declaring "cond_if" operations. We'll refer to these as follows:
1387 //
1388 // Generic:
1389 // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1390 // ^bb0(%arg3, %arg4):
1391 // tosa.yield %arg3
1392 // }, {
1393 // ^bb0(%arg3, %arg4):
1394 // tosa.yield %arg4
1395 // })
1396 //
1397 // Simplified:
1398 // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1399 // ^bb0(%arg3, %arg4):
1400 // tosa.yield %arg3
1401 // } else {
1402 // ^bb0(%arg3, %arg4):
1403 // tosa.yield %arg4
1404 // }
1405
1406 if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1407 failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1408 return failure();
1409 return success();
1410}
1411
1412LogicalResult checkErrorIfWhileLoop(Operation *op) {
1413 auto whileOp = dyn_cast<tosa::WhileOp>(op);
1414 if (!whileOp)
1415 return success();
1416
1417 if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1418 failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1419 return failure();
1420 return success();
1421}
1422
1423LogicalResult checkErrorIfScatter(Operation *op) {
1424 auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1425 if (!scatterOp)
1426 return success();
1427
1428 // for constant indices, check that there are no duplicate values
1429 DenseIntElementsAttr indicesAttr;
1430 if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1431 return success();
1432
1433 auto const indicesType =
1434 dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1435 if (!indicesType || !indicesType.hasRank()) {
1436 op->emitOpError("expect ranked indices tensor");
1437 return failure();
1438 }
1439
1440 if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1441 op->emitOpError("indices values contain duplicates");
1442 return failure();
1443 }
1444
1445 return success();
1446}
1447
1448LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1449 if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1450 failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1451 failed(checkErrorIfPad(op)) || failed(checkErrorIfReshape(op)) ||
1452 failed(checkErrorIfSlice(op)) || failed(checkErrorIfCondIf(op)) ||
1453 failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1454 return failure();
1455 return success();
1456}
1457
1458LogicalResult TosaValidation::applyFunctionSignatureCheck(func::FuncOp op) {
1459 const auto isShapeType = [](Type type) { return isa<tosa::shapeType>(type); };
1460 if (llvm::any_of(op.getArgumentTypes(), isShapeType))
1461 return op.emitOpError()
1462 << "Function argument types must be a tensor type to be TOSA "
1463 "compliant, got !tosa.shape type";
1464 if (llvm::any_of(op.getResultTypes(), isShapeType))
1465 return op.emitOpError()
1466 << "Function return types must be a tensor type to be TOSA "
1467 "compliant, got !tosa.shape type";
1468 return success();
1469}
1470
1471bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1472 if (isa<FloatType>(type)) {
1473 return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1474 Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
1475 Float6E3M2FNType, Float8E8M0FNUType>(type);
1476 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1477 if (intTy.isSignless()) {
1478 switch (intTy.getWidth()) {
1479 case 1:
1480 case 4:
1481 case 8:
1482 case 16:
1483 case 32:
1484 case 48:
1485 case 64:
1486 return true;
1487 }
1488 } else if (allowUnsigned && intTy.isUnsigned()) {
1489 switch (intTy.getWidth()) {
1490 case 8:
1491 case 16:
1492 case 32:
1493 return true;
1494 }
1495 }
1496 } else if (isa<tosa::shapeType>(type))
1497 return true;
1498 else if (isa<tosa::mxint8Type>(type))
1499 return true;
1500 return false;
1501}
1502
1503void TosaValidation::runOnOperation() {
1504 ModuleOp modOp = getOperation();
1505 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1506 if (!tosaDialect)
1507 return;
1508
1509 const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
1510 const auto maybeTargetEnv =
1511 tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
1512 if (failed(maybeTargetEnv))
1513 return signalPassFailure();
1514 targetEnv = *maybeTargetEnv;
1515
1516 const auto functions = modOp.getOps<func::FuncOp>();
1517 if (llvm::any_of(functions, [&](func::FuncOp func) {
1518 return failed(applyFunctionSignatureCheck(func));
1519 }))
1520 return signalPassFailure();
1521
1522 modOp.walk([&](Operation *op) {
1523 if (op->getDialect() != tosaDialect)
1524 return;
1525
1526 // validate operator element types:
1527 // - rescale operator is allowed to have ui8/ui16/ui32
1528 // operands/results when strictOpSpecAlignment is false
1529 // - perform valid element type check at the beginning to
1530 // protect rest of code against quantized element types
1531 const bool allowUnsigned =
1532 !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1533 for (Value operand : op->getOperands()) {
1534 auto elementTy = getElementTypeOrSelf(operand);
1535 if (!isValidElementType(elementTy, allowUnsigned)) {
1536 op->emitOpError() << "is not profile-aligned: element type "
1537 << elementTy << " is not legal";
1538 return signalPassFailure();
1539 }
1540 }
1541 for (Type resultTy : op->getResultTypes()) {
1542 auto elementTy = getElementTypeOrSelf(resultTy);
1543 if (!isValidElementType(elementTy, allowUnsigned)) {
1544 op->emitOpError() << "is not profile-aligned: element type "
1545 << elementTy << " is not legal";
1546 return signalPassFailure();
1547 }
1548 }
1549
1550 if (strictOpSpecAlignment &&
1551 failed(profileComp.checkProfile(op, targetEnv)))
1552 return signalPassFailure();
1553
1554 if (strictOpSpecAlignment &&
1555 failed(profileComp.checkExtension(op, targetEnv)))
1556 return signalPassFailure();
1557
1558 if (!allowInvalidOpDatatypeCombinations &&
1559 failed(profileComp.checkInvalid(op)))
1560 return signalPassFailure();
1561
1562 // Some uses of TOSA rely on the constant operands of particular
1563 // operations.
1564 if (failed(applyConstantOperandCheck(op)))
1565 signalPassFailure();
1566
1567 // do level checks
1568 if (failed(applyLevelCheck(op)))
1569 signalPassFailure();
1570
1571 // check additional attribute restrictions
1572 if (failed(applyAttributeCheck(op)))
1573 signalPassFailure();
1574
1575 // do variable type checks
1576 if (failed(applyVariableCheck(op)))
1577 signalPassFailure();
1578
1579 // do error if checks
1580 if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1581 signalPassFailure();
1582 });
1583}
1584} // namespace
return success()
lhs
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:578
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
#define CHECK_SHAPE_LEN(tosaOp)
@ Gather
#define mul(a, b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
Value getOperand(unsigned idx)
Definition Operation.h:375
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:119
TosaLevel getLevel() const
Definition TargetEnv.h:136
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
bool allows(Profile prof) const
Definition TargetEnv.h:139
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:132
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:69
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition TargetEnv.h:45
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:106
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:630
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369