MLIR  22.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
18 
19 #include <string>
20 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Pass/Pass.h"
30 #include "llvm/ADT/StringExtras.h"
31 
32 namespace mlir {
33 namespace tosa {
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36 } // namespace tosa
37 } // namespace mlir
38 
39 using namespace mlir;
40 using namespace mlir::tosa;
41 
42 namespace {
43 
44 static LogicalResult
45 checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46  for (const auto index : operandIndices) {
47  Attribute attr;
48  if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
49  return op->emitOpError("expected compile time resolvable constant, but "
50  "got variable value for operand #")
51  << index;
52  }
53  }
54  return success();
55 }
56 
57 static LogicalResult checkConstantOperandMul(Operation *op,
58  const TargetEnv &env) {
59  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60  // Check 'shift'
61  return checkConstantOperands(op, {2});
62  }
63  return success();
64 }
65 
66 static LogicalResult checkConstantOperandTable(Operation *op,
67  const TargetEnv &env) {
68  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69  // Check 'table'
70  return checkConstantOperands(op, {1});
71  }
72  return success();
73 }
74 
75 static LogicalResult checkConstantOperandPad(Operation *op,
76  const TargetEnv &env) {
77  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
78  // Assume this op is zero-padding if padConst is not presented
79  if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80  // Check 'pad_const'
81  // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82  return checkConstantOperands(op, {2});
83  }
84  return success();
85 }
86 
87 static LogicalResult checkConstantOperandRescale(Operation *op,
88  const TargetEnv &env) {
89  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90  // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91  return checkConstantOperands(op, {1, 2, 3, 4});
92  }
93  return success();
94 }
95 
96 template <typename T>
97 static LogicalResult checkConstantOperandConvOps(Operation *op,
98  const TargetEnv &env) {
99  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100  // Check 'input_zp' and 'weight_zp'
101  return checkConstantOperands(op, {3, 4});
102  }
103  return success();
104 }
105 
106 static LogicalResult checkConstantOperandMatMul(Operation *op,
107  const TargetEnv &env) {
108  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109  // Check 'A_zp' and 'B_zp'
110  return checkConstantOperands(op, {2, 3});
111  }
112  return success();
113 }
114 
115 static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116  const TargetEnv &env) {
117  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118  // Check 'input_zp' and 'output_zp'
119  return checkConstantOperands(op, {1, 2});
120  }
121  return success();
122 }
123 
124 static LogicalResult checkConstantOperandNegate(Operation *op,
125  const TargetEnv &env) {
126  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127  // Check 'input1_zp' and 'output_zp'
128  return checkConstantOperands(op, {1, 2});
129  }
130  return success();
131 }
132 
133 struct TosaLevel {
134  int32_t MAX_RANK = 0;
135  int32_t MAX_KERNEL = 0;
136  int32_t MAX_STRIDE = 0;
137  int32_t MAX_SCALE = 0;
138  int32_t MAX_LOG2_SIZE = 0;
139  int32_t MAX_NESTING = 0;
140  int32_t MAX_TENSOR_LIST_SIZE = 0;
141 
142  bool operator==(const TosaLevel &rhs) {
143  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145  MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146  MAX_NESTING == rhs.MAX_NESTING &&
147  MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
148  }
149 };
150 
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
153  63, 256, 256};
154 
155 //===----------------------------------------------------------------------===//
156 // TOSA Validation Pass.
157 //===----------------------------------------------------------------------===//
158 
159 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
160 public:
161  explicit TosaValidation() { populateConstantOperandChecks(); }
162 
163  explicit TosaValidation(const TosaValidationOptions &options)
164  : TosaValidation() {
165  this->profile = options.profile;
166  this->extension = options.extension;
167  this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168  this->allowInvalidOpDatatypeCombinations =
169  options.allowInvalidOpDatatypeCombinations;
170  this->level = options.level;
171  }
172  void runOnOperation() final;
173 
174  LogicalResult applyConstantOperandCheck(Operation *op) {
175  for (auto &checker : constCheckers) {
176  if (failed(checker(op, targetEnv)))
177  return failure();
178  }
179  return success();
180  }
181 
182  LogicalResult applyLevelCheck(Operation *op);
183  LogicalResult applyAttributeCheck(Operation *op);
184 
185  // check variable read/write data types against variable declarations
186  LogicalResult applyVariableCheck(Operation *op);
187 
188  // check error if conditions
189  LogicalResult applyErrorIfCheck(Operation *op);
190 
191 private:
192  void populateConstantOperandChecks() {
193  constCheckers.emplace_back(checkConstantOperandMul);
194  constCheckers.emplace_back(checkConstantOperandTable);
195  constCheckers.emplace_back(checkConstantOperandPad);
196  constCheckers.emplace_back(checkConstantOperandRescale);
197  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199  constCheckers.emplace_back(
200  checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201  constCheckers.emplace_back(
202  checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203  constCheckers.emplace_back(checkConstantOperandMatMul);
204  constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205  constCheckers.emplace_back(checkConstantOperandNegate);
206  }
207 
208  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
209  if (v > tosaLevel.MAX_KERNEL) {
210  op->emitOpError() << "failed level check: " << checkDesc;
211  return false;
212  }
213  return true;
214  }
215 
216  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
217  if (v > tosaLevel.MAX_STRIDE) {
218  op->emitOpError() << "failed level check: " << checkDesc;
219  return false;
220  }
221  return true;
222  }
223 
224  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
225  if (v > tosaLevel.MAX_SCALE) {
226  op->emitOpError() << "failed level check: " << checkDesc;
227  return false;
228  }
229  return true;
230  }
231 
232  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
233  if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234  op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
235  << checkDesc;
236  return false;
237  }
238  return true;
239  }
240 
241  // Perform the Level Rank check on the tensor type.
242  bool levelCheckRank(Operation *op, const Type typeToCheck,
243  const StringRef operandOrResult, int32_t highest_rank) {
244  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
245  if (!type.hasRank()) {
246  op->emitOpError() << "failed level check: unranked tensor";
247  return false;
248  }
249  if (type.getRank() > highest_rank) {
250  op->emitOpError() << "failed level check: " << operandOrResult
251  << " rank(shape) <= MAX_RANK";
252  return false;
253  }
254  }
255  return true;
256  }
257 
258  // Perform the Level Rank check on the tensor value.
259  bool levelCheckRank(Operation *op, const Value &v,
260  const StringRef operandOrResult, int32_t highest_rank) {
261  return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
262  }
263 
264  // Perform the Level tensor size check on the tensor type.
265  bool levelCheckSize(Operation *op, const Type &typeToCheck,
266  const StringRef operandOrResult);
267 
268  // Perform the Level tensor size check on the tensor value.
269  bool levelCheckSize(Operation *op, const Value &v,
270  const StringRef operandOrResult) {
271  return levelCheckSize(op, v.getType(), operandOrResult);
272  }
273 
274  // Level check sizes of all operands and results of the operation.
275  template <typename T>
276  bool levelCheckSizes(T tosaOp) {
277  auto op = tosaOp.getOperation();
278  for (auto v : op->getOperands()) {
279  if (!levelCheckSize(op, v, "operand"))
280  return false;
281  }
282 
283  for (auto v : op->getResults()) {
284  if (!levelCheckSize(op, v, "result"))
285  return false;
286  }
287  return true;
288  }
289 
290  // Level check ranks of all operands, attribute and results of the operation.
291  template <typename T>
292  bool levelCheckRanks(T tosaOp) {
293  auto op = tosaOp.getOperation();
294  for (auto v : op->getOperands()) {
295  if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
296  return false;
297  }
298 
299  for (auto v : op->getResults()) {
300  if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
301  return false;
302  }
303  return true;
304  }
305 
306  // Level check ranks and sizes.
307  bool levelCheckRanksAndSizes(Operation *op);
308 
309  // Pool Op: level check kernel/stride/pad values
310  template <typename T>
311  bool levelCheckPool(Operation *op) {
312  if (auto poolOp = dyn_cast<T>(op)) {
313  for (auto k : poolOp.getKernel()) {
314  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
315  return false;
316  }
317  }
318  for (auto s : poolOp.getStride()) {
319  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
320  return false;
321  }
322  }
323  for (auto p : poolOp.getPad()) {
324  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
325  return false;
326  }
327  }
328  }
329  return true;
330  }
331 
332  // Conv Op: level check dilation/stride/pad values
333  template <typename T>
334  bool levelCheckConv(Operation *op) {
335  if (auto convOp = dyn_cast<T>(op)) {
336 
337  for (auto k : convOp.getDilation()) {
338  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
339  return false;
340  }
341  }
342  for (auto p : convOp.getPad()) {
343  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
344  return false;
345  }
346  }
347  for (auto s : convOp.getStride()) {
348  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
349  return false;
350  }
351  }
352  auto dilation = convOp.getDilation();
353  if (ShapedType weightType =
354  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
355  auto shape = weightType.getShape();
356  if (isa<tosa::Conv2DOp>(op)) {
357  assert(shape.size() == 4);
358  assert(dilation.size() == 2);
359  if (!levelCheckKernel(op, dilation[0] * shape[1],
360  "dilation_y * KH <= MAX_KERNEL)") ||
361  !levelCheckKernel(op, dilation[1] * shape[2],
362  "dilation_x * KW <= MAX_KERNEL)"))
363  return false;
364  } else if (isa<tosa::Conv3DOp>(op)) {
365  assert(shape.size() == 5);
366  assert(dilation.size() == 3);
367  if (!levelCheckKernel(op, dilation[0] * shape[1],
368  "dilation_d * KD <= MAX_KERNEL)") ||
369  !levelCheckKernel(op, dilation[1] * shape[2],
370  "dilation_y * KH <= MAX_KERNEL)") ||
371  !levelCheckKernel(op, dilation[2] * shape[3],
372  "dilation_x * KW <= MAX_KERNEL)"))
373  return false;
374  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
375  assert(shape.size() == 4);
376  assert(dilation.size() == 2);
377  if (!levelCheckKernel(op, dilation[0] * shape[0],
378  "dilation_y * KH <= MAX_KERNEL)") ||
379  !levelCheckKernel(op, dilation[1] * shape[1],
380  "dilation_x * KW <= MAX_KERNEL)"))
381  return false;
382  }
383  }
384  }
385  return true;
386  }
387 
388  // FFT op: level check H, W in input shape [N,H,W]
389  template <typename T>
390  bool levelCheckFFT(Operation *op) {
391  if (isa<T>(op)) {
392  for (auto v : op->getOperands()) {
393  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
394  auto shape = type.getShape();
395  assert(shape.size() == 3);
396  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
397  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
398  return false;
399  }
400  }
401  }
402  }
403  return true;
404  }
405 
406  // TransposeConv2d op: level check kH/kW, outpad, and stride
407  bool levelCheckTransposeConv2d(Operation *op) {
408  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
409  if (ShapedType filterType =
410  dyn_cast<ShapedType>(transpose.getWeight().getType())) {
411  auto shape = filterType.getShape();
412  assert(shape.size() == 4);
413  // level check kernel sizes for kH and KW
414  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
415  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
416  return false;
417  }
418  }
419  for (auto p : transpose.getOutPad()) {
420  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
421  return false;
422  }
423  }
424  for (auto s : transpose.getStride()) {
425  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
426  return false;
427  }
428  }
429  }
430  return true;
431  }
432 
433  // Resize op: level check max scales
434  bool levelCheckResize(Operation *op) {
435  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
436  SmallVector<int64_t> scale;
437  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
438  scale)) {
439  return false;
440  }
441  const int64_t scaleYN = scale[0];
442  const int64_t scaleYD = scale[1];
443  const int64_t scaleXN = scale[2];
444  const int64_t scaleXD = scale[3];
445  if (!levelCheckScale(op, scaleYN / scaleYD,
446  "scale_y_n/scale_y_d <= MAX_SCALE") ||
447  !levelCheckScale(op, scaleXN / scaleXD,
448  "scale_x_n/scale_x_d <= MAX_SCALE")) {
449  return false;
450  }
451  }
452  return true;
453  }
454 
455  // Recursively perform a bottom-up search to determine the maximum nesting
456  // depth, starting from a specific operation and continuing up to the function
457  // or module scope. Tosa nesting_depth starts at 0 and increments by one each
458  // time a new nested `region` is encountered.
459  static void getMaxNestedDepth(Operation *op, int32_t &depth) {
460  if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
461  return;
462 
463  op = op->getParentOp();
464  if (!op)
465  return;
466 
467  depth++;
468  getMaxNestedDepth(op, depth);
469  }
470 
471  bool levelCheckMaxNesting(Operation *op) {
472  int32_t maxNestedDepth = 0;
473  getMaxNestedDepth(op, maxNestedDepth);
474 
475  if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
476  op->emitOpError() << "failed level check: " << maxNestedDepth
477  << " >= MAX_NESTING";
478  return false;
479  }
480  return true;
481  }
482 
483  bool levelCheckListSize(Operation *op) {
484  if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
485  return levelCheckListSize(op, concat.getInput1().size(), "input1");
486  }
487  if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
488  if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
489  !levelCheckListSize(op, custom.getOutputList().size(),
490  "output_list")) {
491  return false;
492  }
493  }
494  if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
495  if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
496  !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
497  return false;
498  }
499  }
500  if (auto w = dyn_cast<tosa::WhileOp>(op)) {
501  if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
502  !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
503  return false;
504  }
505  }
506  return true;
507  }
508 
509  bool attributeCheckRescale(Operation *op) {
510  if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
511  if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
512  !targetEnv.allows(Extension::doubleround)) {
513  op->emitOpError()
514  << "failed attribute check: rounding_mode = DOUBLE_ROUND "
515  << "requires extension [doubleround]";
516  return false;
517  }
518  if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
519  !targetEnv.allows(Extension::inexactround)) {
520  op->emitOpError()
521  << "failed attribute check: rounding_mode = INEXACT_ROUND "
522  << "requires extension [inexactround]";
523  return false;
524  }
525  }
526  return true;
527  }
528 
529  // configure profile and level values from pass options profileName and
530  // levelName
531  void configLevelAndProfile() {
532  tosaLevel = TOSA_LEVEL_NONE;
533  if (level == TosaLevelEnum::EightK) {
534  tosaLevel = TOSA_LEVEL_EIGHTK;
535  }
536 
537  if (!profile.empty()) {
538  for (std::string &prof : profile) {
539  auto profSymbol = symbolizeProfile(prof);
540  if (profSymbol) {
541  targetEnv.addProfile(profSymbol.value());
542  } else {
543  llvm::errs() << "unknown TOSA profile name passed in: " << prof
544  << ", supported profiles are `pro_int` and `pro_fp`\n";
545  return signalPassFailure();
546  }
547  }
548  }
549 
550  if (!extension.empty()) {
551  for (std::string &ext : extension) {
552  auto extSymbol = symbolizeExtension(ext);
553  if (extSymbol) {
554  targetEnv.addExtension(extSymbol.value());
555  } else {
556  llvm::errs() << "unknown TOSA extension name passed in: " << ext
557  << ", supported extension are int16, int4, bf16, "
558  << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
559  << "doubleround, inexactround and dynamic\n";
560  return signalPassFailure();
561  }
562  }
563  }
564  }
565 
566  bool CheckVariable(Operation *op);
567  bool CheckVariableReadOrWrite(Operation *op);
568  bool isValidElementType(Type type, const bool allowUnsigned = false);
569 
570  SmallVector<
571  std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
572  constCheckers;
573  TosaLevel tosaLevel;
575  TosaProfileCompliance profileComp;
576  tosa::TargetEnv targetEnv;
577 };
578 
579 template <>
580 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
581  auto *op = tosaOp.getOperation();
582  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
583  return false;
584 
585  // rank(output) = rank(input) - 1
586  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
587  return false;
588 
589  return true;
590 }
591 
592 template <>
593 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
594  auto *op = tosaOp.getOperation();
595 
596  // Only the condition input has rank limitation.
597  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
598  return false;
599 
600  return true;
601 }
602 
603 template <>
604 bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
605  auto *op = tosaOp.getOperation();
606  auto variableType = getVariableType(tosaOp);
607  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
608  return false;
609 
610  return true;
611 }
612 
613 template <>
614 bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
615  auto *op = tosaOp.getOperation();
616  auto variableType = getVariableType(tosaOp);
617  if (!levelCheckSize(op, variableType, "variable type"))
618  return false;
619 
620  return true;
621 }
622 
623 bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
624 #define CHECK_RANKS_AND_SIZES(tosaOp) \
625  if (isa<tosa::tosaOp##Op>(op)) { \
626  if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
627  return false; \
628  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
629  return false; \
630  }
631 
632 #define CHECK_SIZES(tosaOp) \
633  if (isa<tosa::tosaOp##Op>(op)) { \
634  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
635  return false; \
636  }
637 
638  // Tensor Operators
639  CHECK_RANKS_AND_SIZES(ArgMax);
640  // Activation Functions
641  CHECK_RANKS_AND_SIZES(Clamp);
643  CHECK_RANKS_AND_SIZES(Sigmoid);
644  CHECK_RANKS_AND_SIZES(Tanh);
645  // Elementwise Binary Operators
647  CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
648  CHECK_RANKS_AND_SIZES(BitwiseAnd);
649  CHECK_RANKS_AND_SIZES(BitwiseOr);
650  CHECK_RANKS_AND_SIZES(BitwiseXor);
651  CHECK_RANKS_AND_SIZES(IntDiv);
652  CHECK_RANKS_AND_SIZES(LogicalAnd);
653  CHECK_RANKS_AND_SIZES(LogicalLeftShift);
654  CHECK_RANKS_AND_SIZES(LogicalRightShift);
655  CHECK_RANKS_AND_SIZES(LogicalOr);
656  CHECK_RANKS_AND_SIZES(LogicalXor);
657  CHECK_RANKS_AND_SIZES(Maximum);
658  CHECK_RANKS_AND_SIZES(Minimum);
662  CHECK_RANKS_AND_SIZES(Table);
663  // Elementwise Unary Operators
665  CHECK_RANKS_AND_SIZES(BitwiseNot);
666  CHECK_RANKS_AND_SIZES(Ceil);
670  CHECK_RANKS_AND_SIZES(Floor);
672  CHECK_RANKS_AND_SIZES(LogicalNot);
673  CHECK_RANKS_AND_SIZES(Negate);
674  CHECK_RANKS_AND_SIZES(Reciprocal);
675  CHECK_RANKS_AND_SIZES(Rsqrt);
677  // Elementwise Ternary Operators
678  CHECK_RANKS_AND_SIZES(Select);
679  // Comparison Operators
680  CHECK_RANKS_AND_SIZES(Equal);
681  CHECK_RANKS_AND_SIZES(Greater);
682  CHECK_RANKS_AND_SIZES(GreaterEqual);
683  // Reduction Operators
684  CHECK_RANKS_AND_SIZES(ReduceAll);
685  CHECK_RANKS_AND_SIZES(ReduceAny);
686  CHECK_RANKS_AND_SIZES(ReduceMax);
687  CHECK_RANKS_AND_SIZES(ReduceMin);
688  CHECK_RANKS_AND_SIZES(ReduceProduct);
689  CHECK_RANKS_AND_SIZES(ReduceSum);
690  // Data Layout Operators
691  CHECK_RANKS_AND_SIZES(Concat);
693  CHECK_RANKS_AND_SIZES(Reshape);
694  CHECK_RANKS_AND_SIZES(Reverse);
695  CHECK_RANKS_AND_SIZES(Slice);
696  CHECK_RANKS_AND_SIZES(Tile);
697  CHECK_RANKS_AND_SIZES(Transpose);
698  // Type Conversion
699  CHECK_RANKS_AND_SIZES(Cast);
700  CHECK_RANKS_AND_SIZES(Rescale);
701  // Control Flow Operators
703  // Variable Operators
704  CHECK_RANKS_AND_SIZES(Variable);
705  CHECK_RANKS_AND_SIZES(VariableWrite);
706  CHECK_RANKS_AND_SIZES(VariableRead);
707  // Data Nodes
708  CHECK_RANKS_AND_SIZES(Const);
709  CHECK_RANKS_AND_SIZES(Identity);
710 
711  // For the following operators, check whether the size of each tensor
712  // operand is valid in a given Level.
713 
714  // Tensor Operators
715  CHECK_SIZES(AvgPool2d);
716  CHECK_SIZES(Conv2D);
717  CHECK_SIZES(Conv3D);
718  CHECK_SIZES(DepthwiseConv2D);
719  CHECK_SIZES(TransposeConv2D);
720  CHECK_SIZES(FFT2d);
721  CHECK_SIZES(MatMul);
722  CHECK_SIZES(MaxPool2d);
723  CHECK_SIZES(RFFT2d);
724  // Scatter/Gather Operators
726  CHECK_SIZES(Scatter);
727  // Image Operators
728  CHECK_SIZES(Resize);
729  // Custom Operators
730  CHECK_SIZES(Custom);
731  // Control Flow Operators
732  CHECK_SIZES(While);
733  // Shape Operators
734  CHECK_SIZES(ConstShape);
735 
736 #undef CHECK_RANKS_AND_SIZES
737 #undef CHECK_SIZES
738  return true;
739 }
740 
741 // Perform the Level tensor size check on the tensor type.
742 bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
743  const StringRef operandOrResult) {
744  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
745  if (!type.hasRank()) {
746  op->emitOpError() << "failed level check: unranked tensor";
747  return false;
748  }
749  auto shape = type.getShape();
750  for (auto dim : shape) {
751  if (mlir::ShapedType::isDynamic(dim)) {
752  op->emitOpError() << "failed level check: " << operandOrResult
753  << " shape dimension cannot be dynamic";
754  return false;
755  }
756  }
757 
758  int64_t element_bits = type.getElementTypeBitWidth();
759  int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
760  int64_t size = element_bytes * type.getNumElements();
761 
762  // According to 1.11. Tensor Definitions of Tosa spec, the value of
763  // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
764  // defined in 1.7. Levels.
765  // For each tensor, the number of tensor elements multiplied by the
766  // element size in bytes must be representable as a tensor_size_t.
767  const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
768  if (size > max_size) {
769  op->emitOpError()
770  << "failed level check: " << operandOrResult
771  << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
772  return false;
773  }
774  }
775  return true;
776 }
777 
778 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
779  if (tosaLevel == TOSA_LEVEL_NONE) {
780  // no need to do level checks
781  return success();
782  }
783 
784  // check rank and sizes early so later checks can assume shaped operands
785  if (!levelCheckRanksAndSizes(op))
786  return failure();
787 
788  // additional level checks from spec 0.70
789  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
790  !levelCheckConv<tosa::Conv2DOp>(op) ||
791  !levelCheckConv<tosa::Conv3DOp>(op) ||
792  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
793  !levelCheckFFT<tosa::FFT2dOp>(op) ||
794  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
795  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
796  !levelCheckResize(op)) {
797  return failure();
798  }
799 
800  // level check MAX_TENSOR_LIST_SIZE
801  if (!levelCheckListSize(op)) {
802  return failure();
803  }
804 
805  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
806  if (!levelCheckMaxNesting(op)) {
807  return failure();
808  }
809  }
810 
811  return success();
812 }
813 
814 LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
815  if (!attributeCheckRescale(op))
816  return failure();
817  return success();
818 }
819 
820 inline bool CompatibleTypes(const mlir::Type &type,
821  const mlir::Type &declaredType) {
822  // for now, simply use type equality comparison
823  return type == declaredType;
824 }
825 
826 bool TosaValidation::CheckVariable(Operation *op) {
827  if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
828  mlir::StringAttr nameAttr = variableOp.getNameAttr();
829 
830  if (variablesMap.count(nameAttr)) {
831  op->emitOpError() << "name has already been declared";
832  return false;
833  }
834 
835  auto elementType = variableOp.getType();
836  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
837  SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
838  RankedTensorType variableType =
839  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
840 
841  variablesMap[nameAttr] = variableType;
842  }
843 
844  return true;
845 }
846 
847 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
848  if (isa<mlir::tosa::VariableReadOp>(op) ||
849  isa<mlir::tosa::VariableWriteOp>(op)) {
850  mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
851  if (!variablesMap.count(nameAttr)) {
852  op->emitOpError() << "name has not been declared";
853  return false;
854  }
855 
856  auto varType = variablesMap[nameAttr];
857 
858  for (auto v : op->getOperands()) {
859  auto type = v.getType();
860  if (!CompatibleTypes(type, varType)) {
861  op->emitOpError() << "operand type does not equal variable type";
862  return false;
863  }
864  }
865 
866  for (auto v : op->getResults()) {
867  auto type = v.getType();
868  if (!CompatibleTypes(type, varType)) {
869  op->emitOpError() << "result type does not equal variable type";
870  return false;
871  }
872  }
873  }
874 
875  return true;
876 }
877 
878 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
879  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
880  return failure();
881  }
882  return success();
883 }
884 
885 bool checkErrorIfResize(Operation *op) {
886  auto resize = dyn_cast<tosa::ResizeOp>(op);
887  if (!resize)
888  return true;
889 
890  const Value input = resize.getInput();
891  const Value output = resize.getOutput();
892  const RankedTensorType inputType =
893  llvm::dyn_cast<RankedTensorType>(input.getType());
894  const RankedTensorType outputType =
895  llvm::dyn_cast<RankedTensorType>(output.getType());
896 
897  if (!inputType || !outputType) {
898  op->emitOpError("expect ranked input/output tensor");
899  return false;
900  }
901 
902  // Ensure the image size is supported by GPU APIs and that for integer
903  // implementations, position * stride does not overflow int32_t.
904  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
905  const SmallVector<int64_t, 4> sizes = {
906  outputType.getDimSize(1), outputType.getDimSize(2),
907  inputType.getDimSize(1), inputType.getDimSize(2)};
908  const int64_t *maxDim = llvm::max_element(sizes);
909  if (maxDim != sizes.end() && *maxDim >= 16384) {
910  op->emitOpError("expect input/output height/width dims to be < 16384, ")
911  << "got [OH, OW, IH, IW] = " << sizes;
912  return false;
913  }
914  }
915 
916  SmallVector<int64_t> scale;
917  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
918  return false;
919  }
920 
921  const int64_t scaleYN = scale[0];
922  const int64_t scaleYD = scale[1];
923  const int64_t scaleXN = scale[2];
924  const int64_t scaleXD = scale[3];
925 
926  // Ensure scale values don't overflow int32 accumulator
927  if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
928  op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
929  "got scale_y_n=")
930  << scaleYN << ", scale_x_n=" << scaleXN;
931  return false;
932  }
933 
934  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
935  op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
936  << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
937  return false;
938  }
939 
940  SmallVector<int64_t> offset;
941  SmallVector<int64_t> border;
942  if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
943  !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
944  return false;
945  }
946 
947  const int64_t offsetY = offset[0];
948  const int64_t offsetX = offset[1];
949  // Set a consistent lower limit of 1/16 downscale to simplify
950  // implementations
951  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
952  op->emitOpError(
953  "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
954  << offsetY << "/" << scaleYN;
955  return false;
956  }
957  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
958  op->emitOpError(
959  "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
960  << offsetX << "/" << scaleXN;
961  return false;
962  }
963 
964  const int64_t borderY = border[0];
965  const int64_t borderX = border[1];
966  if (borderY < -16 * scaleYN || borderY >= scaleYN) {
967  op->emitOpError(
968  "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
969  << borderY << "/" << scaleYN;
970  return false;
971  }
972  if (borderX < -16 * scaleXN || borderX >= scaleXN) {
973  op->emitOpError(
974  "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
975  << borderX << "/" << scaleXN;
976  return false;
977  }
978 
979  // The following section of code is mostly duplicated with ResizeOp::verify().
980  //
981  // In TOSA specification, we do not support broadcast behavior.
982  // However, there is a rewrite pattern to materialize broadcast ResizeOp.
983  // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
984  // existing code, we keep the rewrite pattern untouched. So, we need
985  // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
986  //
987  // Here is a strict checking to conform TOSA specification.
988  // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
989  auto idivCheck = [](const int64_t lhs,
990  const int64_t rhs) -> std::optional<int64_t> {
991  if (lhs % rhs != 0)
992  return std::nullopt;
993  return lhs / rhs;
994  };
995 
996  const int64_t oh = outputType.getDimSize(1);
997  const int64_t ow = outputType.getDimSize(2);
998  const int64_t ih = inputType.getDimSize(1);
999  const int64_t iw = inputType.getDimSize(2);
1000 
1001  if (ih != ShapedType::kDynamic) {
1002  const std::optional<int64_t> calculatedOutHeightMinusOne =
1003  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1004  if (!calculatedOutHeightMinusOne.has_value()) {
1005  op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1006  "border_y ")
1007  << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
1008  << scaleYN << " - " << offsetY << " + " << borderY << ") / "
1009  << scaleYD;
1010  return false;
1011  }
1012  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1013  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
1014  op->emitOpError("calculated output height did not match expected: ")
1015  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1016  return false;
1017  }
1018  }
1019 
1020  if (iw != ShapedType::kDynamic) {
1021  const std::optional<int64_t> calculatedOutWidthMinusOne =
1022  idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1023  if (!calculatedOutWidthMinusOne.has_value()) {
1024  op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1025  "border_x ")
1026  << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
1027  << scaleXN << " - " << offsetX << " + " << borderX << ") / "
1028  << scaleXD;
1029  return false;
1030  }
1031  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1032  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1033  op->emitOpError("calculated output width did not match expected: ")
1034  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1035  return false;
1036  }
1037  }
1038 
1039  return true;
1040 }
1041 
1042 bool checkErrorIfMul(Operation *op) {
1043  auto mul = dyn_cast<tosa::MulOp>(op);
1044  if (!mul)
1045  return true;
1046 
1047  // REQUIRE(0 <= shift && shift <= 63);
1048  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1049  ElementsAttr shift_elem;
1050  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
1051  return true;
1052  }
1053  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1054  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1055  if (inputElemType.isInteger(32)) {
1056  // 0 <= shift <= 63 for int32_t type
1057  if (shift < 0 || shift > 63) {
1058  op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
1059  << shift;
1060  return false;
1061  }
1062  } else {
1063  // shift must be 0 for all other types
1064  if (shift != 0) {
1065  op->emitOpError() << "requires shift = 0 for all input data types that "
1066  "are not int32_t, but got: "
1067  << shift;
1068  return false;
1069  }
1070  }
1071 
1072  return true;
1073 }
1074 
1075 bool checkErrorIfTable(Operation *op) {
1076  auto table = dyn_cast<tosa::TableOp>(op);
1077  if (!table)
1078  return true;
1079 
1080  // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1081  const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1082  const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1083 
1084  const ShapeAdaptor tableShape(table.getTable().getType());
1085  if (tableShape.hasStaticShape()) {
1086  const auto numElements = tableShape.getNumElements();
1087  if (numElements != tableSize) {
1088  op->emitOpError() << "requires table size of " << tableSize << ", got "
1089  << numElements;
1090  return false;
1091  }
1092  }
1093 
1094  return true;
1095 }
1096 
1097 bool checkErrorIfRescale(Operation *op) {
1098  auto rescale = dyn_cast<tosa::RescaleOp>(op);
1099  if (!rescale)
1100  return true;
1101 
1102  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1103  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1104  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1105  !outputType.getElementType().isInteger())
1106  return true;
1107 
1108  auto inElemType = inputType.getElementType();
1109  auto outElemType = outputType.getElementType();
1110  auto inWidth = inElemType.getIntOrFloatBitWidth();
1111  auto outWidth = outElemType.getIntOrFloatBitWidth();
1112 
1113  bool inputUnsigned = rescale.getInputUnsigned();
1114  bool outputUnsigned = rescale.getOutputUnsigned();
1115 
1116  bool scale32 = rescale.getScale32();
1117  auto roundingMode = rescale.getRoundingMode();
1118 
1119  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1120  if (scale32 && inWidth == 48) {
1121  op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1122  return false;
1123  }
1124 
1125  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1126  if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
1127  op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1128  return false;
1129  }
1130 
1131  // ERROR_IF(input_unsigned && output_unsigned)
1132  if (inputUnsigned && outputUnsigned) {
1133  op->emitOpError() << "input and output cannot be both unsigned.";
1134  return false;
1135  }
1136 
1137  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1138  if (outWidth == 32 && inputUnsigned) {
1139  op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1140  return false;
1141  }
1142 
1143  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1144  if (inWidth == 32 && outputUnsigned) {
1145  op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1146  return false;
1147  }
1148 
1149  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1150  if (inWidth == 48 && outputUnsigned) {
1151  op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1152  return false;
1153  }
1154 
1155  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1156  if (inWidth == 48 && inputUnsigned) {
1157  op->emitOpError() << "i48 input type cannot be unsigned.";
1158  return false;
1159  }
1160 
1161  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1162  if (inWidth == 32 && inputUnsigned) {
1163  op->emitOpError() << "i32 input type cannot be unsigned.";
1164  return false;
1165  }
1166 
1167  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1168  if (outWidth == 32 && outputUnsigned) {
1169  op->emitOpError() << "i32 output type cannot be unsigned.";
1170  return false;
1171  }
1172 
1173  return true;
1174 }
1175 
1176 bool checkErrorIfPad(Operation *op) {
1177  auto pad = dyn_cast<tosa::PadOp>(op);
1178  if (!pad)
1179  return true;
1180 
1181  DenseIntElementsAttr paddingAttr;
1182  if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1183  // Pad verifier will catch this
1184  return true;
1185 
1186  for (const APInt &val : paddingAttr.getValues<APInt>()) {
1187  if (val.getSExtValue() < 0) {
1188  op->emitOpError() << "padding value must all be non-negative, got "
1189  << val.getSExtValue();
1190  return false;
1191  }
1192  }
1193 
1194  return true;
1195 }
1196 
1197 static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1198  return llvm::all_of(op->getOperands(), [&](auto operand) {
1199  Region *operandRegion = operand.getParentRegion();
1200  return operandRegion && region->isAncestor(operandRegion);
1201  });
1202 }
1203 
1204 static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
1205  bool noLiveInValue = true;
1206  regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1207  if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1208  noLiveInValue = false;
1209  return WalkResult::interrupt();
1210  }
1211  return WalkResult::advance();
1212  });
1213  return noLiveInValue;
1214 }
1215 
1216 LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1217  StringRef regionName) {
1218  if (isRegionIsolatedFromAbove(regionToCheck))
1219  return success();
1220  op->emitOpError()
1221  << "is not conformant to the TOSA specification. It requires the '"
1222  << regionName << "' region is isolated from above.\n";
1223  return failure();
1224 }
1225 
1226 bool checkErrorIfCondIf(Operation *op) {
1227  auto ifOp = dyn_cast<tosa::IfOp>(op);
1228  if (!ifOp)
1229  return true;
1230 
1231  // Currently the dialect supports declaring cond_if operations that
1232  // have then/else regions that reference values from outside these
1233  // regions. According to the specification, all values used by the
1234  // then/else regions must be explicitly declared within the regions.
1235  // Therefore we must check that the then/else regions are
1236  // "isolated from above", in order to be conformant to the
1237  // specification.
1238  //
1239  // Note: the dialect currently supports two styles of syntax for
1240  // declaring "cond_if" operations. We'll refer to these as follows:
1241  //
1242  // Generic:
1243  // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1244  // ^bb0(%arg3, %arg4):
1245  // tosa.yield %arg3
1246  // }, {
1247  // ^bb0(%arg3, %arg4):
1248  // tosa.yield %arg4
1249  // })
1250  //
1251  // Simplified:
1252  // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1253  // ^bb0(%arg3, %arg4):
1254  // tosa.yield %arg3
1255  // } else {
1256  // ^bb0(%arg3, %arg4):
1257  // tosa.yield %arg4
1258  // }
1259 
1260  return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1261  failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
1262 }
1263 
1264 bool checkErrorIfWhileLoop(Operation *op) {
1265  auto whileOp = dyn_cast<tosa::WhileOp>(op);
1266  if (!whileOp)
1267  return true;
1268 
1269  return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1270  failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
1271 }
1272 
1273 bool checkErrorIfScatter(Operation *op) {
1274  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1275  if (!scatterOp)
1276  return true;
1277 
1278  // for constant indices, check that there are no duplicate values
1279  DenseIntElementsAttr indicesAttr;
1280  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1281  return true;
1282 
1283  auto const indicesType =
1284  dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1285  if (!indicesType || !indicesType.hasRank()) {
1286  op->emitOpError("expect ranked indices tensor");
1287  return false;
1288  }
1289 
1290  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1291  op->emitOpError("indices values contain duplicates");
1292  return false;
1293  }
1294 
1295  return true;
1296 }
1297 
1298 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1299  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1300  !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1301  !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1302  !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
1303  return failure();
1304  return success();
1305 }
1306 
1307 bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1308  if (isa<FloatType>(type)) {
1309  return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1310  Float8E5M2Type>(type);
1311  }
1312  if (auto intTy = dyn_cast<IntegerType>(type)) {
1313  if (intTy.isSignless()) {
1314  switch (intTy.getWidth()) {
1315  case 1:
1316  case 4:
1317  case 8:
1318  case 16:
1319  case 32:
1320  case 48:
1321  return true;
1322  }
1323  } else if (allowUnsigned && intTy.isUnsigned()) {
1324  switch (intTy.getWidth()) {
1325  case 8:
1326  case 16:
1327  case 32:
1328  return true;
1329  }
1330  }
1331  } else if (mlir::isa<tosa::shapeType>(type)) {
1332  return true;
1333  }
1334  return false;
1335 }
1336 
1337 void TosaValidation::runOnOperation() {
1338  configLevelAndProfile();
1339 
1340  TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1341  if (!tosaDialect)
1342  return;
1343 
1344  getOperation().walk([&](Operation *op) {
1345  if (op->getDialect() != tosaDialect)
1346  return;
1347 
1348  // validate operator element types:
1349  // - rescale operator is allowed to have ui8/ui16/ui32
1350  // operands/results when strictOpSpecAlignment is false
1351  // - perform valid element type check at the beginning to
1352  // protect rest of code against quantized element types
1353  const bool allowUnsigned =
1354  !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1355  for (Value operand : op->getOperands()) {
1356  auto elementTy = getElementTypeOrSelf(operand);
1357  if (!isValidElementType(elementTy, allowUnsigned)) {
1358  op->emitOpError() << "is not profile-aligned: element type "
1359  << elementTy << " is not legal";
1360  return signalPassFailure();
1361  }
1362  }
1363  for (Type resultTy : op->getResultTypes()) {
1364  auto elementTy = getElementTypeOrSelf(resultTy);
1365  if (!isValidElementType(elementTy, allowUnsigned)) {
1366  op->emitOpError() << "is not profile-aligned: element type "
1367  << elementTy << " is not legal";
1368  return signalPassFailure();
1369  }
1370  }
1371 
1372  if (strictOpSpecAlignment &&
1373  failed(profileComp.checkProfile(op, targetEnv)))
1374  return signalPassFailure();
1375 
1376  if (strictOpSpecAlignment &&
1377  failed(profileComp.checkExtension(op, targetEnv)))
1378  return signalPassFailure();
1379 
1380  if (!allowInvalidOpDatatypeCombinations &&
1381  failed(profileComp.checkInvalid(op)))
1382  return signalPassFailure();
1383 
1384  // Some uses of TOSA rely on the constant operands of particular
1385  // operations.
1386  if (failed(applyConstantOperandCheck(op)))
1387  signalPassFailure();
1388 
1389  // do level checks
1390  if (failed(applyLevelCheck(op)))
1391  signalPassFailure();
1392 
1393  // check additional attribute restrictions
1394  if (failed(applyAttributeCheck(op)))
1395  signalPassFailure();
1396 
1397  // do variable type checks
1398  if (failed(applyVariableCheck(op)))
1399  signalPassFailure();
1400 
1401  // do error if checks
1402  if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1403  signalPassFailure();
1404  });
1405 }
1406 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:515
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
@ Gather
const float * table
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allows(Profile prof) const
Definition: TargetEnv.h:43
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2465
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
RankedTensorType getVariableType(VariableOp variableOp)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369