MLIR  21.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
18 
19 #include <string>
20 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Pass/Pass.h"
30 #include "llvm/ADT/StringExtras.h"
31 
32 namespace mlir {
33 namespace tosa {
34 #define GEN_PASS_DEF_TOSAVALIDATION
35 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36 } // namespace tosa
37 } // namespace mlir
38 
39 using namespace mlir;
40 using namespace mlir::tosa;
41 
42 namespace {
43 
44 static LogicalResult
45 checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
46  for (const auto index : operandIndices) {
47  Attribute attr;
48  if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
49  return op->emitOpError("expected compile time resolvable constant, but "
50  "got variable value for operand #")
51  << index;
52  }
53  }
54  return success();
55 }
56 
57 static LogicalResult checkConstantOperandMul(Operation *op,
58  const TargetEnv &env) {
59  if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
60  // Check 'shift'
61  return checkConstantOperands(op, {2});
62  }
63  return success();
64 }
65 
66 static LogicalResult checkConstantOperandTable(Operation *op,
67  const TargetEnv &env) {
68  if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
69  // Check 'table'
70  return checkConstantOperands(op, {1});
71  }
72  return success();
73 }
74 
75 static LogicalResult checkConstantOperandPad(Operation *op,
76  const TargetEnv &env) {
77  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
78  // Assume this op is zero-padding if padConst is not presented
79  if (!env.allows(Extension::dynamic) && padOp.getPadConst())
80  // Check 'pad_const'
81  // Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
82  return checkConstantOperands(op, {2});
83  }
84  return success();
85 }
86 
87 static LogicalResult checkConstantOperandRescale(Operation *op,
88  const TargetEnv &env) {
89  if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
90  // Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
91  return checkConstantOperands(op, {1, 2, 3, 4});
92  }
93  return success();
94 }
95 
96 template <typename T>
97 static LogicalResult checkConstantOperandConvOps(Operation *op,
98  const TargetEnv &env) {
99  if (!env.allows(Extension::dynamic) && isa<T>(op)) {
100  // Check 'input_zp' and 'weight_zp'
101  return checkConstantOperands(op, {3, 4});
102  }
103  return success();
104 }
105 
106 static LogicalResult checkConstantOperandMatMul(Operation *op,
107  const TargetEnv &env) {
108  if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
109  // Check 'A_zp' and 'B_zp'
110  return checkConstantOperands(op, {2, 3});
111  }
112  return success();
113 }
114 
115 static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
116  const TargetEnv &env) {
117  if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
118  // Check 'input_zp' and 'output_zp'
119  return checkConstantOperands(op, {1, 2});
120  }
121  return success();
122 }
123 
124 static LogicalResult checkConstantOperandNegate(Operation *op,
125  const TargetEnv &env) {
126  if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
127  // Check 'input1_zp' and 'output_zp'
128  return checkConstantOperands(op, {1, 2});
129  }
130  return success();
131 }
132 
133 struct TosaLevel {
134  int32_t MAX_RANK = 0;
135  int32_t MAX_KERNEL = 0;
136  int32_t MAX_STRIDE = 0;
137  int32_t MAX_SCALE = 0;
138  int32_t MAX_LOG2_SIZE = 0;
139  int32_t MAX_NESTING = 0;
140  int32_t MAX_TENSOR_LIST_SIZE = 0;
141 
142  bool operator==(const TosaLevel &rhs) {
143  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
144  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
145  MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
146  MAX_NESTING == rhs.MAX_NESTING &&
147  MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
148  }
149 };
150 
151 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
152 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
153  63, 256, 256};
154 
155 //===----------------------------------------------------------------------===//
156 // TOSA Validation Pass.
157 //===----------------------------------------------------------------------===//
158 
159 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
160 public:
161  explicit TosaValidation() { populateConstantOperandChecks(); }
162 
163  explicit TosaValidation(const TosaValidationOptions &options)
164  : TosaValidation() {
165  this->profile = options.profile;
166  this->extension = options.extension;
167  this->strictOpSpecAlignment = options.strictOpSpecAlignment;
168  this->allowInvalidOpDatatypeCombinations =
169  options.allowInvalidOpDatatypeCombinations;
170  this->level = options.level;
171  }
172  void runOnOperation() final;
173 
174  LogicalResult applyConstantOperandCheck(Operation *op) {
175  for (auto &checker : constCheckers) {
176  if (failed(checker(op, targetEnv)))
177  return failure();
178  }
179  return success();
180  }
181 
182  LogicalResult applyLevelCheck(Operation *op);
183  LogicalResult applyAttributeCheck(Operation *op);
184 
185  // check variable read/write data types against variable declarations
186  LogicalResult applyVariableCheck(Operation *op);
187 
188  // check error if conditions
189  LogicalResult applyErrorIfCheck(Operation *op);
190 
191 private:
192  void populateConstantOperandChecks() {
193  constCheckers.emplace_back(checkConstantOperandMul);
194  constCheckers.emplace_back(checkConstantOperandTable);
195  constCheckers.emplace_back(checkConstantOperandPad);
196  constCheckers.emplace_back(checkConstantOperandRescale);
197  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
198  constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
199  constCheckers.emplace_back(
200  checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
201  constCheckers.emplace_back(
202  checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
203  constCheckers.emplace_back(checkConstantOperandMatMul);
204  constCheckers.emplace_back(checkConstantOperandAvgPool2d);
205  constCheckers.emplace_back(checkConstantOperandNegate);
206  }
207 
208  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
209  if (v > tosaLevel.MAX_KERNEL) {
210  op->emitOpError() << "failed level check: " << checkDesc;
211  return false;
212  }
213  return true;
214  }
215 
216  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
217  if (v > tosaLevel.MAX_STRIDE) {
218  op->emitOpError() << "failed level check: " << checkDesc;
219  return false;
220  }
221  return true;
222  }
223 
224  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
225  if (v > tosaLevel.MAX_SCALE) {
226  op->emitOpError() << "failed level check: " << checkDesc;
227  return false;
228  }
229  return true;
230  }
231 
232  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
233  if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
234  op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
235  << checkDesc;
236  return false;
237  }
238  return true;
239  }
240 
241  template <typename T>
242  bool levelCheckRank(Operation *op, const T &v,
243  const StringRef operandOrResult, int32_t highest_rank) {
244  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
245  if (!type.hasRank()) {
246  op->emitOpError() << "failed level check: unranked tensor";
247  return false;
248  }
249  if (type.getRank() > highest_rank) {
250  op->emitOpError() << "failed level check: " << operandOrResult
251  << " rank(shape) <= MAX_RANK";
252  return false;
253  }
254  }
255  return true;
256  }
257 
258  // Perform the Level tensor size check on the input tensor.
259  bool levelCheckSize(Operation *op, const Value &v,
260  const StringRef operandOrResult);
261 
262  // Level check sizes of all operands and results of the operation.
263  template <typename T>
264  bool levelCheckSizes(T tosaOp) {
265  auto op = tosaOp.getOperation();
266  for (auto v : op->getOperands()) {
267  if (!levelCheckSize(op, v, "operand"))
268  return false;
269  }
270 
271  for (auto v : op->getResults()) {
272  if (!levelCheckSize(op, v, "result"))
273  return false;
274  }
275  return true;
276  }
277 
278  // Level check ranks of all operands, attribute and results of the operation.
279  template <typename T>
280  bool levelCheckRanks(T tosaOp) {
281  auto op = tosaOp.getOperation();
282  for (auto v : op->getOperands()) {
283  if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
284  return false;
285  }
286 
287  if (!op->getAttrs().empty()) {
288  for (NamedAttribute attr : op->getAttrs()) {
289  if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
290  if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
291  return false;
292  }
293  }
294  }
295 
296  for (auto v : op->getResults()) {
297  if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
298  return false;
299  }
300  return true;
301  }
302 
303  // Level check ranks and sizes.
304  bool levelCheckRanksAndSizes(Operation *op);
305 
306  // Pool Op: level check kernel/stride/pad values
307  template <typename T>
308  bool levelCheckPool(Operation *op) {
309  if (auto poolOp = dyn_cast<T>(op)) {
310  for (auto k : poolOp.getKernel()) {
311  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
312  return false;
313  }
314  }
315  for (auto s : poolOp.getStride()) {
316  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
317  return false;
318  }
319  }
320  for (auto p : poolOp.getPad()) {
321  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
322  return false;
323  }
324  }
325  }
326  return true;
327  }
328 
329  // Conv Op: level check dilation/stride/pad values
330  template <typename T>
331  bool levelCheckConv(Operation *op) {
332  if (auto convOp = dyn_cast<T>(op)) {
333 
334  for (auto k : convOp.getDilation()) {
335  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
336  return false;
337  }
338  }
339  for (auto p : convOp.getPad()) {
340  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
341  return false;
342  }
343  }
344  for (auto s : convOp.getStride()) {
345  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
346  return false;
347  }
348  }
349  auto dilation = convOp.getDilation();
350  if (ShapedType weightType =
351  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
352  auto shape = weightType.getShape();
353  if (isa<tosa::Conv2DOp>(op)) {
354  assert(shape.size() == 4);
355  assert(dilation.size() == 2);
356  if (!levelCheckKernel(op, dilation[0] * shape[1],
357  "dilation_y * KH <= MAX_KERNEL)") ||
358  !levelCheckKernel(op, dilation[1] * shape[2],
359  "dilation_x * KW <= MAX_KERNEL)"))
360  return false;
361  } else if (isa<tosa::Conv3DOp>(op)) {
362  assert(shape.size() == 5);
363  assert(dilation.size() == 3);
364  if (!levelCheckKernel(op, dilation[0] * shape[1],
365  "dilation_d * KD <= MAX_KERNEL)") ||
366  !levelCheckKernel(op, dilation[1] * shape[2],
367  "dilation_y * KH <= MAX_KERNEL)") ||
368  !levelCheckKernel(op, dilation[2] * shape[3],
369  "dilation_x * KW <= MAX_KERNEL)"))
370  return false;
371  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
372  assert(shape.size() == 4);
373  assert(dilation.size() == 2);
374  if (!levelCheckKernel(op, dilation[0] * shape[0],
375  "dilation_y * KH <= MAX_KERNEL)") ||
376  !levelCheckKernel(op, dilation[1] * shape[1],
377  "dilation_x * KW <= MAX_KERNEL)"))
378  return false;
379  }
380  }
381  }
382  return true;
383  }
384 
385  // FFT op: level check H, W in input shape [N,H,W]
386  template <typename T>
387  bool levelCheckFFT(Operation *op) {
388  if (isa<T>(op)) {
389  for (auto v : op->getOperands()) {
390  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
391  auto shape = type.getShape();
392  assert(shape.size() == 3);
393  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
394  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
395  return false;
396  }
397  }
398  }
399  }
400  return true;
401  }
402 
403  // TransposeConv2d op: level check kH/kW, outpad, and stride
404  bool levelCheckTransposeConv2d(Operation *op) {
405  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
406  if (ShapedType filterType =
407  dyn_cast<ShapedType>(transpose.getWeight().getType())) {
408  auto shape = filterType.getShape();
409  assert(shape.size() == 4);
410  // level check kernel sizes for kH and KW
411  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
412  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
413  return false;
414  }
415  }
416  for (auto p : transpose.getOutPad()) {
417  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
418  return false;
419  }
420  }
421  for (auto s : transpose.getStride()) {
422  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
423  return false;
424  }
425  }
426  }
427  return true;
428  }
429 
430  // Resize op: level check max scales
431  bool levelCheckResize(Operation *op) {
432  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
433  SmallVector<int64_t> scale;
434  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
435  scale)) {
436  return false;
437  }
438  const int64_t scaleYN = scale[0];
439  const int64_t scaleYD = scale[1];
440  const int64_t scaleXN = scale[2];
441  const int64_t scaleXD = scale[3];
442  if (!levelCheckScale(op, scaleYN / scaleYD,
443  "scale_y_n/scale_y_d <= MAX_SCALE") ||
444  !levelCheckScale(op, scaleXN / scaleXD,
445  "scale_x_n/scale_x_d <= MAX_SCALE")) {
446  return false;
447  }
448  }
449  return true;
450  }
451 
452  // Recursively perform a bottom-up search to determine the maximum nesting
453  // depth, starting from a specific operation and continuing up to the function
454  // or module scope. Tosa nesting_depth starts at 0 and increments by one each
455  // time a new nested `region` is encountered.
456  static void getMaxNestedDepth(Operation *op, int32_t &depth) {
457  if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
458  return;
459 
460  op = op->getParentOp();
461  if (!op)
462  return;
463 
464  depth++;
465  getMaxNestedDepth(op, depth);
466  return;
467  }
468 
469  bool levelCheckMaxNesting(Operation *op) {
470  int32_t maxNestedDepth = 0;
471  getMaxNestedDepth(op, maxNestedDepth);
472 
473  if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
474  op->emitOpError() << "failed level check: " << maxNestedDepth
475  << " >= MAX_NESTING";
476  return false;
477  }
478  return true;
479  }
480 
481  bool levelCheckListSize(Operation *op) {
482  if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
483  return levelCheckListSize(op, concat.getInput1().size(), "input1");
484  }
485  if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
486  if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
487  !levelCheckListSize(op, custom.getOutputList().size(),
488  "output_list")) {
489  return false;
490  }
491  }
492  if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
493  if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
494  !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
495  return false;
496  }
497  }
498  if (auto w = dyn_cast<tosa::WhileOp>(op)) {
499  if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
500  !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
501  return false;
502  }
503  }
504  return true;
505  }
506 
507  bool attributeCheckRescale(Operation *op) {
508  if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
509  if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
510  !targetEnv.allows(Extension::doubleround)) {
511  op->emitOpError()
512  << "failed attribute check: rounding_mode = DOUBLE_ROUND "
513  << "requires extension [doubleround]";
514  return false;
515  } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
516  !targetEnv.allows(Extension::inexactround)) {
517  op->emitOpError()
518  << "failed attribute check: rounding_mode = INEXACT_ROUND "
519  << "requires extension [inexactround]";
520  return false;
521  }
522  }
523  return true;
524  }
525 
526  // configure profile and level values from pass options profileName and
527  // levelName
528  void configLevelAndProfile() {
529  tosaLevel = TOSA_LEVEL_NONE;
530  if (level == TosaLevelEnum::EightK) {
531  tosaLevel = TOSA_LEVEL_EIGHTK;
532  }
533 
534  if (!profile.empty()) {
535  for (std::string &prof : profile) {
536  auto profSymbol = symbolizeProfile(prof);
537  if (profSymbol) {
538  targetEnv.addProfile(profSymbol.value());
539  } else {
540  llvm::errs() << "unknown TOSA profile name passed in: " << prof
541  << ", supported profiles are `pro_int` and `pro_fp`\n";
542  return signalPassFailure();
543  }
544  }
545  }
546 
547  if (!extension.empty()) {
548  for (std::string &ext : extension) {
549  auto extSymbol = symbolizeExtension(ext);
550  if (extSymbol) {
551  targetEnv.addExtension(extSymbol.value());
552  } else {
553  llvm::errs() << "unknown TOSA extension name passed in: " << ext
554  << ", supported extension are int16, int4, bf16, "
555  << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
556  << "doubleround, inexactround and dynamic\n";
557  return signalPassFailure();
558  }
559  }
560  }
561  }
562 
563  bool CheckVariable(Operation *op);
564  bool CheckVariableReadOrWrite(Operation *op);
565  bool isValidElementType(Type type);
566 
567  SmallVector<
568  std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
569  constCheckers;
570  TosaLevel tosaLevel;
572  TosaProfileCompliance profileComp;
573  tosa::TargetEnv targetEnv;
574 };
575 
576 template <>
577 bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
578  auto op = tosaOp.getOperation();
579  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
580  return false;
581 
582  // rank(output) = rank(input) - 1
583  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
584  return false;
585 
586  return true;
587 }
588 
589 template <>
590 bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
591  auto op = tosaOp.getOperation();
592 
593  // Only the condition input has rank limitation.
594  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
595  return false;
596 
597  return true;
598 }
599 
600 bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
601 #define CHECK_RANKS_AND_SIZES(tosaOp) \
602  if (isa<tosa::tosaOp##Op>(op)) { \
603  if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
604  return false; \
605  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
606  return false; \
607  }
608 
609 #define CHECK_SIZES(tosaOp) \
610  if (isa<tosa::tosaOp##Op>(op)) { \
611  if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
612  return false; \
613  }
614 
615  // Tensor Operators
616  CHECK_RANKS_AND_SIZES(ArgMax);
617  // Activation Functions
618  CHECK_RANKS_AND_SIZES(Clamp);
620  CHECK_RANKS_AND_SIZES(Sigmoid);
621  CHECK_RANKS_AND_SIZES(Tanh);
622  // Elementwise Binary Operators
624  CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
625  CHECK_RANKS_AND_SIZES(BitwiseAnd);
626  CHECK_RANKS_AND_SIZES(BitwiseOr);
627  CHECK_RANKS_AND_SIZES(BitwiseXor);
628  CHECK_RANKS_AND_SIZES(IntDiv);
629  CHECK_RANKS_AND_SIZES(LogicalAnd);
630  CHECK_RANKS_AND_SIZES(LogicalLeftShift);
631  CHECK_RANKS_AND_SIZES(LogicalRightShift);
632  CHECK_RANKS_AND_SIZES(LogicalOr);
633  CHECK_RANKS_AND_SIZES(LogicalXor);
634  CHECK_RANKS_AND_SIZES(Maximum);
635  CHECK_RANKS_AND_SIZES(Minimum);
639  CHECK_RANKS_AND_SIZES(Table);
640  // Elementwise Unary Operators
642  CHECK_RANKS_AND_SIZES(BitwiseNot);
643  CHECK_RANKS_AND_SIZES(Ceil);
647  CHECK_RANKS_AND_SIZES(Floor);
649  CHECK_RANKS_AND_SIZES(LogicalNot);
650  CHECK_RANKS_AND_SIZES(Negate);
651  CHECK_RANKS_AND_SIZES(Reciprocal);
652  CHECK_RANKS_AND_SIZES(Rsqrt);
654  // Elementwise Ternary Operators
655  CHECK_RANKS_AND_SIZES(Select);
656  // Comparison Operators
657  CHECK_RANKS_AND_SIZES(Equal);
658  CHECK_RANKS_AND_SIZES(Greater);
659  CHECK_RANKS_AND_SIZES(GreaterEqual);
660  // Reduction Operators
661  CHECK_RANKS_AND_SIZES(ReduceAll);
662  CHECK_RANKS_AND_SIZES(ReduceAny);
663  CHECK_RANKS_AND_SIZES(ReduceMax);
664  CHECK_RANKS_AND_SIZES(ReduceMin);
665  CHECK_RANKS_AND_SIZES(ReduceProduct);
666  CHECK_RANKS_AND_SIZES(ReduceSum);
667  // Data Layout Operators
668  CHECK_RANKS_AND_SIZES(Concat);
670  CHECK_RANKS_AND_SIZES(Reshape);
671  CHECK_RANKS_AND_SIZES(Reverse);
672  CHECK_RANKS_AND_SIZES(Slice);
673  CHECK_RANKS_AND_SIZES(Tile);
674  CHECK_RANKS_AND_SIZES(Transpose);
675  // Type Conversion
676  CHECK_RANKS_AND_SIZES(Cast);
677  CHECK_RANKS_AND_SIZES(Rescale);
678  // Control Flow Operators
680  // Variable Operators
681  CHECK_RANKS_AND_SIZES(Variable);
682  CHECK_RANKS_AND_SIZES(VariableWrite);
683  CHECK_RANKS_AND_SIZES(VariableRead);
684  // Data Nodes
685  CHECK_RANKS_AND_SIZES(Const);
686  CHECK_RANKS_AND_SIZES(Identity);
687 
688  // For the following operators, check whether the size of each tensor
689  // operand is valid in a given Level.
690 
691  // Tensor Operators
692  CHECK_SIZES(AvgPool2d);
693  CHECK_SIZES(Conv2D);
694  CHECK_SIZES(Conv3D);
695  CHECK_SIZES(DepthwiseConv2D);
696  CHECK_SIZES(TransposeConv2D);
697  CHECK_SIZES(FFT2d);
698  CHECK_SIZES(MatMul);
699  CHECK_SIZES(MaxPool2d);
700  CHECK_SIZES(RFFT2d);
701  // Scatter/Gather Operators
703  CHECK_SIZES(Scatter);
704  // Image Operators
705  CHECK_SIZES(Resize);
706  // Custom Operators
707  CHECK_SIZES(Custom);
708  // Control Flow Operators
709  CHECK_SIZES(While);
710  // Shape Operators
711  CHECK_SIZES(ConstShape);
712 
713 #undef CHECK_RANKS_AND_SIZES
714 #undef CHECK_SIZES
715  return true;
716 }
717 
718 // Perform the Level tensor size check
719 bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
720  const StringRef operandOrResult) {
721  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
722  if (!type.hasRank()) {
723  op->emitOpError() << "failed level check: unranked tensor";
724  return false;
725  }
726  auto shape = type.getShape();
727  for (auto dim : shape) {
728  if (mlir::ShapedType::isDynamic(dim)) {
729  op->emitOpError() << "failed level check: " << operandOrResult
730  << " shape dimension cannot be dynamic";
731  return false;
732  }
733  }
734 
735  int64_t element_bits = type.getElementTypeBitWidth();
736  int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
737  int64_t size = element_bytes * type.getNumElements();
738 
739  // According to 1.11. Tensor Definitions of Tosa spec, the value of
740  // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
741  // defined in 1.7. Levels.
742  // For each tensor, the number of tensor elements multiplied by the
743  // element size in bytes must be representable as a tensor_size_t.
744  const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
745  if (size > max_size) {
746  op->emitOpError()
747  << "failed level check: " << operandOrResult
748  << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
749  return false;
750  }
751  }
752  return true;
753 }
754 
755 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
756  if (tosaLevel == TOSA_LEVEL_NONE) {
757  // no need to do level checks
758  return success();
759  }
760 
761  // additional level checks from spec 0.70
762  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
763  !levelCheckConv<tosa::Conv2DOp>(op) ||
764  !levelCheckConv<tosa::Conv3DOp>(op) ||
765  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
766  !levelCheckFFT<tosa::FFT2dOp>(op) ||
767  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
768  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
769  !levelCheckResize(op)) {
770  return failure();
771  }
772 
773  if (!levelCheckRanksAndSizes(op)) {
774  return failure();
775  }
776 
777  // level check MAX_TENSOR_LIST_SIZE
778  if (!levelCheckListSize(op)) {
779  return failure();
780  }
781 
782  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
783  if (!levelCheckMaxNesting(op)) {
784  return failure();
785  }
786  }
787 
788  return success();
789 }
790 
791 LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
792  if (!attributeCheckRescale(op))
793  return failure();
794  return success();
795 }
796 
797 inline bool CompatibleTypes(const mlir::Type &type,
798  const mlir::Type &declaredType) {
799  // for now, simply use type equality comparison
800  return type == declaredType;
801 }
802 
803 bool TosaValidation::CheckVariable(Operation *op) {
804  if (isa<mlir::tosa::VariableOp>(op)) {
805  mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
806 
807  if (variablesMap.count(nameAttr)) {
808  op->emitOpError() << "name has already been declared";
809  return false;
810  }
811 
812  auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
813  mlir::Type type = typeAttr.getValue();
814 
815  variablesMap[nameAttr] = type;
816  }
817 
818  return true;
819 }
820 
821 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
822  if (isa<mlir::tosa::VariableReadOp>(op) ||
823  isa<mlir::tosa::VariableWriteOp>(op)) {
824  mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
825  if (!variablesMap.count(nameAttr)) {
826  op->emitOpError() << "name has not been declared";
827  return false;
828  }
829 
830  auto varType = variablesMap[nameAttr];
831 
832  for (auto v : op->getOperands()) {
833  auto type = v.getType();
834  if (!CompatibleTypes(type, varType)) {
835  op->emitOpError() << "operand type does not equal variable type";
836  return false;
837  }
838  }
839 
840  for (auto v : op->getResults()) {
841  auto type = v.getType();
842  if (!CompatibleTypes(type, varType)) {
843  op->emitOpError() << "result type does not equal variable type";
844  return false;
845  }
846  }
847  }
848 
849  return true;
850 }
851 
852 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
853  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
854  return failure();
855  }
856  return success();
857 }
858 
859 bool checkErrorIfResize(Operation *op) {
860  auto resize = dyn_cast<tosa::ResizeOp>(op);
861  if (!resize)
862  return true;
863 
864  const Value input = resize.getInput();
865  const Value output = resize.getOutput();
866  const RankedTensorType inputType =
867  llvm::dyn_cast<RankedTensorType>(input.getType());
868  const RankedTensorType outputType =
869  llvm::dyn_cast<RankedTensorType>(output.getType());
870 
871  if (!inputType || !outputType) {
872  op->emitOpError("expect ranked input/output tensor");
873  return false;
874  }
875 
876  // Ensure the image size is supported by GPU APIs and that for integer
877  // implementations, position * stride does not overflow int32_t.
878  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
879  const SmallVector<int64_t, 4> sizes = {
880  outputType.getDimSize(1), outputType.getDimSize(2),
881  inputType.getDimSize(1), inputType.getDimSize(2)};
882  const int64_t *maxDim = llvm::max_element(sizes);
883  if (maxDim != sizes.end() && *maxDim >= 16384) {
884  op->emitOpError("expect input/output height/width dims to be < 16384, ")
885  << "got [OH, OW, IH, IW] = " << sizes;
886  return false;
887  }
888  }
889 
890  SmallVector<int64_t> scale;
891  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
892  return false;
893  }
894 
895  const int64_t scaleYN = scale[0];
896  const int64_t scaleYD = scale[1];
897  const int64_t scaleXN = scale[2];
898  const int64_t scaleXD = scale[3];
899 
900  // Ensure scale values don't overflow int32 accumulator
901  if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
902  op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
903  "got scale_y_n=")
904  << scaleYN << ", scale_x_n=" << scaleXN;
905  return false;
906  }
907 
908  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
909  op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
910  << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
911  return false;
912  }
913 
914  SmallVector<int64_t> offset;
915  SmallVector<int64_t> border;
916  if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
917  !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
918  return false;
919  }
920 
921  const int64_t offsetY = offset[0];
922  const int64_t offsetX = offset[1];
923  // Set a consistent lower limit of 1/16 downscale to simplify
924  // implementations
925  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
926  op->emitOpError(
927  "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
928  << offsetY << "/" << scaleYN;
929  return false;
930  }
931  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
932  op->emitOpError(
933  "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
934  << offsetX << "/" << scaleXN;
935  return false;
936  }
937 
938  const int64_t borderY = border[0];
939  const int64_t borderX = border[1];
940  if (borderY < -16 * scaleYN || borderY >= scaleYN) {
941  op->emitOpError(
942  "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
943  << borderY << "/" << scaleYN;
944  return false;
945  }
946  if (borderX < -16 * scaleXN || borderX >= scaleXN) {
947  op->emitOpError(
948  "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
949  << borderX << "/" << scaleXN;
950  return false;
951  }
952 
953  // The following section of code is mostly duplicated with ResizeOp::verify().
954  //
955  // In TOSA specification, we do not support broadcast behavior.
956  // However, there is a rewrite pattern to materialize broadcast ResizeOp.
957  // It makes invalid TOSA ResizeOp into valid one. To avoid breaking
958  // existing code, we keep the rewrite pattern untouched. So, we need
959  // loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
960  //
961  // Here is a strict checking to conform TOSA specification.
962  // FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
963  auto idivCheck = [](const int64_t lhs,
964  const int64_t rhs) -> std::optional<int64_t> {
965  if (lhs % rhs != 0)
966  return std::nullopt;
967  return lhs / rhs;
968  };
969 
970  const int64_t oh = outputType.getDimSize(1);
971  const int64_t ow = outputType.getDimSize(2);
972  const int64_t ih = inputType.getDimSize(1);
973  const int64_t iw = inputType.getDimSize(2);
974 
975  if (ih != ShapedType::kDynamic) {
976  const std::optional<int64_t> calculatedOutHeightMinusOne =
977  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
978  if (!calculatedOutHeightMinusOne.has_value()) {
979  op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
980  "border_y ")
981  << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
982  << scaleYN << " - " << offsetY << " + " << borderY << ") / "
983  << scaleYD;
984  return false;
985  }
986  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
987  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
988  op->emitOpError("calculated output height did not match expected: ")
989  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
990  return false;
991  }
992  }
993 
994  if (iw != ShapedType::kDynamic) {
995  const std::optional<int64_t> calculatedOutWidthMinusOne =
996  idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
997  if (!calculatedOutWidthMinusOne.has_value()) {
998  op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
999  "border_x ")
1000  << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
1001  << scaleXN << " - " << offsetX << " + " << borderX << ") / "
1002  << scaleXD;
1003  return false;
1004  }
1005  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1006  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
1007  op->emitOpError("calculated output width did not match expected: ")
1008  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
1009  return false;
1010  }
1011  }
1012 
1013  return true;
1014 }
1015 
1016 bool checkErrorIfMul(Operation *op) {
1017  auto mul = dyn_cast<tosa::MulOp>(op);
1018  if (!mul)
1019  return true;
1020 
1021  // REQUIRE(0 <= shift && shift <= 63);
1022  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
1023  ElementsAttr shift_elem;
1024  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
1025  return true;
1026  }
1027  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1028  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
1029  if (inputElemType.isInteger(32)) {
1030  // 0 <= shift <= 63 for int32_t type
1031  if (shift < 0 || shift > 63) {
1032  op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
1033  << shift;
1034  return false;
1035  }
1036  } else {
1037  // shift must be 0 for all other types
1038  if (shift != 0) {
1039  op->emitOpError() << "requires shift = 0 for all input data types that "
1040  "are not int32_t, but got: "
1041  << shift;
1042  return false;
1043  }
1044  }
1045 
1046  return true;
1047 }
1048 
1049 bool checkErrorIfTable(Operation *op) {
1050  auto table = dyn_cast<tosa::TableOp>(op);
1051  if (!table)
1052  return true;
1053 
1054  // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1055  const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1056  const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1057 
1058  const ShapeAdaptor tableShape(table.getTable().getType());
1059  if (tableShape.hasStaticShape()) {
1060  const auto numElements = tableShape.getNumElements();
1061  if (numElements != tableSize) {
1062  op->emitOpError() << "requires table size of " << tableSize << ", got "
1063  << numElements;
1064  return false;
1065  }
1066  }
1067 
1068  return true;
1069 }
1070 
1071 bool checkErrorIfRescale(Operation *op) {
1072  auto rescale = dyn_cast<tosa::RescaleOp>(op);
1073  if (!rescale)
1074  return true;
1075 
1076  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1077  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1078  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1079  !outputType.getElementType().isInteger())
1080  return true;
1081 
1082  auto inElemType = inputType.getElementType();
1083  auto outElemType = outputType.getElementType();
1084  auto inWidth = inElemType.getIntOrFloatBitWidth();
1085  auto outWidth = outElemType.getIntOrFloatBitWidth();
1086 
1087  bool inputUnsigned = rescale.getInputUnsigned();
1088  bool outputUnsigned = rescale.getOutputUnsigned();
1089 
1090  bool scale32 = rescale.getScale32();
1091  auto roundingMode = rescale.getRoundingMode();
1092 
1093  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
1094  if (scale32 && inWidth == 48) {
1095  op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1096  return false;
1097  }
1098 
1099  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1100  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
1101  op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1102  return false;
1103  }
1104 
1105  // ERROR_IF(input_unsigned && output_unsigned)
1106  if (inputUnsigned && outputUnsigned) {
1107  op->emitOpError() << "input and output cannot be both unsigned.";
1108  return false;
1109  }
1110 
1111  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1112  if (outWidth == 32 && inputUnsigned) {
1113  op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1114  return false;
1115  }
1116 
1117  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1118  if (inWidth == 32 && outputUnsigned) {
1119  op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1120  return false;
1121  }
1122 
1123  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1124  if (inWidth == 48 && outputUnsigned) {
1125  op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1126  return false;
1127  }
1128 
1129  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1130  if (inWidth == 48 && inputUnsigned) {
1131  op->emitOpError() << "i48 input type cannot be unsigned.";
1132  return false;
1133  }
1134 
1135  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1136  if (inWidth == 32 && inputUnsigned) {
1137  op->emitOpError() << "i32 input type cannot be unsigned.";
1138  return false;
1139  }
1140 
1141  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1142  if (outWidth == 32 && outputUnsigned) {
1143  op->emitOpError() << "i32 output type cannot be unsigned.";
1144  return false;
1145  }
1146 
1147  return true;
1148 }
1149 
1150 bool checkErrorIfPad(Operation *op) {
1151  auto pad = dyn_cast<tosa::PadOp>(op);
1152  if (!pad)
1153  return true;
1154 
1155  DenseIntElementsAttr paddingAttr;
1156  if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1157  // Pad verifier will catch this
1158  return true;
1159 
1160  for (const APInt &val : paddingAttr.getValues<APInt>()) {
1161  if (val.getSExtValue() < 0) {
1162  op->emitOpError() << "padding value must all be non-negative, got "
1163  << val.getSExtValue();
1164  return false;
1165  }
1166  }
1167 
1168  return true;
1169 }
1170 
1171 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1172  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1173  !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1174  !checkErrorIfPad(op))
1175  return failure();
1176  return success();
1177 }
1178 
1179 bool TosaValidation::isValidElementType(Type type) {
1180  if (isa<FloatType>(type)) {
1181  return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1182  Float8E5M2Type>(type);
1183  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
1184  if (intTy.isSignless()) {
1185  switch (intTy.getWidth()) {
1186  case 1:
1187  case 4:
1188  case 8:
1189  case 16:
1190  case 32:
1191  case 48:
1192  return true;
1193  }
1194  }
1195  } else if (mlir::isa<tosa::shapeType>(type)) {
1196  return true;
1197  }
1198  return false;
1199 }
1200 
1201 void TosaValidation::runOnOperation() {
1202  configLevelAndProfile();
1203 
1204  TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
1205  if (!tosaDialect)
1206  return;
1207 
1208  getOperation().walk([&](Operation *op) {
1209  if (op->getDialect() != tosaDialect)
1210  return;
1211 
1212  // perform valid element type check at the beginning to
1213  // protect rest of code against quantized element types
1214  for (Value operand : op->getOperands()) {
1215  auto elementTy = getElementTypeOrSelf(operand);
1216  if (!isValidElementType(elementTy)) {
1217  op->emitOpError() << "is not profile-aligned: element type "
1218  << elementTy << " is not legal";
1219  return signalPassFailure();
1220  }
1221  }
1222  for (Type resultTy : op->getResultTypes()) {
1223  auto elementTy = getElementTypeOrSelf(resultTy);
1224  if (!isValidElementType(elementTy)) {
1225  op->emitOpError() << "is not profile-aligned: element type "
1226  << elementTy << " is not legal";
1227  return signalPassFailure();
1228  }
1229  }
1230 
1231  if (strictOpSpecAlignment &&
1232  failed(profileComp.checkProfile(op, targetEnv)))
1233  return signalPassFailure();
1234 
1235  if (strictOpSpecAlignment &&
1236  failed(profileComp.checkExtension(op, targetEnv)))
1237  return signalPassFailure();
1238 
1239  if (!allowInvalidOpDatatypeCombinations &&
1240  failed(profileComp.checkInvalid(op))) {
1241  op->emitOpError("illegal: operand/result data types not supported");
1242  return signalPassFailure();
1243  }
1244 
1245  // Some uses of TOSA rely on the constant operands of particular
1246  // operations.
1247  if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
1248  signalPassFailure();
1249 
1250  // do level checks
1251  if (failed(applyLevelCheck(op)))
1252  signalPassFailure();
1253 
1254  // check additional attribute restrictions
1255  if (failed(applyAttributeCheck(op)))
1256  signalPassFailure();
1257 
1258  // do variable type checks
1259  if (failed(applyVariableCheck(op)))
1260  signalPassFailure();
1261 
1262  // do error if checks
1263  if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
1264  signalPassFailure();
1265  });
1266 }
1267 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:223
#define CHECK_RANKS_AND_SIZES(tosaOp)
#define CHECK_SIZES(tosaOp)
@ Gather
const float * table
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense integer vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allows(Profile prof) const
Definition: TargetEnv.h:43
NestedPattern If(const NestedPattern &child)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2336
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369