MLIR  22.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 
18 #include <string>
19 
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Pass/Pass.h"
29 #include "llvm/ADT/StringExtras.h"
30 
31 namespace mlir {
32 namespace tosa {
33 #define GEN_PASS_DEF_TOSAVALIDATION
34 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
35 } // namespace tosa
36 } // namespace mlir
37 
38 using namespace mlir;
39 using namespace mlir::tosa;
40 
41 namespace {
42 
43 static LogicalResult
44 checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
45  for (const auto index : operandIndices) {
46  Attribute attr;
47  if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
48  return op->emitOpError("expected compile time resolvable constant, but "
49  "got variable value for operand #")
50  << index;
51  }
52  }
53  return success();
54 }
55 
56 static LogicalResult checkConstantOperandMul(Operation *op,
57  const TargetEnv &env) {
58  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
59  // Check 'shift'
60  return checkConstantOperands(op, {2});
61  }
62  return success();
63 }
64 
65 static LogicalResult checkConstantOperandTable(Operation *op,
66  const TargetEnv &env) {
67  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
68  // Check 'table'
69  return checkConstantOperands(op, {1});
70  }
71  return success();
72 }
73 
74 static LogicalResult checkConstantOperandPad(Operation *op,
75  const TargetEnv &env) {
76  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
77  // Assume this op is zero-padding if padConst is not presented
78  if (!env.allows(Extension::dynamic) && padOp.getPadConst())
79  // Check 'pad_const'
80  // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
81  return checkConstantOperands(op, {2});
82  }
83  return success();
84 }
85 
86 static LogicalResult checkConstantOperandRescale(Operation *op,
87  const TargetEnv &env) {
88  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
89  // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
90  return checkConstantOperands(op, {1, 2, 3, 4});
91  }
92  return success();
93 }
94 
95 template <typename T>
96 static LogicalResult checkConstantOperandConvOps(Operation *op,
97  const TargetEnv &env) {
98  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
99  // Check 'input_zp' and 'weight_zp'
100  return checkConstantOperands(op, {3, 4});
101  }
102  return success();
103 }
104 
105 static LogicalResult checkConstantOperandMatMul(Operation *op,
106  const TargetEnv &env) {
107  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
108  // Check 'A_zp' and 'B_zp'
109  return checkConstantOperands(op, {2, 3});
110  }
111  return success();
112 }
113 
114 static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
115  const TargetEnv &env) {
116  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
117  // Check 'input_zp' and 'output_zp'
118  return checkConstantOperands(op, {1, 2});
119  }
120  return success();
121 }
122 
123 static LogicalResult checkConstantOperandNegate(Operation *op,
124  const TargetEnv &env) {
125  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
126  // Check 'input1_zp' and 'output_zp'
127  return checkConstantOperands(op, {1, 2});
128  }
129  return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA Validation Pass.
134 //===----------------------------------------------------------------------===//
135 
136 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
137 public:
138  explicit TosaValidation() { populateConstantOperandChecks(); }
139 
140  explicit TosaValidation(const TosaValidationOptions &options)
141  : TosaValidation() {
142  this->strictOpSpecAlignment = options.strictOpSpecAlignment;
143  this->allowInvalidOpDatatypeCombinations =
144  options.allowInvalidOpDatatypeCombinations;
145  }
146  void runOnOperation() final;
147 
148  LogicalResult applyConstantOperandCheck(Operation *op) {
149  for (auto &checker : constCheckers) {
150  if (failed(checker(op, targetEnv)))
151  return failure();
152  }
153  return success();
154  }
155 
156  LogicalResult applyLevelCheck(Operation *op);
157  LogicalResult applyAttributeCheck(Operation *op);
158 
159  // check variable read/write data types against variable declarations
160  LogicalResult applyVariableCheck(Operation *op);
161 
162  // check error if conditions
163  LogicalResult applyErrorIfCheck(Operation *op);
164 
165 private:
166  void populateConstantOperandChecks() {
167  constCheckers.emplace_back(checkConstantOperandMul);
168  constCheckers.emplace_back(checkConstantOperandTable);
169  constCheckers.emplace_back(checkConstantOperandPad);
170  constCheckers.emplace_back(checkConstantOperandRescale);
171  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
172  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
173  constCheckers.emplace_back(
174  checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
175  constCheckers.emplace_back(
176  checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
177  constCheckers.emplace_back(checkConstantOperandMatMul);
178  constCheckers.emplace_back(checkConstantOperandAvgPool2d);
179  constCheckers.emplace_back(checkConstantOperandNegate);
180  }
181 
182  LogicalResult levelCheckKernel(Operation *op, int32_t v,
183  const StringRef checkDesc) {
184  if (v > targetEnv.getLevel().MAX_KERNEL)
185  return op->emitOpError() << "failed level check: " << checkDesc;
186  return success();
187  }
188 
189  LogicalResult levelCheckStride(Operation *op, int32_t v,
190  const StringRef checkDesc) {
191  if (v > targetEnv.getLevel().MAX_STRIDE)
192  return op->emitOpError() << "failed level check: " << checkDesc;
193  return success();
194  }
195 
196  LogicalResult levelCheckScale(Operation *op, int32_t v,
197  const StringRef checkDesc) {
198  if (v > targetEnv.getLevel().MAX_SCALE)
199  return op->emitOpError() << "failed level check: " << checkDesc;
200  return success();
201  }
202 
203  LogicalResult levelCheckListSize(Operation *op, int32_t v,
204  const StringRef checkDesc) {
205  if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
206  return op->emitOpError()
207  << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
208  return success();
209  }
210 
211  // Perform the Level Rank check on the tensor type.
212  LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
213  const StringRef operandOrResult,
214  int32_t highest_rank) {
215  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
216  if (!type.hasRank())
217  return op->emitOpError() << "failed level check: unranked tensor";
218  if (type.getRank() > highest_rank)
219  return op->emitOpError() << "failed level check: " << operandOrResult
220  << " rank(shape) <= MAX_RANK";
221  }
222  return success();
223  }
224 
225  // Perform the Level Rank check on the tensor value.
226  LogicalResult levelCheckRank(Operation *op, const Value &v,
227  const StringRef operandOrResult,
228  int32_t highest_rank) {
229  return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
230  }
231 
232  // Perform the Level tensor size check on the tensor type.
233  LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
234  const StringRef operandOrResult);
235 
236  // Perform the Level tensor size check on the tensor value.
237  LogicalResult levelCheckSize(Operation *op, const Value &v,
238  const StringRef operandOrResult) {
239  return levelCheckSize(op, v.getType(), operandOrResult);
240  }
241 
242  // Level check sizes of all operands and results of the operation.
243  template <typename T>
244  LogicalResult levelCheckSizes(T tosaOp) {
245  auto op = tosaOp.getOperation();
246  for (auto v : op->getOperands()) {
247  if (failed(levelCheckSize(op, v, "operand")))
248  return failure();
249  }
250 
251  for (auto v : op->getResults()) {
252  if (failed(levelCheckSize(op, v, "result")))
253  return failure();
254  }
255  return success();
256  }
257 
258  // Level check ranks of all operands, attribute and results of the operation.
259  template <typename T>
260  LogicalResult levelCheckRanks(T tosaOp) {
261  auto op = tosaOp.getOperation();
262  const TosaLevel tosaLevel = targetEnv.getLevel();
263  for (auto v : op->getOperands()) {
264  if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
265  return failure();
266  }
267 
268  for (auto v : op->getResults()) {
269  if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
270  return failure();
271  }
272  return success();
273  }
274 
275  // Level check ranks and sizes.
276  LogicalResult levelCheckRanksAndSizes(Operation *op);
277 
278  // Pool Op: level check kernel/stride/pad values
279  template <typename T>
280  LogicalResult levelCheckPool(Operation *op) {
281  if (auto poolOp = dyn_cast<T>(op)) {
282  for (auto k : poolOp.getKernel()) {
283  if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
284  return failure();
285  }
286  }
287  for (auto s : poolOp.getStride()) {
288  if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
289  return failure();
290  }
291  }
292  for (auto p : poolOp.getPad()) {
293  if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
294  return failure();
295  }
296  }
297  }
298  return success();
299  }
300 
301  // Conv Op: level check dilation/stride/pad values
302  template <typename T>
303  LogicalResult levelCheckConv(Operation *op) {
304  if (auto convOp = dyn_cast<T>(op)) {
305 
306  for (auto k : convOp.getDilation()) {
307  if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
308  return failure();
309  }
310  }
311  for (auto p : convOp.getPad()) {
312  if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
313  return failure();
314  }
315  }
316  for (auto s : convOp.getStride()) {
317  if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
318  return failure();
319  }
320  }
321  auto dilation = convOp.getDilation();
322  if (ShapedType weightType =
323  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
324  auto shape = weightType.getShape();
325  if (isa<tosa::Conv2DOp>(op)) {
326  assert(shape.size() == 4);
327  assert(dilation.size() == 2);
328  if (failed(levelCheckKernel(op, dilation[0] * shape[1],
329  "dilation_y * KH <= MAX_KERNEL)")) ||
330  failed(levelCheckKernel(op, dilation[1] * shape[2],
331  "dilation_x * KW <= MAX_KERNEL)")))
332  return failure();
333  } else if (isa<tosa::Conv3DOp>(op)) {
334  assert(shape.size() == 5);
335  assert(dilation.size() == 3);
336  if (failed(levelCheckKernel(op, dilation[0] * shape[1],
337  "dilation_d * KD <= MAX_KERNEL)")) ||
338  failed(levelCheckKernel(op, dilation[1] * shape[2],
339  "dilation_y * KH <= MAX_KERNEL)")) ||
340  failed(levelCheckKernel(op, dilation[2] * shape[3],
341  "dilation_x * KW <= MAX_KERNEL)")))
342  return failure();
343  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
344  assert(shape.size() == 4);
345  assert(dilation.size() == 2);
346  if (failed(levelCheckKernel(op, dilation[0] * shape[0],
347  "dilation_y * KH <= MAX_KERNEL)")) ||
348  failed(levelCheckKernel(op, dilation[1] * shape[1],
349  "dilation_x * KW <= MAX_KERNEL)")))
350  return failure();
351  }
352  }
353  }
354  return success();
355  }
356 
357  // FFT op: level check H, W in input shape [N,H,W]
358  template <typename T>
359  LogicalResult levelCheckFFT(Operation *op) {
360  if (isa<T>(op)) {
361  for (auto v : op->getOperands()) {
362  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
363  auto shape = type.getShape();
364  assert(shape.size() == 3);
365  if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
366  failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
367  return failure();
368  }
369  }
370  }
371  }
372  return success();
373  }
374 
375  // TransposeConv2d op: level check kH/kW, outpad, and stride
376  LogicalResult levelCheckTransposeConv2d(Operation *op) {
377  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
378  if (ShapedType filterType =
379  dyn_cast<ShapedType>(transpose.getWeight().getType())) {
380  auto shape = filterType.getShape();
381  assert(shape.size() == 4);
382  // level check kernel sizes for kH and KW
383  if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
384  failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
385  return failure();
386  }
387  }
388  for (auto p : transpose.getOutPad()) {
389  if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
390  return failure();
391  }
392  }
393  for (auto s : transpose.getStride()) {
394  if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
395  return failure();
396  }
397  }
398  }
399  return success();
400  }
401 
402  // Resize op: level check max scales
403  LogicalResult levelCheckResize(Operation *op) {
404  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
405  SmallVector<int64_t> scale;
406  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
407  scale)) {
408  return failure();
409  }
410  const int64_t scaleYN = scale[0];
411  const int64_t scaleYD = scale[1];
412  const int64_t scaleXN = scale[2];
413  const int64_t scaleXD = scale[3];
414  if (failed(levelCheckScale(op, scaleYN / scaleYD,
415  "scale_y_n/scale_y_d <= MAX_SCALE")) ||
416  failed(levelCheckScale(op, scaleXN / scaleXD,
417  "scale_x_n/scale_x_d <= MAX_SCALE"))) {
418  return failure();
419  }
420  }
421  return success();
422  }
423 
424  // Recursively perform a bottom-up search to determine the maximum nesting
425  // depth, starting from a specific operation and continuing up to the function
426  // or module scope. Tosa nesting_depth starts at 0 and increments by one each
427  // time a new nested `region` is encountered.
428  static void getMaxNestedDepth(Operation *op, int32_t &depth) {
429  if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
430  return;
431 
432  op = op->getParentOp();
433  if (!op)
434  return;
435 
436  depth++;
437  getMaxNestedDepth(op, depth);
438  }
439 
440  LogicalResult levelCheckMaxNesting(Operation *op) {
441  int32_t maxNestedDepth = 0;
442  getMaxNestedDepth(op, maxNestedDepth);
443 
444  if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
445  op->emitOpError() << "failed level check: " << maxNestedDepth
446  << " >= MAX_NESTING";
447  return failure();
448  }
449  return success();
450  }
451 
452  LogicalResult levelCheckListSize(Operation *op) {
453  if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454  return levelCheckListSize(op, concat.getInput1().size(), "input1");
455  }
456  if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
457  if (failed(levelCheckListSize(op, custom.getInputList().size(),
458  "input_list")) ||
459  failed(levelCheckListSize(op, custom.getOutputList().size(),
460  "output_list"))) {
461  return failure();
462  }
463  }
464  if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
465  if (failed(
466  levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
467  failed(levelCheckListSize(op, condIf.getOutputList().size(),
468  "outputs"))) {
469  return failure();
470  }
471  }
472  if (auto w = dyn_cast<tosa::WhileOp>(op)) {
473  if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
474  failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
475  return failure();
476  }
477  }
478  return success();
479  }
480 
481  LogicalResult attributeCheckRescale(Operation *op) {
482  if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
483  if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
484  !targetEnv.allows(Extension::doubleround)) {
485  op->emitOpError()
486  << "failed attribute check: rounding_mode = DOUBLE_ROUND "
487  << "requires extension [doubleround]";
488  return failure();
489  }
490  if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
491  !targetEnv.allows(Extension::inexactround)) {
492  op->emitOpError()
493  << "failed attribute check: rounding_mode = INEXACT_ROUND "
494  << "requires extension [inexactround]";
495  return failure();
496  }
497  }
498  return success();
499  }
500 
501  LogicalResult CheckVariable(Operation *op);
502  LogicalResult CheckVariableReadOrWrite(Operation *op);
503  bool isValidElementType(Type type, const bool allowUnsigned = false);
504 
505  SmallVector<
506  std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
507  constCheckers;
509  TosaProfileCompliance profileComp;
510  tosa::TargetEnv targetEnv;
511 };
512 
513 template <>
514 LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
515  auto *op = tosaOp.getOperation();
516  if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
517  targetEnv.getLevel().MAX_RANK)))
518  return failure();
519 
520  // rank(output) = rank(input) - 1
521  if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
522  targetEnv.getLevel().MAX_RANK - 1)))
523  return failure();
524 
525  return success();
526 }
527 
528 template <>
529 LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
530  auto *op = tosaOp.getOperation();
531 
532  // Only the condition input has rank limitation.
533  if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
534  targetEnv.getLevel().MAX_RANK)))
535  return failure();
536 
537  return success();
538 }
539 
540 template <>
541 LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
542  auto *op = tosaOp.getOperation();
543  auto variableType = getVariableType(tosaOp);
544  if (failed(levelCheckRank(op, variableType, "variable type",
545  targetEnv.getLevel().MAX_RANK)))
546  return failure();
547 
548  return success();
549 }
550 
551 template <>
552 LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
553  auto *op = tosaOp.getOperation();
554  auto variableType = getVariableType(tosaOp);
555  if (failed(levelCheckSize(op, variableType, "variable type")))
556  return failure();
557 
558  return success();
559 }
560 
561 LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
562 #define CHECK_RANKS_AND_SIZES(tosaOp) \
563  if (isa<tosa::tosaOp##Op>(op)) { \
564  if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
565  return failure(); \
566  if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
567  return failure(); \
568  }
569 
570 #define CHECK_SIZES(tosaOp) \
571  if (isa<tosa::tosaOp##Op>(op)) { \
572  if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op)))) \
573  return failure(); \
574  }
575 
576  // Tensor Operators
577  CHECK_RANKS_AND_SIZES(ArgMax);
578  // Activation Functions
579  CHECK_RANKS_AND_SIZES(Clamp);
581  CHECK_RANKS_AND_SIZES(Sigmoid);
582  CHECK_RANKS_AND_SIZES(Tanh);
583  // Elementwise Binary Operators
585  CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
586  CHECK_RANKS_AND_SIZES(BitwiseAnd);
587  CHECK_RANKS_AND_SIZES(BitwiseOr);
588  CHECK_RANKS_AND_SIZES(BitwiseXor);
589  CHECK_RANKS_AND_SIZES(IntDiv);
590  CHECK_RANKS_AND_SIZES(LogicalAnd);
591  CHECK_RANKS_AND_SIZES(LogicalLeftShift);
592  CHECK_RANKS_AND_SIZES(LogicalRightShift);
593  CHECK_RANKS_AND_SIZES(LogicalOr);
594  CHECK_RANKS_AND_SIZES(LogicalXor);
595  CHECK_RANKS_AND_SIZES(Maximum);
596  CHECK_RANKS_AND_SIZES(Minimum);
600  CHECK_RANKS_AND_SIZES(Table);
601  // Elementwise Unary Operators
603  CHECK_RANKS_AND_SIZES(BitwiseNot);
604  CHECK_RANKS_AND_SIZES(Ceil);
608  CHECK_RANKS_AND_SIZES(Floor);
610  CHECK_RANKS_AND_SIZES(LogicalNot);
611  CHECK_RANKS_AND_SIZES(Negate);
612  CHECK_RANKS_AND_SIZES(Reciprocal);
613  CHECK_RANKS_AND_SIZES(Rsqrt);
615  // Elementwise Ternary Operators
616  CHECK_RANKS_AND_SIZES(Select);
617  // Comparison Operators
618  CHECK_RANKS_AND_SIZES(Equal);
619  CHECK_RANKS_AND_SIZES(Greater);
620  CHECK_RANKS_AND_SIZES(GreaterEqual);
621  // Reduction Operators
622  CHECK_RANKS_AND_SIZES(ReduceAll);
623  CHECK_RANKS_AND_SIZES(ReduceAny);
624  CHECK_RANKS_AND_SIZES(ReduceMax);
625  CHECK_RANKS_AND_SIZES(ReduceMin);
626  CHECK_RANKS_AND_SIZES(ReduceProduct);
627  CHECK_RANKS_AND_SIZES(ReduceSum);
628  // Data Layout Operators
629  CHECK_RANKS_AND_SIZES(Concat);
631  CHECK_RANKS_AND_SIZES(Reshape);
632  CHECK_RANKS_AND_SIZES(Reverse);
633  CHECK_RANKS_AND_SIZES(Slice);
634  CHECK_RANKS_AND_SIZES(Tile);
635  CHECK_RANKS_AND_SIZES(Transpose);
636  // Type Conversion
637  CHECK_RANKS_AND_SIZES(Cast);
638  CHECK_RANKS_AND_SIZES(Rescale);
639  // Control Flow Operators
641  // Variable Operators
642  CHECK_RANKS_AND_SIZES(Variable);
643  CHECK_RANKS_AND_SIZES(VariableWrite);
644  CHECK_RANKS_AND_SIZES(VariableRead);
645  // Data Nodes
646  CHECK_RANKS_AND_SIZES(Const);
647  CHECK_RANKS_AND_SIZES(Identity);
648 
649  // For the following operators, check whether the size of each tensor
650  // operand is valid in a given Level.
651 
652  // Tensor Operators
653  CHECK_SIZES(AvgPool2d);
654  CHECK_SIZES(Conv2D);
655  CHECK_SIZES(Conv3D);
656  CHECK_SIZES(DepthwiseConv2D);
657  CHECK_SIZES(TransposeConv2D);
658  CHECK_SIZES(FFT2d);
659  CHECK_SIZES(MatMul);
660  CHECK_SIZES(MaxPool2d);
661  CHECK_SIZES(RFFT2d);
662  // Scatter/Gather Operators
664  CHECK_SIZES(Scatter);
665  // Image Operators
666  CHECK_SIZES(Resize);
667  // Custom Operators
668  CHECK_SIZES(Custom);
669  // Control Flow Operators
670  CHECK_SIZES(While);
671  // Shape Operators
672  CHECK_SIZES(ConstShape);
673 
674 #undef CHECK_RANKS_AND_SIZES
675 #undef CHECK_SIZES
676  return success();
677 }
678 
679 // Perform the Level tensor size check on the tensor type.
680 LogicalResult TosaValidation::levelCheckSize(Operation *op,
681  const Type &typeToCheck,
682  const StringRef operandOrResult) {
683  if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
684  if (!type.hasRank())
685  return op->emitOpError() << "failed level check: unranked tensor";
686  auto shape = type.getShape();
687  for (auto dim : shape) {
688  if (mlir::ShapedType::isDynamic(dim))
689  return op->emitOpError() << "failed level check: " << operandOrResult
690  << " shape dimension cannot be dynamic";
691  }
692 
693  int64_t element_bits = type.getElementTypeBitWidth();
694  int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
695  int64_t size = element_bytes * type.getNumElements();
696 
697  // According to 1.11. Tensor Definitions of Tosa spec, the value of
698  // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
699  // defined in 1.7. Levels.
700  // For each tensor, the number of tensor elements multiplied by the
701  // element size in bytes must be representable as a tensor_size_t.
702  const int64_t max_size =
703  (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
704  if (size > max_size)
705  return op->emitOpError()
706  << "failed level check: " << operandOrResult
707  << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
708  }
709  return success();
710 }
711 
712 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
713  if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
714  // no need to do level checks
715  return success();
716  }
717 
718  // check rank and sizes early so later checks can assume shaped operands
719  if (failed(levelCheckRanksAndSizes(op)))
720  return failure();
721 
722  // additional level checks from spec 0.70
723  if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
724  failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
725  failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
726  failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
727  failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
728  failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
729  failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
730  failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
731  return failure();
732  }
733 
734  // level check MAX_TENSOR_LIST_SIZE
735  if (failed(levelCheckListSize(op))) {
736  return failure();
737  }
738 
739  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
740  if (failed(levelCheckMaxNesting(op))) {
741  return failure();
742  }
743  }
744 
745  return success();
746 }
747 
748 LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
749  if (failed(attributeCheckRescale(op)))
750  return failure();
751  return success();
752 }
753 
754 inline bool CompatibleTypes(const mlir::Type &type,
755  const mlir::Type &declaredType) {
756  // for now, simply use type equality comparison
757  return type == declaredType;
758 }
759 
760 LogicalResult TosaValidation::CheckVariable(Operation *op) {
761  if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
762  mlir::StringAttr nameAttr = variableOp.getNameAttr();
763 
764  if (variablesMap.count(nameAttr))
765  return op->emitOpError() << "name has already been declared";
766 
767  auto elementType = variableOp.getType();
768  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
769  SmallVector<int64_t> shape = to_vector(varShapeAttr.getValues<int64_t>());
770  RankedTensorType variableType =
771  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
772 
773  variablesMap[nameAttr] = variableType;
774  }
775 
776  return success();
777 }
778 
779 LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
780  if (isa<mlir::tosa::VariableReadOp>(op) ||
781  isa<mlir::tosa::VariableWriteOp>(op)) {
782  mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
783  if (!variablesMap.count(nameAttr))
784  return op->emitOpError() << "name has not been declared";
785 
786  auto varType = variablesMap[nameAttr];
787 
788  for (auto v : op->getOperands()) {
789  auto type = v.getType();
790  if (!CompatibleTypes(type, varType))
791  return op->emitOpError() << "operand type does not equal variable type";
792  }
793 
794  for (auto v : op->getResults()) {
795  auto type = v.getType();
796  if (!CompatibleTypes(type, varType))
797  return op->emitOpError() << "result type does not equal variable type";
798  }
799  }
800 
801  return success();
802 }
803 
804 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
805  if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
806  return failure();
807  return success();
808 }
809 
810 LogicalResult checkErrorIfResize(Operation *op) {
811  auto resize = dyn_cast<tosa::ResizeOp>(op);
812  if (!resize)
813  return success();
814 
815  const Value input = resize.getInput();
816  const Value output = resize.getOutput();
817  const RankedTensorType inputType =
818  llvm::dyn_cast<RankedTensorType>(input.getType());
819  const RankedTensorType outputType =
820  llvm::dyn_cast<RankedTensorType>(output.getType());
821 
822  if (!inputType || !outputType)
823  return op->emitOpError("expect ranked input/output tensor");
824 
825  // Ensure the image size is supported by GPU APIs and that for integer
826  // implementations, position * stride does not overflow int32_t.
827  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
828  const SmallVector<int64_t, 4> sizes = {
829  outputType.getDimSize(1), outputType.getDimSize(2),
830  inputType.getDimSize(1), inputType.getDimSize(2)};
831  const int64_t *maxDim = llvm::max_element(sizes);
832  if (maxDim != sizes.end() && *maxDim >= 16384)
833  return op->emitOpError(
834  "expect input/output height/width dims to be < 16384, ")
835  << "got [OH, OW, IH, IW] = " << sizes;
836  }
837 
838  SmallVector<int64_t> scale;
839  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
840  return failure();
841 
842  const int64_t scaleYN = scale[0];
843  const int64_t scaleYD = scale[1];
844  const int64_t scaleXN = scale[2];
845  const int64_t scaleXD = scale[3];
846 
847  // Ensure scale values don't overflow int32 accumulator
848  if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
849  return op->emitOpError(
850  "expect all scale numerator values to be <= (1 << 11), "
851  "got scale_y_n=")
852  << scaleYN << ", scale_x_n=" << scaleXN;
853 
854  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
855  return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
856  << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
857 
858  SmallVector<int64_t> offset;
859  SmallVector<int64_t> border;
860  if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
861  !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
862  return failure();
863 
864  const int64_t offsetY = offset[0];
865  const int64_t offsetX = offset[1];
866  // Set a consistent lower limit of 1/16 downscale to simplify
867  // implementations
868  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
869  return op->emitOpError(
870  "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
871  << offsetY << "/" << scaleYN;
872  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
873  return op->emitOpError(
874  "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
875  << offsetX << "/" << scaleXN;
876 
877  const int64_t borderY = border[0];
878  const int64_t borderX = border[1];
879  if (borderY < -16 * scaleYN || borderY >= scaleYN)
880  return op->emitOpError(
881  "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
882  << borderY << "/" << scaleYN;
883  if (borderX < -16 * scaleXN || borderX >= scaleXN)
884  return op->emitOpError(
885  "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
886  << borderX << "/" << scaleXN;
887 
888  // The following section of code is mostly duplicated with ResizeOp::verify().
889  //
890  // In TOSA specification, we do not support broadcast behavior.
891  // However, there is a rewrite pattern to materialize broadcast ResizeOp.
892  // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
893  // existing code, we keep the rewrite pattern untouched. So, we need
894  // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
895  //
896  // Here is a strict checking to conform TOSA specification.
897  // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
898  auto idivCheck = [](const int64_t lhs,
899  const int64_t rhs) -> std::optional<int64_t> {
900  if (lhs % rhs != 0)
901  return std::nullopt;
902  return lhs / rhs;
903  };
904 
905  const int64_t oh = outputType.getDimSize(1);
906  const int64_t ow = outputType.getDimSize(2);
907  const int64_t ih = inputType.getDimSize(1);
908  const int64_t iw = inputType.getDimSize(2);
909 
910  if (ih != ShapedType::kDynamic) {
911  const std::optional<int64_t> calculatedOutHeightMinusOne =
912  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
913  if (!calculatedOutHeightMinusOne.has_value())
914  return op->emitOpError(
915  "expected (input_height - 1) * scale_y_n - offset_y + "
916  "border_y ")
917  << "to be wholly divisible by scale_y_d, got ((" << ih
918  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
919  << ") / " << scaleYD;
920  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
921  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
922  return op->emitOpError(
923  "calculated output height did not match expected: ")
924  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
925  }
926 
927  if (iw != ShapedType::kDynamic) {
928  const std::optional<int64_t> calculatedOutWidthMinusOne =
929  idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
930  if (!calculatedOutWidthMinusOne.has_value())
931  return op->emitOpError(
932  "expected (input_width - 1) * scale_x_n - offset_x + "
933  "border_x ")
934  << "to be wholly divisible by scale_x_d, got ((" << iw
935  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
936  << ") / " << scaleXD;
937  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
938  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
939  return op->emitOpError("calculated output width did not match expected: ")
940  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
941  }
942 
943  return success();
944 }
945 
946 LogicalResult checkErrorIfMul(Operation *op) {
947  auto mul = dyn_cast<tosa::MulOp>(op);
948  if (!mul)
949  return success();
950 
951  // REQUIRE(0 <= shift && shift <= 63);
952  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
953  ElementsAttr shift_elem;
954  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
955  return success();
956  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
957  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
958  if (inputElemType.isInteger(32)) {
959  // 0 <= shift <= 63 for int32_t type
960  if (shift < 0 || shift > 63)
961  return op->emitOpError()
962  << "requires 0 <= shift && shift <= 63, but got: " << shift;
963  } else {
964  // shift must be 0 for all other types
965  if (shift != 0)
966  return op->emitOpError()
967  << "requires shift = 0 for all input data types that "
968  "are not int32_t, but got: "
969  << shift;
970  }
971 
972  return success();
973 }
974 
975 LogicalResult checkErrorIfTable(Operation *op) {
976  auto table = dyn_cast<tosa::TableOp>(op);
977  if (!table)
978  return success();
979 
980  // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
981  const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
982  const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
983 
984  const ShapeAdaptor tableShape(table.getTable().getType());
985  if (tableShape.hasStaticShape()) {
986  const auto numElements = tableShape.getNumElements();
987  if (numElements != tableSize)
988  return op->emitOpError() << "requires table size of " << tableSize
989  << ", got " << numElements;
990  }
991 
992  return success();
993 }
994 
995 LogicalResult checkErrorIfRescale(Operation *op) {
996  auto rescale = dyn_cast<tosa::RescaleOp>(op);
997  if (!rescale)
998  return success();
999 
1000  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1001  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1002  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1003  !outputType.getElementType().isInteger())
1004  return success();
1005 
1006  auto inElemType = inputType.getElementType();
1007  auto outElemType = outputType.getElementType();
1008  auto inWidth = inElemType.getIntOrFloatBitWidth();
1009  auto outWidth = outElemType.getIntOrFloatBitWidth();
1010 
1011  bool inputUnsigned = rescale.getInputUnsigned();
1012  bool outputUnsigned = rescale.getOutputUnsigned();
1013 
1014  bool scale32 = rescale.getScale32();
1015  auto roundingMode = rescale.getRoundingMode();
1016 
1017  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1018  if (scale32 && inWidth == 48)
1019  return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1020 
1021  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1022  if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
1023  return op->emitOpError()
1024  << "DOUBLE_ROUND is only allowed with scale32=true.";
1025 
1026  // ERROR_IF(input_unsigned && output_unsigned)
1027  if (inputUnsigned && outputUnsigned)
1028  return op->emitOpError() << "input and output cannot be both unsigned.";
1029 
1030  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1031  if (outWidth == 32 && inputUnsigned)
1032  return op->emitOpError()
1033  << "i32 output type is not allowed with unsigned input.";
1034 
1035  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1036  if (inWidth == 32 && outputUnsigned)
1037  return op->emitOpError()
1038  << "i32 input type is not allowed with unsigned output.";
1039 
1040  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1041  if (inWidth == 48 && outputUnsigned)
1042  return op->emitOpError()
1043  << "i48 input type is not allowed with unsigned output.";
1044 
1045  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1046  if (inWidth == 48 && inputUnsigned)
1047  return op->emitOpError() << "i48 input type cannot be unsigned.";
1048 
1049  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1050  if (inWidth == 32 && inputUnsigned)
1051  return op->emitOpError() << "i32 input type cannot be unsigned.";
1052 
1053  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1054  if (outWidth == 32 && outputUnsigned)
1055  return op->emitOpError() << "i32 output type cannot be unsigned.";
1056 
1057  return success();
1058 }
1059 
1060 LogicalResult checkErrorIfPad(Operation *op) {
1061  auto pad = dyn_cast<tosa::PadOp>(op);
1062  if (!pad)
1063  return success();
1064 
1065  DenseIntElementsAttr paddingAttr;
1066  if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1067  // Pad verifier will catch this
1068  return success();
1069 
1070  for (const APInt &val : paddingAttr.getValues<APInt>()) {
1071  if (val.getSExtValue() < 0)
1072  return op->emitOpError() << "padding value must all be non-negative, got "
1073  << val.getSExtValue();
1074  }
1075 
1076  return success();
1077 }
1078 
1079 static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
1080  return llvm::all_of(op->getOperands(), [&](auto operand) {
1081  Region *operandRegion = operand.getParentRegion();
1082  return operandRegion && region->isAncestor(operandRegion);
1083  });
1084 }
1085 
1086 static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
1087  bool noLiveInValue = true;
1088  regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
1089  if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
1090  noLiveInValue = false;
1091  return WalkResult::interrupt();
1092  }
1093  return WalkResult::advance();
1094  });
1095  return noLiveInValue ? success() : failure();
1096 }
1097 
1098 LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
1099  StringRef regionName) {
1100  if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
1101  return success();
1102  return op->emitOpError()
1103  << "is not conformant to the TOSA specification. It requires the '"
1104  << regionName << "' region is isolated from above.\n";
1105 }
1106 
1107 LogicalResult checkErrorIfCondIf(Operation *op) {
1108  auto ifOp = dyn_cast<tosa::IfOp>(op);
1109  if (!ifOp)
1110  return success();
1111 
1112  // Currently the dialect supports declaring cond_if operations that
1113  // have then/else regions that reference values from outside these
1114  // regions. According to the specification, all values used by the
1115  // then/else regions must be explicitly declared within the regions.
1116  // Therefore we must check that the then/else regions are
1117  // "isolated from above", in order to be conformant to the
1118  // specification.
1119  //
1120  // Note: the dialect currently supports two styles of syntax for
1121  // declaring "cond_if" operations. We'll refer to these as follows:
1122  //
1123  // Generic:
1124  // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
1125  // ^bb0(%arg3, %arg4):
1126  // tosa.yield %arg3
1127  // }, {
1128  // ^bb0(%arg3, %arg4):
1129  // tosa.yield %arg4
1130  // })
1131  //
1132  // Simplified:
1133  // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
1134  // ^bb0(%arg3, %arg4):
1135  // tosa.yield %arg3
1136  // } else {
1137  // ^bb0(%arg3, %arg4):
1138  // tosa.yield %arg4
1139  // }
1140 
1141  if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
1142  failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
1143  return failure();
1144  return success();
1145 }
1146 
1147 LogicalResult checkErrorIfWhileLoop(Operation *op) {
1148  auto whileOp = dyn_cast<tosa::WhileOp>(op);
1149  if (!whileOp)
1150  return success();
1151 
1152  if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
1153  failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
1154  return failure();
1155  return success();
1156 }
1157 
1158 LogicalResult checkErrorIfScatter(Operation *op) {
1159  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1160  if (!scatterOp)
1161  return success();
1162 
1163  // for constant indices, check that there are no duplicate values
1164  DenseIntElementsAttr indicesAttr;
1165  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1166  return success();
1167 
1168  auto const indicesType =
1169  dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1170  if (!indicesType || !indicesType.hasRank()) {
1171  op->emitOpError("expect ranked indices tensor");
1172  return failure();
1173  }
1174 
1175  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1176  op->emitOpError("indices values contain duplicates");
1177  return failure();
1178  }
1179 
1180  return success();
1181 }
1182 
1183 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1184  if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
1185  failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
1186  failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
1187  failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
1188  return failure();
1189  return success();
1190 }
1191 
1192 bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
1193  if (isa<FloatType>(type)) {
1194  return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1195  Float8E5M2Type>(type);
1196  }
1197  if (auto intTy = dyn_cast<IntegerType>(type)) {
1198  if (intTy.isSignless()) {
1199  switch (intTy.getWidth()) {
1200  case 1:
1201  case 4:
1202  case 8:
1203  case 16:
1204  case 32:
1205  case 48:
1206  return true;
1207  }
1208  } else if (allowUnsigned && intTy.isUnsigned()) {
1209  switch (intTy.getWidth()) {
1210  case 8:
1211  case 16:
1212  case 32:
1213  return true;
1214  }
1215  }
1216  } else if (mlir::isa<tosa::shapeType>(type)) {
1217  return true;
1218  }
1219  return false;
1220 }
1221 
1222 void TosaValidation::runOnOperation() {
1223  TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1224  if (!tosaDialect)
1225  return;
1226 
1227  targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
1228 
1229  getOperation().walk([&](Operation *op) {
1230  if (op->getDialect() != tosaDialect)
1231  return;
1232 
1233  // validate operator element types:
1234  // - rescale operator is allowed to have ui8/ui16/ui32
1235  // operands/results when strictOpSpecAlignment is false
1236  // - perform valid element type check at the beginning to
1237  // protect rest of code against quantized element types
1238  const bool allowUnsigned =
1239  !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
1240  for (Value operand : op->getOperands()) {
1241  auto elementTy = getElementTypeOrSelf(operand);
1242  if (!isValidElementType(elementTy, allowUnsigned)) {
1243  op->emitOpError() << "is not profile-aligned: element type "
1244  << elementTy << " is not legal";
1245  return signalPassFailure();
1246  }
1247  }
1248  for (Type resultTy : op->getResultTypes()) {
1249  auto elementTy = getElementTypeOrSelf(resultTy);
1250  if (!isValidElementType(elementTy, allowUnsigned)) {
1251  op->emitOpError() << "is not profile-aligned: element type "
1252  << elementTy << " is not legal";
1253  return signalPassFailure();
1254  }
1255  }
1256 
1257  if (strictOpSpecAlignment &&
1258  failed(profileComp.checkProfile(op, targetEnv)))
1259  return signalPassFailure();
1260 
1261  if (strictOpSpecAlignment &&
1262  failed(profileComp.checkExtension(op, targetEnv)))
1263  return signalPassFailure();
1264 
1265  if (!allowInvalidOpDatatypeCombinations &&
1266  failed(profileComp.checkInvalid(op)))
1267  return signalPassFailure();
1268 
1269  // Some uses of TOSA rely on the constant operands of particular
1270  // operations.
1271  if (failed(applyConstantOperandCheck(op)))
1272  signalPassFailure();
1273 
1274  // do level checks
1275  if (failed(applyLevelCheck(op)))
1276  signalPassFailure();
1277 
1278  // check additional attribute restrictions
1279  if (failed(applyAttributeCheck(op)))
1280  signalPassFailure();
1281 
1282  // do variable type checks
1283  if (failed(applyVariableCheck(op)))
1284  signalPassFailure();
1285 
1286  // do error if checks
1287  if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1288  signalPassFailure();
1289  });
1290 }
1291 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:515
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
@ Gather
ArrayRef< float > table
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:56
bool allows(Profile prof) const
Definition: TargetEnv.h:86
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2465
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
RankedTensorType getVariableType(VariableOp variableOp)
static constexpr TosaLevel TOSA_LEVEL_NONE
Definition: TargetEnv.h:42
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Definition: TargetEnv.cpp:34
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369