MLIR  21.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
18 
19 #include <string>
20 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Pass/Pass.h"
30 #include "llvm/ADT/StringExtras.h"
31 
32 namespace mlir {
33 namespace tosa {
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36 } // namespace tosa
37 } // namespace mlir
38 
39 using namespace mlir;
40 using namespace mlir::tosa;
41 
42 namespace {
43 
44 static LogicalResult
45 checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46  for (const auto index : operandIndices) {
47  Attribute attr;
48  if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
49  return op->emitOpError("expected compile time resolvable constant, but "
50  "got variable value for operand #")
51  << index;
52  }
53  }
54  return success();
55 }
56 
57 static LogicalResult checkConstantOperandMul(Operation *op,
58  const TargetEnv &env) {
59  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60  // Check 'shift'
61  return checkConstantOperands(op, {2});
62  }
63  return success();
64 }
65 
66 static LogicalResult checkConstantOperandTable(Operation *op,
67  const TargetEnv &env) {
68  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69  // Check 'table'
70  return checkConstantOperands(op, {1});
71  }
72  return success();
73 }
74 
75 static LogicalResult checkConstantOperandPad(Operation *op,
76  const TargetEnv &env) {
77  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
78  // Assume this op is zero-padding if padConst is not presented
79  if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80  // Check 'pad_const'
81  // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82  return checkConstantOperands(op, {2});
83  }
84  return success();
85 }
86 
87 static LogicalResult checkConstantOperandRescale(Operation *op,
88  const TargetEnv &env) {
89  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90  // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91  return checkConstantOperands(op, {1, 2, 3, 4});
92  }
93  return success();
94 }
95 
96 template <typename T>
97 static LogicalResult checkConstantOperandConvOps(Operation *op,
98  const TargetEnv &env) {
99  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100  // Check 'input_zp' and 'weight_zp'
101  return checkConstantOperands(op, {3, 4});
102  }
103  return success();
104 }
105 
106 static LogicalResult checkConstantOperandMatMul(Operation *op,
107  const TargetEnv &env) {
108  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109  // Check 'A_zp' and 'B_zp'
110  return checkConstantOperands(op, {2, 3});
111  }
112  return success();
113 }
114 
115 static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116  const TargetEnv &env) {
117  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118  // Check 'input_zp' and 'output_zp'
119  return checkConstantOperands(op, {1, 2});
120  }
121  return success();
122 }
123 
124 static LogicalResult checkConstantOperandNegate(Operation *op,
125  const TargetEnv &env) {
126  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127  // Check 'input1_zp' and 'output_zp'
128  return checkConstantOperands(op, {1, 2});
129  }
130  return success();
131 }
132 
133 struct TosaLevel {
134  int32_t MAX_RANK = 0;
135  int32_t MAX_KERNEL = 0;
136  int32_t MAX_STRIDE = 0;
137  int32_t MAX_SCALE = 0;
138  int32_t MAX_LOG2_SIZE = 0;
139  int32_t MAX_NESTING = 0;
140  int32_t MAX_TENSOR_LIST_SIZE = 0;
141 
142  bool operator==(const TosaLevel &rhs) {
143  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145  MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146  MAX_NESTING == rhs.MAX_NESTING &&
147  MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
148  }
149 };
150 
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
153  63, 256, 256};
154 
155 //===----------------------------------------------------------------------===//
156 // TOSA Validation Pass.
157 //===----------------------------------------------------------------------===//
158 
159 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
160 public:
161  explicit TosaValidation() { populateConstantOperandChecks(); }
162 
163  explicit TosaValidation(const TosaValidationOptions &options)
164  : TosaValidation() {
165  this->profile = options.profile;
166  this->extension = options.extension;
167  this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168  this->allowInvalidOpDatatypeCombinations =
169  options.allowInvalidOpDatatypeCombinations;
170  this->level = options.level;
171  }
172  void runOnOperation() final;
173 
174  LogicalResult applyConstantOperandCheck(Operation *op) {
175  for (auto &checker : constCheckers) {
176  if (failed(checker(op, targetEnv)))
177  return failure();
178  }
179  return success();
180  }
181 
182  LogicalResult applyLevelCheck(Operation *op);
183  LogicalResult applyAttributeCheck(Operation *op);
184 
185  // check variable read/write data types against variable declarations
186  LogicalResult applyVariableCheck(Operation *op);
187 
188  // check error if conditions
189  LogicalResult applyErrorIfCheck(Operation *op);
190 
191 private:
192  void populateConstantOperandChecks() {
193  constCheckers.emplace_back(checkConstantOperandMul);
194  constCheckers.emplace_back(checkConstantOperandTable);
195  constCheckers.emplace_back(checkConstantOperandPad);
196  constCheckers.emplace_back(checkConstantOperandRescale);
197  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199  constCheckers.emplace_back(
200  checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201  constCheckers.emplace_back(
202  checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203  constCheckers.emplace_back(checkConstantOperandMatMul);
204  constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205  constCheckers.emplace_back(checkConstantOperandNegate);
206  }
207 
208  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
209  if (v > tosaLevel.MAX_KERNEL) {
210  op->emitOpError() << "failed level check: " << checkDesc;
211  return false;
212  }
213  return true;
214  }
215 
216  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
217  if (v > tosaLevel.MAX_STRIDE) {
218  op->emitOpError() << "failed level check: " << checkDesc;
219  return false;
220  }
221  return true;
222  }
223 
224  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
225  if (v > tosaLevel.MAX_SCALE) {
226  op->emitOpError() << "failed level check: " << checkDesc;
227  return false;
228  }
229  return true;
230  }
231 
232  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
233  if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234  op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
235  << checkDesc;
236  return false;
237  }
238  return true;
239  }
240 
241  // Perform the Level Rank check on the tensor type.
242  bool levelCheckRank(Operation *op, const Type typeToCheck,
243  const StringRef operandOrResult, int32_t highest_rank) {
244  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
245  if (!type.hasRank()) {
246  op->emitOpError() << "failed level check: unranked tensor";
247  return false;
248  }
249  if (type.getRank() > highest_rank) {
250  op->emitOpError() << "failed level check: " << operandOrResult
251  << " rank(shape) <= MAX_RANK";
252  return false;
253  }
254  }
255  return true;
256  }
257 
258  // Perform the Level Rank check on the tensor value.
259  bool levelCheckRank(Operation *op, const Value &v,
260  const StringRef operandOrResult, int32_t highest_rank) {
261  return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
262  }
263 
264  // Perform the Level tensor size check on the tensor type.
265  bool levelCheckSize(Operation *op, const Type &typeToCheck,
266  const StringRef operandOrResult);
267 
268  // Perform the Level tensor size check on the tensor value.
269  bool levelCheckSize(Operation *op, const Value &v,
270  const StringRef operandOrResult) {
271  return levelCheckSize(op, v.getType(), operandOrResult);
272  }
273 
274  // Level check sizes of all operands and results of the operation.
275  template <typename T>
276  bool levelCheckSizes(T tosaOp) {
277  auto op = tosaOp.getOperation();
278  for (auto v : op->getOperands()) {
279  if (!levelCheckSize(op, v, "operand"))
280  return false;
281  }
282 
283  for (auto v : op->getResults()) {
284  if (!levelCheckSize(op, v, "result"))
285  return false;
286  }
287  return true;
288  }
289 
290  // Level check ranks of all operands, attribute and results of the operation.
291  template <typename T>
292  bool levelCheckRanks(T tosaOp) {
293  auto op = tosaOp.getOperation();
294  for (auto v : op->getOperands()) {
295  if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
296  return false;
297  }
298 
299  for (auto v : op->getResults()) {
300  if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
301  return false;
302  }
303  return true;
304  }
305 
306  // Level check ranks and sizes.
307  bool levelCheckRanksAndSizes(Operation *op);
308 
309  // Pool Op: level check kernel/stride/pad values
310  template <typename T>
311  bool levelCheckPool(Operation *op) {
312  if (auto poolOp = dyn_cast<T>(op)) {
313  for (auto k : poolOp.getKernel()) {
314  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
315  return false;
316  }
317  }
318  for (auto s : poolOp.getStride()) {
319  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
320  return false;
321  }
322  }
323  for (auto p : poolOp.getPad()) {
324  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
325  return false;
326  }
327  }
328  }
329  return true;
330  }
331 
332  // Conv Op: level check dilation/stride/pad values
333  template <typename T>
334  bool levelCheckConv(Operation *op) {
335  if (auto convOp = dyn_cast<T>(op)) {
336 
337  for (auto k : convOp.getDilation()) {
338  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
339  return false;
340  }
341  }
342  for (auto p : convOp.getPad()) {
343  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
344  return false;
345  }
346  }
347  for (auto s : convOp.getStride()) {
348  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
349  return false;
350  }
351  }
352  auto dilation = convOp.getDilation();
353  if (ShapedType weightType =
354  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
355  auto shape = weightType.getShape();
356  if (isa<tosa::Conv2DOp>(op)) {
357  assert(shape.size() == 4);
358  assert(dilation.size() == 2);
359  if (!levelCheckKernel(op, dilation[0] * shape[1],
360  "dilation_y * KH <= MAX_KERNEL)") ||
361  !levelCheckKernel(op, dilation[1] * shape[2],
362  "dilation_x * KW <= MAX_KERNEL)"))
363  return false;
364  } else if (isa<tosa::Conv3DOp>(op)) {
365  assert(shape.size() == 5);
366  assert(dilation.size() == 3);
367  if (!levelCheckKernel(op, dilation[0] * shape[1],
368  "dilation_d * KD <= MAX_KERNEL)") ||
369  !levelCheckKernel(op, dilation[1] * shape[2],
370  "dilation_y * KH <= MAX_KERNEL)") ||
371  !levelCheckKernel(op, dilation[2] * shape[3],
372  "dilation_x * KW <= MAX_KERNEL)"))
373  return false;
374  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
375  assert(shape.size() == 4);
376  assert(dilation.size() == 2);
377  if (!levelCheckKernel(op, dilation[0] * shape[0],
378  "dilation_y * KH <= MAX_KERNEL)") ||
379  !levelCheckKernel(op, dilation[1] * shape[1],
380  "dilation_x * KW <= MAX_KERNEL)"))
381  return false;
382  }
383  }
384  }
385  return true;
386  }
387 
388  // FFT op: level check H, W in input shape [N,H,W]
389  template <typename T>
390  bool levelCheckFFT(Operation *op) {
391  if (isa<T>(op)) {
392  for (auto v : op->getOperands()) {
393  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
394  auto shape = type.getShape();
395  assert(shape.size() == 3);
396  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
397  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
398  return false;
399  }
400  }
401  }
402  }
403  return true;
404  }
405 
406  // TransposeConv2d op: level check kH/kW, outpad, and stride
407  bool levelCheckTransposeConv2d(Operation *op) {
408  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
409  if (ShapedType filterType =
410  dyn_cast<ShapedType>(transpose.getWeight().getType())) {
411  auto shape = filterType.getShape();
412  assert(shape.size() == 4);
413  // level check kernel sizes for kH and KW
414  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
415  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
416  return false;
417  }
418  }
419  for (auto p : transpose.getOutPad()) {
420  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
421  return false;
422  }
423  }
424  for (auto s : transpose.getStride()) {
425  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
426  return false;
427  }
428  }
429  }
430  return true;
431  }
432 
433  // Resize op: level check max scales
434  bool levelCheckResize(Operation *op) {
435  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
436  SmallVector<int64_t> scale;
437  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
438  scale)) {
439  return false;
440  }
441  const int64_t scaleYN = scale[0];
442  const int64_t scaleYD = scale[1];
443  const int64_t scaleXN = scale[2];
444  const int64_t scaleXD = scale[3];
445  if (!levelCheckScale(op, scaleYN / scaleYD,
446  "scale_y_n/scale_y_d <= MAX_SCALE") ||
447  !levelCheckScale(op, scaleXN / scaleXD,
448  "scale_x_n/scale_x_d <= MAX_SCALE")) {
449  return false;
450  }
451  }
452  return true;
453  }
454 
455  // Recursively perform a bottom-up search to determine the maximum nesting
456  // depth, starting from a specific operation and continuing up to the function
457  // or module scope. Tosa nesting_depth starts at 0 and increments by one each
458  // time a new nested `region` is encountered.
459  static void getMaxNestedDepth(Operation *op, int32_t &depth) {
460  if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
461  return;
462 
463  op = op->getParentOp();
464  if (!op)
465  return;
466 
467  depth++;
468  getMaxNestedDepth(op, depth);
469  }
470 
471  bool levelCheckMaxNesting(Operation *op) {
472  int32_t maxNestedDepth = 0;
473  getMaxNestedDepth(op, maxNestedDepth);
474 
475  if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
476  op->emitOpError() << "failed level check: " << maxNestedDepth
477  << " >= MAX_NESTING";
478  return false;
479  }
480  return true;
481  }
482 
483  bool levelCheckListSize(Operation *op) {
484  if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
485  return levelCheckListSize(op, concat.getInput1().size(), "input1");
486  }
487  if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
488  if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
489  !levelCheckListSize(op, custom.getOutputList().size(),
490  "output_list")) {
491  return false;
492  }
493  }
494  if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
495  if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
496  !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
497  return false;
498  }
499  }
500  if (auto w = dyn_cast<tosa::WhileOp>(op)) {
501  if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
502  !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
503  return false;
504  }
505  }
506  return true;
507  }
508 
509  bool attributeCheckRescale(Operation *op) {
510  if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
511  if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
512  !targetEnv.allows(Extension::doubleround)) {
513  op->emitOpError()
514  << "failed attribute check: rounding_mode = DOUBLE_ROUND "
515  << "requires extension [doubleround]";
516  return false;
517  } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
518  !targetEnv.allows(Extension::inexactround)) {
519  op->emitOpError()
520  << "failed attribute check: rounding_mode = INEXACT_ROUND "
521  << "requires extension [inexactround]";
522  return false;
523  }
524  }
525  return true;
526  }
527 
528  // configure profile and level values from pass options profileName and
529  // levelName
530  void configLevelAndProfile() {
531  tosaLevel = TOSA_LEVEL_NONE;
532  if (level == TosaLevelEnum::EightK) {
533  tosaLevel = TOSA_LEVEL_EIGHTK;
534  }
535 
536  if (!profile.empty()) {
537  for (std::string &prof : profile) {
538  auto profSymbol = symbolizeProfile(prof);
539  if (profSymbol) {
540  targetEnv.addProfile(profSymbol.value());
541  } else {
542  llvm::errs() << "unknown TOSA profile name passed in: " << prof
543  << ", supported profiles are `pro_int` and `pro_fp`\n";
544  return signalPassFailure();
545  }
546  }
547  }
548 
549  if (!extension.empty()) {
550  for (std::string &ext : extension) {
551  auto extSymbol = symbolizeExtension(ext);
552  if (extSymbol) {
553  targetEnv.addExtension(extSymbol.value());
554  } else {
555  llvm::errs() << "unknown TOSA extension name passed in: " << ext
556  << ", supported extension are int16, int4, bf16, "
557  << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
558  << "doubleround, inexactround and dynamic\n";
559  return signalPassFailure();
560  }
561  }
562  }
563  }
564 
565  bool CheckVariable(Operation *op);
566  bool CheckVariableReadOrWrite(Operation *op);
567  bool isValidElementType(Type type, const bool allowUnsigned = false);
568 
569  SmallVector<
570  std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
571  constCheckers;
572  TosaLevel tosaLevel;
574  TosaProfileCompliance profileComp;
575  tosa::TargetEnv targetEnv;
576 };
577 
578 template <>
579 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
580  auto op = tosaOp.getOperation();
581  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
582  return false;
583 
584  // rank(output) = rank(input) - 1
585  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
586  return false;
587 
588  return true;
589 }
590 
591 template <>
592 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
593  auto op = tosaOp.getOperation();
594 
595  // Only the condition input has rank limitation.
596  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
597  return false;
598 
599  return true;
600 }
601 
602 template <>
603 bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
604  auto op = tosaOp.getOperation();
605  auto variableType = getVariableType(tosaOp);
606  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
607  return false;
608 
609  return true;
610 }
611 
612 template <>
613 bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
614  auto op = tosaOp.getOperation();
615  auto variableType = getVariableType(tosaOp);
616  if (!levelCheckSize(op, variableType, "variable type"))
617  return false;
618 
619  return true;
620 }
621 
622 bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
623 #define CHECK_RANKS_AND_SIZES(tosaOp) \
624  if (isa<tosa::tosaOp##Op>(op)) { \
625  if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
626  return false; \
627  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
628  return false; \
629  }
630 
631 #define CHECK_SIZES(tosaOp) \
632  if (isa<tosa::tosaOp##Op>(op)) { \
633  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
634  return false; \
635  }
636 
637  // Tensor Operators
638  CHECK_RANKS_AND_SIZES(ArgMax);
639  // Activation Functions
640  CHECK_RANKS_AND_SIZES(Clamp);
642  CHECK_RANKS_AND_SIZES(Sigmoid);
643  CHECK_RANKS_AND_SIZES(Tanh);
644  // Elementwise Binary Operators
646  CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
647  CHECK_RANKS_AND_SIZES(BitwiseAnd);
648  CHECK_RANKS_AND_SIZES(BitwiseOr);
649  CHECK_RANKS_AND_SIZES(BitwiseXor);
650  CHECK_RANKS_AND_SIZES(IntDiv);
651  CHECK_RANKS_AND_SIZES(LogicalAnd);
652  CHECK_RANKS_AND_SIZES(LogicalLeftShift);
653  CHECK_RANKS_AND_SIZES(LogicalRightShift);
654  CHECK_RANKS_AND_SIZES(LogicalOr);
655  CHECK_RANKS_AND_SIZES(LogicalXor);
656  CHECK_RANKS_AND_SIZES(Maximum);
657  CHECK_RANKS_AND_SIZES(Minimum);
661  CHECK_RANKS_AND_SIZES(Table);
662  // Elementwise Unary Operators
664  CHECK_RANKS_AND_SIZES(BitwiseNot);
665  CHECK_RANKS_AND_SIZES(Ceil);
669  CHECK_RANKS_AND_SIZES(Floor);
671  CHECK_RANKS_AND_SIZES(LogicalNot);
672  CHECK_RANKS_AND_SIZES(Negate);
673  CHECK_RANKS_AND_SIZES(Reciprocal);
674  CHECK_RANKS_AND_SIZES(Rsqrt);
676  // Elementwise Ternary Operators
677  CHECK_RANKS_AND_SIZES(Select);
678  // Comparison Operators
679  CHECK_RANKS_AND_SIZES(Equal);
680  CHECK_RANKS_AND_SIZES(Greater);
681  CHECK_RANKS_AND_SIZES(GreaterEqual);
682  // Reduction Operators
683  CHECK_RANKS_AND_SIZES(ReduceAll);
684  CHECK_RANKS_AND_SIZES(ReduceAny);
685  CHECK_RANKS_AND_SIZES(ReduceMax);
686  CHECK_RANKS_AND_SIZES(ReduceMin);
687  CHECK_RANKS_AND_SIZES(ReduceProduct);
688  CHECK_RANKS_AND_SIZES(ReduceSum);
689  // Data Layout Operators
690  CHECK_RANKS_AND_SIZES(Concat);
692  CHECK_RANKS_AND_SIZES(Reshape);
693  CHECK_RANKS_AND_SIZES(Reverse);
694  CHECK_RANKS_AND_SIZES(Slice);
695  CHECK_RANKS_AND_SIZES(Tile);
696  CHECK_RANKS_AND_SIZES(Transpose);
697  // Type Conversion
698  CHECK_RANKS_AND_SIZES(Cast);
699  CHECK_RANKS_AND_SIZES(Rescale);
700  // Control Flow Operators
702  // Variable Operators
703  CHECK_RANKS_AND_SIZES(Variable);
704  CHECK_RANKS_AND_SIZES(VariableWrite);
705  CHECK_RANKS_AND_SIZES(VariableRead);
706  // Data Nodes
707  CHECK_RANKS_AND_SIZES(Const);
708  CHECK_RANKS_AND_SIZES(Identity);
709 
710  // For the following operators, check whether the size of each tensor
711  // operand is valid in a given Level.
712 
713  // Tensor Operators
714  CHECK_SIZES(AvgPool2d);
715  CHECK_SIZES(Conv2D);
716  CHECK_SIZES(Conv3D);
717  CHECK_SIZES(DepthwiseConv2D);
718  CHECK_SIZES(TransposeConv2D);
719  CHECK_SIZES(FFT2d);
720  CHECK_SIZES(MatMul);
721  CHECK_SIZES(MaxPool2d);
722  CHECK_SIZES(RFFT2d);
723  // Scatter/Gather Operators
725  CHECK_SIZES(Scatter);
726  // Image Operators
727  CHECK_SIZES(Resize);
728  // Custom Operators
729  CHECK_SIZES(Custom);
730  // Control Flow Operators
731  CHECK_SIZES(While);
732  // Shape Operators
733  CHECK_SIZES(ConstShape);
734 
735 #undef CHECK_RANKS_AND_SIZES
736 #undef CHECK_SIZES
737  return true;
738 }
739 
740 // Perform the Level tensor size check on the tensor type.
741 bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
742  const StringRef operandOrResult) {
743  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
744  if (!type.hasRank()) {
745  op->emitOpError() << "failed level check: unranked tensor";
746  return false;
747  }
748  auto shape = type.getShape();
749  for (auto dim : shape) {
750  if (mlir::ShapedType::isDynamic(dim)) {
751  op->emitOpError() << "failed level check: " << operandOrResult
752  << " shape dimension cannot be dynamic";
753  return false;
754  }
755  }
756 
757  int64_t element_bits = type.getElementTypeBitWidth();
758  int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
759  int64_t size = element_bytes * type.getNumElements();
760 
761  // According to 1.11. Tensor Definitions of Tosa spec, the value of
762  // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
763  // defined in 1.7. Levels.
764  // For each tensor, the number of tensor elements multiplied by the
765  // element size in bytes must be representable as a tensor_size_t.
766  const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
767  if (size > max_size) {
768  op->emitOpError()
769  << "failed level check: " << operandOrResult
770  << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
771  return false;
772  }
773  }
774  return true;
775 }
776 
777 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
778  if (tosaLevel == TOSA_LEVEL_NONE) {
779  // no need to do level checks
780  return success();
781  }
782 
783  // check rank and sizes early so later checks can assume shaped operands
784  if (!levelCheckRanksAndSizes(op))
785  return failure();
786 
787  // additional level checks from spec 0.70
788  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
789  !levelCheckConv<tosa::Conv2DOp>(op) ||
790  !levelCheckConv<tosa::Conv3DOp>(op) ||
791  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
792  !levelCheckFFT<tosa::FFT2dOp>(op) ||
793  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
794  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
795  !levelCheckResize(op)) {
796  return failure();
797  }
798 
799  // level check MAX_TENSOR_LIST_SIZE
800  if (!levelCheckListSize(op)) {
801  return failure();
802  }
803 
804  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
805  if (!levelCheckMaxNesting(op)) {
806  return failure();
807  }
808  }
809 
810  return success();
811 }
812 
813 LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
814  if (!attributeCheckRescale(op))
815  return failure();
816  return success();
817 }
818 
819 inline bool CompatibleTypes(const mlir::Type &type,
820  const mlir::Type &declaredType) {
821  // for now, simply use type equality comparison
822  return type == declaredType;
823 }
824 
825 bool TosaValidation::CheckVariable(Operation *op) {
826  if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
827  mlir::StringAttr nameAttr = variableOp.getNameAttr();
828 
829  if (variablesMap.count(nameAttr)) {
830  op->emitOpError() << "name has already been declared";
831  return false;
832  }
833 
834  auto elementType = variableOp.getType();
835  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
836  SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
837  RankedTensorType variableType =
838  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
839 
840  variablesMap[nameAttr] = variableType;
841  }
842 
843  return true;
844 }
845 
846 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
847  if (isa<mlir::tosa::VariableReadOp>(op) ||
848  isa<mlir::tosa::VariableWriteOp>(op)) {
849  mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
850  if (!variablesMap.count(nameAttr)) {
851  op->emitOpError() << "name has not been declared";
852  return false;
853  }
854 
855  auto varType = variablesMap[nameAttr];
856 
857  for (auto v : op->getOperands()) {
858  auto type = v.getType();
859  if (!CompatibleTypes(type, varType)) {
860  op->emitOpError() << "operand type does not equal variable type";
861  return false;
862  }
863  }
864 
865  for (auto v : op->getResults()) {
866  auto type = v.getType();
867  if (!CompatibleTypes(type, varType)) {
868  op->emitOpError() << "result type does not equal variable type";
869  return false;
870  }
871  }
872  }
873 
874  return true;
875 }
876 
877 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
878  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
879  return failure();
880  }
881  return success();
882 }
883 
884 bool checkErrorIfResize(Operation *op) {
885  auto resize = dyn_cast<tosa::ResizeOp>(op);
886  if (!resize)
887  return true;
888 
889  const Value input = resize.getInput();
890  const Value output = resize.getOutput();
891  const RankedTensorType inputType =
892  llvm::dyn_cast<RankedTensorType>(input.getType());
893  const RankedTensorType outputType =
894  llvm::dyn_cast<RankedTensorType>(output.getType());
895 
896  if (!inputType || !outputType) {
897  op->emitOpError("expect ranked input/output tensor");
898  return false;
899  }
900 
901  // Ensure the image size is supported by GPU APIs and that for integer
902  // implementations, position * stride does not overflow int32_t.
903  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
904  const SmallVector<int64_t, 4> sizes = {
905  outputType.getDimSize(1), outputType.getDimSize(2),
906  inputType.getDimSize(1), inputType.getDimSize(2)};
907  const int64_t *maxDim = llvm::max_element(sizes);
908  if (maxDim != sizes.end() && *maxDim >= 16384) {
909  op->emitOpError("expect input/output height/width dims to be < 16384, ")
910  << "got [OH, OW, IH, IW] = " << sizes;
911  return false;
912  }
913  }
914 
915  SmallVector<int64_t> scale;
916  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
917  return false;
918  }
919 
920  const int64_t scaleYN = scale[0];
921  const int64_t scaleYD = scale[1];
922  const int64_t scaleXN = scale[2];
923  const int64_t scaleXD = scale[3];
924 
925  // Ensure scale values don't overflow int32 accumulator
926  if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
927  op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
928  "got scale_y_n=")
929  << scaleYN << ", scale_x_n=" << scaleXN;
930  return false;
931  }
932 
933  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
934  op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
935  << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
936  return false;
937  }
938 
939  SmallVector<int64_t> offset;
940  SmallVector<int64_t> border;
941  if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
942  !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
943  return false;
944  }
945 
946  const int64_t offsetY = offset[0];
947  const int64_t offsetX = offset[1];
948  // Set a consistent lower limit of 1/16 downscale to simplify
949  // implementations
950  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
951  op->emitOpError(
952  "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
953  << offsetY << "/" << scaleYN;
954  return false;
955  }
956  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
957  op->emitOpError(
958  "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
959  << offsetX << "/" << scaleXN;
960  return false;
961  }
962 
963  const int64_t borderY = border[0];
964  const int64_t borderX = border[1];
965  if (borderY < -16 * scaleYN || borderY >= scaleYN) {
966  op->emitOpError(
967  "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
968  << borderY << "/" << scaleYN;
969  return false;
970  }
971  if (borderX < -16 * scaleXN || borderX >= scaleXN) {
972  op->emitOpError(
973  "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
974  << borderX << "/" << scaleXN;
975  return false;
976  }
977 
978  // The following section of code is mostly duplicated with ResizeOp::verify().
979  //
980  // In TOSA specification, we do not support broadcast behavior.
981  // However, there is a rewrite pattern to materialize broadcast ResizeOp.
982  // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
983  // existing code, we keep the rewrite pattern untouched. So, we need
984  // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
985  //
986  // Here is a strict checking to conform TOSA specification.
987  // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
988  auto idivCheck = [](const int64_t lhs,
989  const int64_t rhs) -> std::optional<int64_t> {
990  if (lhs % rhs != 0)
991  return std::nullopt;
992  return lhs / rhs;
993  };
994 
995  const int64_t oh = outputType.getDimSize(1);
996  const int64_t ow = outputType.getDimSize(2);
997  const int64_t ih = inputType.getDimSize(1);
998  const int64_t iw = inputType.getDimSize(2);
999 
1000  if (ih != ShapedType::kDynamic) {
1001  const std::optional<int64_t> calculatedOutHeightMinusOne =
1002  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1003  if (!calculatedOutHeightMinusOne.has_value()) {
1004  op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1005  "border_y ")
1006  << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
1007  << scaleYN << " - " << offsetY << " + " << borderY << ") / "
1008  << scaleYD;
1009  return false;
1010  }
1011  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1012  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
1013  op->emitOpError("calculated output height did not match expected: ")
1014  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
1015  return false;
1016  }
1017  }
1018 
1019  if (iw != ShapedType::kDynamic) {
1020  const std::optional<int64_t> calculatedOutWidthMinusOne =
1021  idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
1022  if (!calculatedOutWidthMinusOne.has_value()) {
1023  op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1024  "border_x ")
1025  << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
1026  << scaleXN << " - " << offsetX << " + " << borderX << ") / "
1027  << scaleXD;
1028  return false;
1029  }
1030  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1031  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1032  op->emitOpError("calculated output width did not match expected: ")
1033  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1034  return false;
1035  }
1036  }
1037 
1038  return true;
1039 }
1040 
1041 bool checkErrorIfMul(Operation *op) {
1042  auto mul = dyn_cast<tosa::MulOp>(op);
1043  if (!mul)
1044  return true;
1045 
1046  // REQUIRE(0 <= shift && shift <= 63);
1047  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1048  ElementsAttr shift_elem;
1049  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
1050  return true;
1051  }
1052  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1053  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1054  if (inputElemType.isInteger(32)) {
1055  // 0 <= shift <= 63 for int32_t type
1056  if (shift < 0 || shift > 63) {
1057  op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
1058  << shift;
1059  return false;
1060  }
1061  } else {
1062  // shift must be 0 for all other types
1063  if (shift != 0) {
1064  op->emitOpError() << "requires shift = 0 for all input data types that "
1065  "are not int32_t, but got: "
1066  << shift;
1067  return false;
1068  }
1069  }
1070 
1071  return true;
1072 }
1073 
1074 bool checkErrorIfTable(Operation *op) {
1075  auto table = dyn_cast<tosa::TableOp>(op);
1076  if (!table)
1077  return true;
1078 
1079  // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1080  const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1081  const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1082 
1083  const ShapeAdaptor tableShape(table.getTable().getType());
1084  if (tableShape.hasStaticShape()) {
1085  const auto numElements = tableShape.getNumElements();
1086  if (numElements != tableSize) {
1087  op->emitOpError() << "requires table size of " << tableSize << ", got "
1088  << numElements;
1089  return false;
1090  }
1091  }
1092 
1093  return true;
1094 }
1095 
1096 bool checkErrorIfRescale(Operation *op) {
1097  auto rescale = dyn_cast<tosa::RescaleOp>(op);
1098  if (!rescale)
1099  return true;
1100 
1101  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1102  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1103  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1104  !outputType.getElementType().isInteger())
1105  return true;
1106 
1107  auto inElemType = inputType.getElementType();
1108  auto outElemType = outputType.getElementType();
1109  auto inWidth = inElemType.getIntOrFloatBitWidth();
1110  auto outWidth = outElemType.getIntOrFloatBitWidth();
1111 
1112  bool inputUnsigned = rescale.getInputUnsigned();
1113  bool outputUnsigned = rescale.getOutputUnsigned();
1114 
1115  bool scale32 = rescale.getScale32();
1116  auto roundingMode = rescale.getRoundingMode();
1117 
1118  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1119  if (scale32 && inWidth == 48) {
1120  op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1121  return false;
1122  }
1123 
1124  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1125  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
1126  op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1127  return false;
1128  }
1129 
1130  // ERROR_IF(input_unsigned && output_unsigned)
1131  if (inputUnsigned && outputUnsigned) {
1132  op->emitOpError() << "input and output cannot be both unsigned.";
1133  return false;
1134  }
1135 
1136  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1137  if (outWidth == 32 && inputUnsigned) {
1138  op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1139  return false;
1140  }
1141 
1142  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1143  if (inWidth == 32 && outputUnsigned) {
1144  op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1145  return false;
1146  }
1147 
1148  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1149  if (inWidth == 48 && outputUnsigned) {
1150  op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1151  return false;
1152  }
1153 
1154  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1155  if (inWidth == 48 && inputUnsigned) {
1156  op->emitOpError() << "i48 input type cannot be unsigned.";
1157  return false;
1158  }
1159 
1160  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1161  if (inWidth == 32 && inputUnsigned) {
1162  op->emitOpError() << "i32 input type cannot be unsigned.";
1163  return false;
1164  }
1165 
1166  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1167  if (outWidth == 32 && outputUnsigned) {
1168  op->emitOpError() << "i32 output type cannot be unsigned.";
1169  return false;
1170  }
1171 
1172  return true;
1173 }
1174 
1175 bool checkErrorIfPad(Operation *op) {
1176  auto pad = dyn_cast<tosa::PadOp>(op);
1177  if (!pad)
1178  return true;
1179 
1180  DenseIntElementsAttr paddingAttr;
1181  if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1182  // Pad verifier will catch this
1183  return true;
1184 
1185  for (const APInt &val : paddingAttr.getValues<APInt>()) {
1186  if (val.getSExtValue() < 0) {
1187  op->emitOpError() << "padding value must all be non-negative, got "
1188  << val.getSExtValue();
1189  return false;
1190  }
1191  }
1192 
1193  return true;
1194 }
1195 
1196 // Returns true if the operation takes no input operands, excluding attributes.
1197 static bool isNullaryOperation(Operation *op) {
1198  if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
1199  isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
1200  return true;
1201  return false;
1202 }
1203 
1204 bool checkErrorIfCondIf(Operation *op) {
1205  auto ifOp = dyn_cast<tosa::IfOp>(op);
1206  if (!ifOp)
1207  return true;
1208 
1209  // Whether the types and shapes of operands between the input/output list and
1210  // internal regions are validated by the operation verifier. However, with
1211  // support for the simplified form - where redundant operand notations are
1212  // omitted - is not conformant to the specification. According to the
1213  // specification, all operands passed into an operation must be explicitly
1214  // declared at each operation's structure. This code section verify that the
1215  // operation's form complies with this requirement.
1216 
1217  // Returns true if the region uses no external input operands.
1218  auto isNullaryRegion = [](Region &region) -> bool {
1219  bool noLiveInValue = true;
1220  region.walk([&noLiveInValue](Operation *op) {
1221  if (!isNullaryOperation(op)) {
1222  noLiveInValue = false;
1223  return WalkResult::interrupt();
1224  }
1225  return WalkResult::advance();
1226  });
1227  return noLiveInValue;
1228  };
1229 
1230  mlir::Region &thenGraph = ifOp.getThenGraph();
1231  mlir::Region &elseGraph = ifOp.getElseGraph();
1232  bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
1233  bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
1234  bool isInputListEmpty = ifOp.getInputList().size() == 0;
1235 
1236  if ((isInputListEmpty != isThenGraphNullaryRegion) ||
1237  (isInputListEmpty != isElseGraphNullaryRegion)) {
1238  op->emitOpError()
1239  << "the current simplified form is not strictly conformant to the "
1240  "spec, please use the generic format\n";
1241  return false;
1242  }
1243 
1244  return true;
1245 }
1246 
1247 bool checkErrorIfScatter(Operation *op) {
1248  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1249  if (!scatterOp)
1250  return true;
1251 
1252  // for constant indices, check that there are no duplicate values
1253  DenseIntElementsAttr indicesAttr;
1254  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1255  return true;
1256 
1257  auto const indicesType =
1258  dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1259  if (!indicesType || !indicesType.hasRank()) {
1260  op->emitOpError("expect ranked indices tensor");
1261  return false;
1262  }
1263 
1264  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1265  op->emitOpError("indices values contain duplicates");
1266  return false;
1267  }
1268 
1269  return true;
1270 }
1271 
1272 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1273  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1274  !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1275  !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1276  !checkErrorIfScatter(op))
1277  return failure();
1278  return success();
1279 }
1280 
1281 bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1282  if (isa<FloatType>(type)) {
1283  return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1284  Float8E5M2Type>(type);
1285  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1286  if (intTy.isSignless()) {
1287  switch (intTy.getWidth()) {
1288  case 1:
1289  case 4:
1290  case 8:
1291  case 16:
1292  case 32:
1293  case 48:
1294  return true;
1295  }
1296  } else if (allowUnsigned && intTy.isUnsigned()) {
1297  switch (intTy.getWidth()) {
1298  case 8:
1299  case 16:
1300  case 32:
1301  return true;
1302  }
1303  }
1304  } else if (mlir::isa<tosa::shapeType>(type)) {
1305  return true;
1306  }
1307  return false;
1308 }
1309 
1310 void TosaValidation::runOnOperation() {
1311  configLevelAndProfile();
1312 
1313  TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1314  if (!tosaDialect)
1315  return;
1316 
1317  getOperation().walk([&](Operation *op) {
1318  if (op->getDialect() != tosaDialect)
1319  return;
1320 
1321  // validate operator element types:
1322  // - rescale operator is allowed to have ui8/ui16/ui32
1323  // operands/results when strictOpSpecAlignment is false
1324  // - perform valid element type check at the beginning to
1325  // protect rest of code against quantized element types
1326  const bool allowUnsigned =
1327  !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1328  for (Value operand : op->getOperands()) {
1329  auto elementTy = getElementTypeOrSelf(operand);
1330  if (!isValidElementType(elementTy, allowUnsigned)) {
1331  op->emitOpError() << "is not profile-aligned: element type "
1332  << elementTy << " is not legal";
1333  return signalPassFailure();
1334  }
1335  }
1336  for (Type resultTy : op->getResultTypes()) {
1337  auto elementTy = getElementTypeOrSelf(resultTy);
1338  if (!isValidElementType(elementTy, allowUnsigned)) {
1339  op->emitOpError() << "is not profile-aligned: element type "
1340  << elementTy << " is not legal";
1341  return signalPassFailure();
1342  }
1343  }
1344 
1345  if (strictOpSpecAlignment &&
1346  failed(profileComp.checkProfile(op, targetEnv)))
1347  return signalPassFailure();
1348 
1349  if (strictOpSpecAlignment &&
1350  failed(profileComp.checkExtension(op, targetEnv)))
1351  return signalPassFailure();
1352 
1353  if (!allowInvalidOpDatatypeCombinations &&
1354  failed(profileComp.checkInvalid(op)))
1355  return signalPassFailure();
1356 
1357  // Some uses of TOSA rely on the constant operands of particular
1358  // operations.
1359  if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
1360  signalPassFailure();
1361 
1362  // do level checks
1363  if (failed(applyLevelCheck(op)))
1364  signalPassFailure();
1365 
1366  // check additional attribute restrictions
1367  if (failed(applyAttributeCheck(op)))
1368  signalPassFailure();
1369 
1370  // do variable type checks
1371  if (failed(applyVariableCheck(op)))
1372  signalPassFailure();
1373 
1374  // do error if checks
1375  if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1376  signalPassFailure();
1377  });
1378 }
1379 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:279
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
@ Gather
const float * table
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allows(Profile prof) const
Definition: TargetEnv.h:43
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2424
RankedTensorType getVariableType(VariableOp variableOp)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369