MLIR  19.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16 
17 #include <string>
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 namespace tosa {
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32 } // namespace tosa
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 namespace {
39 
40 static LogicalResult checkConstantOperandPad(Operation *op) {
41  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
42  DenseElementsAttr paddings;
43  if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
44  return op->emitOpError("padding of pad is not constant");
45 
46  DenseElementsAttr padConst;
47  // Assume this op is zero-padding if padConst is not presented.
48  if (padOp.getPadConst() &&
49  !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
50  return op->emitOpError("pad_const of pad is not constant");
51  }
52  return success();
53 }
54 
55 static LogicalResult checkConstantOperandTranspose(Operation *op) {
56  if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
57  DenseElementsAttr perms;
58  if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
59  return op->emitOpError("perms of transpose is not constant");
60  }
61  return success();
62 }
63 
64 static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
65  if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
66  DenseElementsAttr weight;
67  if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
68  return op->emitOpError("weight of fully_connected is not constant");
69 
70  DenseElementsAttr bias;
71  if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
72  return op->emitOpError("bias of fully_connected is not constant");
73  }
74  return success();
75 }
76 
77 struct TosaLevel {
78  int32_t MAX_RANK = 0;
79  int32_t MAX_KERNEL = 0;
80  int32_t MAX_STRIDE = 0;
81  int32_t MAX_SCALE = 0;
82 
83  // @todo: MAX_LOG2_SIZE value and checks
84 
85  bool operator==(const TosaLevel &rhs) {
86  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
88  }
89 };
90 
91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
93 
94 //===----------------------------------------------------------------------===//
95 // TOSA Validation Pass.
96 //===----------------------------------------------------------------------===//
97 
98 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
99 public:
100  explicit TosaValidation() { populateConstantOperandChecks(); }
101  explicit TosaValidation(const TosaValidationOptions &options)
102  : TosaValidation() {
103  this->profile = options.profile;
104  this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
105  this->level = options.level;
106  }
107  void runOnOperation() final;
108 
109  LogicalResult applyConstantOperandCheck(Operation *op) {
110  for (auto &checker : constCheckers) {
111  if (failed(checker(op)))
112  return failure();
113  }
114  return success();
115  }
116 
117  LogicalResult applyLevelCheck(Operation *op);
118 
119  // check variable read/write data types against variable declarations
120  LogicalResult applyVariableCheck(Operation *op);
121 
122 private:
123  void populateConstantOperandChecks() {
124  constCheckers.emplace_back(checkConstantOperandPad);
125  constCheckers.emplace_back(checkConstantOperandTranspose);
126  constCheckers.emplace_back(checkConstantOperandFullyConnected);
127  }
128 
129  bool levelCheckKernel(Operation *op, int32_t v,
130  const std::string &checkDesc) {
131  if (v > tosaLevel.MAX_KERNEL) {
132  op->emitOpError() << "failed level check: " << checkDesc;
133  return false;
134  }
135  return true;
136  }
137 
138  bool levelCheckStride(Operation *op, int32_t v,
139  const std::string &checkDesc) {
140  if (v > tosaLevel.MAX_STRIDE) {
141  op->emitOpError() << "failed level check: " << checkDesc;
142  return false;
143  }
144  return true;
145  }
146 
147  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
148  if (v > tosaLevel.MAX_SCALE) {
149  op->emitOpError() << "failed level check: " << checkDesc;
150  return false;
151  }
152  return true;
153  }
154 
155  bool levelCheckRank(Operation *op, const Value &v,
156  const std::string &checkDesc) {
157  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
158  if (!type.hasRank()) {
159  op->emitOpError() << "failed level check: unranked tensor";
160  return false;
161  }
162  if (type.getRank() > tosaLevel.MAX_RANK) {
163  op->emitOpError() << "failed level check: " << checkDesc;
164  return false;
165  }
166  }
167  return true;
168  }
169 
170  template <typename T>
171  bool levelCheckRanksFor(Operation *op) {
172  if (dyn_cast<T>(op)) {
173  // level check ranks of all operands and results
174  for (auto v : op->getOperands()) {
175  if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
176  return false;
177  }
178  for (auto v : op->getResults()) {
179  if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
180  return false;
181  }
182  }
183  return true;
184  }
185 
186  bool levelCheckRanks(Operation *op) {
187 #define CHECK_RANKS_FOR(tosaOp) \
188  if (!levelCheckRanksFor<tosaOp##Op>(op)) \
189  return false;
190 
191  // tensor operators:
192  CHECK_RANKS_FOR(ArgMax);
193  // all activation functions:
194  CHECK_RANKS_FOR(Clamp);
195  CHECK_RANKS_FOR(Sigmoid);
196  CHECK_RANKS_FOR(Tanh);
197  // all elementwise binary operators:
198  CHECK_RANKS_FOR(Add);
199  CHECK_RANKS_FOR(ArithmeticRightShift);
200  CHECK_RANKS_FOR(BitwiseAnd);
201  CHECK_RANKS_FOR(BitwiseOr);
202  CHECK_RANKS_FOR(BitwiseXor);
203  CHECK_RANKS_FOR(Div);
204  CHECK_RANKS_FOR(LogicalAnd);
205  CHECK_RANKS_FOR(LogicalLeftShift);
206  CHECK_RANKS_FOR(LogicalRightShift);
207  CHECK_RANKS_FOR(LogicalOr);
208  CHECK_RANKS_FOR(LogicalXor);
209  CHECK_RANKS_FOR(Maximum);
210  CHECK_RANKS_FOR(Minimum);
211  CHECK_RANKS_FOR(Mul);
212  CHECK_RANKS_FOR(Pow);
213  CHECK_RANKS_FOR(Sub);
214  CHECK_RANKS_FOR(Table);
215  // all elementwise unary operators:
216  CHECK_RANKS_FOR(Abs);
217  CHECK_RANKS_FOR(BitwiseNot);
218  CHECK_RANKS_FOR(Ceil);
219  CHECK_RANKS_FOR(Clz);
220  CHECK_RANKS_FOR(Exp);
221  CHECK_RANKS_FOR(Floor);
222  CHECK_RANKS_FOR(Log);
223  CHECK_RANKS_FOR(LogicalNot);
224  CHECK_RANKS_FOR(Negate);
225  CHECK_RANKS_FOR(Reciprocal);
226  CHECK_RANKS_FOR(Rsqrt);
227  // all elementwise ternary operators:
228  CHECK_RANKS_FOR(Select);
229  // all comparison operators:
230  CHECK_RANKS_FOR(Equal);
231  CHECK_RANKS_FOR(Greater);
232  CHECK_RANKS_FOR(GreaterEqual);
233  // all reduction operators:
234  CHECK_RANKS_FOR(ReduceAll);
235  CHECK_RANKS_FOR(ReduceAny);
236  CHECK_RANKS_FOR(ReduceMax);
237  CHECK_RANKS_FOR(ReduceMin);
238  CHECK_RANKS_FOR(ReduceProd);
239  CHECK_RANKS_FOR(ReduceSum);
240  // all data layout operators:
241  CHECK_RANKS_FOR(Concat);
242  CHECK_RANKS_FOR(Pad);
243  CHECK_RANKS_FOR(Reshape);
244  CHECK_RANKS_FOR(Reverse);
245  CHECK_RANKS_FOR(Slice);
246  CHECK_RANKS_FOR(Tile);
247  CHECK_RANKS_FOR(Transpose);
248  // all type conversion operators:
249  CHECK_RANKS_FOR(Cast);
250  CHECK_RANKS_FOR(Rescale);
251  // all data nodes operators:
252  CHECK_RANKS_FOR(Const);
253  CHECK_RANKS_FOR(Identity);
254 
255 #undef CHECK_RANKS_FOR
256  return true;
257  }
258 
259  // Pool Op: level check kernel/stride/pad values
260  template <typename T>
261  bool levelCheckPool(Operation *op) {
262  if (auto poolOp = dyn_cast<T>(op)) {
263  for (auto k : poolOp.getKernel()) {
264  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
265  return false;
266  }
267  }
268  for (auto s : poolOp.getStride()) {
269  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
270  return false;
271  }
272  }
273  for (auto p : poolOp.getPad()) {
274  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
275  return false;
276  }
277  }
278  }
279  return true;
280  }
281 
282  // Conv Op: level check dilation/stride/pad values
283  template <typename T>
284  bool levelCheckConv(Operation *op) {
285  if (auto convOp = dyn_cast<T>(op)) {
286 
287  for (auto k : convOp.getDilation()) {
288  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
289  return false;
290  }
291  }
292  for (auto p : convOp.getPad()) {
293  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
294  return false;
295  }
296  }
297  for (auto s : convOp.getStride()) {
298  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
299  return false;
300  }
301  }
302  auto dilation = convOp.getDilation();
303  if (ShapedType weightType =
304  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
305  auto shape = weightType.getShape();
306  if (isa<tosa::Conv2DOp>(op)) {
307  assert(shape.size() == 4);
308  assert(dilation.size() == 2);
309  if (!levelCheckKernel(op, dilation[0] * shape[1],
310  "dilation_y * KH <= MAX_KERNEL)") ||
311  !levelCheckKernel(op, dilation[1] * shape[2],
312  "dilation_x * KW <= MAX_KERNEL)"))
313  return false;
314  } else if (isa<tosa::Conv3DOp>(op)) {
315  assert(shape.size() == 5);
316  assert(dilation.size() == 3);
317  if (!levelCheckKernel(op, dilation[0] * shape[1],
318  "dilation_d * KD <= MAX_KERNEL)") ||
319  !levelCheckKernel(op, dilation[1] * shape[2],
320  "dilation_y * KH <= MAX_KERNEL)") ||
321  !levelCheckKernel(op, dilation[2] * shape[3],
322  "dilation_x * KW <= MAX_KERNEL)"))
323  return false;
324  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325  assert(shape.size() == 4);
326  assert(dilation.size() == 2);
327  if (!levelCheckKernel(op, dilation[0] * shape[0],
328  "dilation_y * KH <= MAX_KERNEL)") ||
329  !levelCheckKernel(op, dilation[1] * shape[1],
330  "dilation_x * KW <= MAX_KERNEL)"))
331  return false;
332  }
333  }
334  }
335  return true;
336  }
337 
338  // FFT op: level check H, W in input shape [N,H,W]
339  template <typename T>
340  bool levelCheckFFT(Operation *op) {
341  if (isa<T>(op)) {
342  for (auto v : op->getOperands()) {
343  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
344  auto shape = type.getShape();
345  assert(shape.size() == 3);
346  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
347  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
348  return false;
349  }
350  }
351  }
352  }
353  return true;
354  }
355 
356  // TransposeConv2d op: level check kH/kW, outpad, and stride
357  bool levelCheckTransposeConv2d(Operation *op) {
358  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359  if (ShapedType filterType =
360  transpose.getFilter().getType().dyn_cast<ShapedType>()) {
361  auto shape = filterType.getShape();
362  assert(shape.size() == 4);
363  // level check kernel sizes for kH and KW
364  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
365  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
366  return false;
367  }
368  }
369  for (auto p : transpose.getOutPad()) {
370  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
371  return false;
372  }
373  }
374  for (auto s : transpose.getStride()) {
375  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
376  return false;
377  }
378  }
379  }
380  return true;
381  }
382 
383  // Resize op: level check max scales
384  bool levelCheckResize(Operation *op) {
385  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386  auto scale = resize.getScale();
387  int16_t scaleYN = scale[0];
388  int16_t scaleYD = scale[1];
389  int16_t scaleXN = scale[2];
390  int16_t scaleXD = scale[3];
391  if (!levelCheckScale(op, scaleYN / scaleYD,
392  "scale_y_n/scale_y_d <= MAX_SCALE") ||
393  !levelCheckScale(op, scaleXN / scaleXD,
394  "scale_x_n/scale_x_d <= MAX_SCALE")) {
395  return false;
396  }
397  }
398  return true;
399  }
400 
401  // configure profile and level values from pass options profileName and
402  // levelName
403  void configLevelAndProfile() {
404  tosaLevel = TOSA_LEVEL_NONE;
405  if (level == TosaLevelEnum::EightK) {
406  tosaLevel = TOSA_LEVEL_EIGHTK;
407  }
408  }
409 
410  bool CheckVariable(Operation *op);
411  bool CheckVariableReadOrWrite(Operation *op);
412 
413  bool isValidElementType(Type type);
414 
415  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
416  TosaLevel tosaLevel;
418 };
419 
420 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
421  if (tosaLevel == TOSA_LEVEL_NONE) {
422  // no need to do level checks
423  return success();
424  }
425 
426  if (!levelCheckRanks(op)) {
427  return failure();
428  }
429 
430  // additional level checks from spec 0.70
431  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
432  !levelCheckConv<tosa::Conv2DOp>(op) ||
433  !levelCheckConv<tosa::Conv3DOp>(op) ||
434  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
435  !levelCheckFFT<tosa::FFT2dOp>(op) ||
436  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
437  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
438  !levelCheckResize(op)) {
439  return failure();
440  }
441 
442  return success();
443 }
444 
445 inline bool CompatibleTypes(const mlir::Type &type,
446  const mlir::Type &declaredType) {
447  // for now, simply use type equality comparison
448  return type == declaredType;
449 }
450 
451 bool TosaValidation::CheckVariable(Operation *op) {
452  if (isa<mlir::tosa::VariableOp>(op)) {
453  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
454 
455  if (variablesMap.count(nameAttr)) {
456  op->emitOpError() << "name has already been declared";
457  return false;
458  }
459 
460  auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
461  mlir::Type type = typeAttr.getValue();
462 
463  variablesMap[nameAttr] = type;
464  }
465 
466  return true;
467 }
468 
469 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
470  if (isa<mlir::tosa::VariableReadOp>(op) ||
471  isa<mlir::tosa::VariableWriteOp>(op)) {
472  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
473 
474  if (!variablesMap.count(nameAttr)) {
475  op->emitOpError() << "name has not been declared";
476  return false;
477  }
478 
479  auto varType = variablesMap[nameAttr];
480 
481  for (auto v : op->getOperands()) {
482  auto type = v.getType();
483  if (!CompatibleTypes(type, varType)) {
484  op->emitOpError() << "operand type does not equal variable type";
485  return false;
486  }
487  }
488 
489  for (auto v : op->getResults()) {
490  auto type = v.getType();
491  if (!CompatibleTypes(type, varType)) {
492  op->emitOpError() << "result type does not equal variable type";
493  return false;
494  }
495  }
496  }
497 
498  return true;
499 }
500 
501 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
502  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
503  return failure();
504  }
505  return success();
506 }
507 
508 bool TosaValidation::isValidElementType(Type type) {
509  if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
510  return false;
511  }
512  if (type.isF64()) {
513  return false;
514  }
515  if (auto intTy = dyn_cast<IntegerType>(type)) {
516  if (intTy.isUnsigned()) {
517  switch (intTy.getWidth()) {
518  case 8:
519  case 16:
520  return true;
521  default:
522  return false;
523  }
524  } else {
525  // Signless - treated as signed.
526  switch (intTy.getWidth()) {
527  case 1:
528  case 4:
529  case 8:
530  case 16:
531  case 32:
532  case 48:
533  case 64:
534  return true;
535  default:
536  return false;
537  }
538  }
539  return false;
540  }
541  return true;
542 }
543 
544 void TosaValidation::runOnOperation() {
545  configLevelAndProfile();
546  getOperation().walk([&](Operation *op) {
547  for (Value operand : op->getOperands()) {
548  auto elementTy = getElementTypeOrSelf(operand);
549  if (!isValidElementType(elementTy)) {
550  op->emitOpError() << "is not profile-aligned: element type "
551  << elementTy << " is not legal";
552  return signalPassFailure();
553  }
554  }
555  for (Type resultTy : op->getResultTypes()) {
556  auto elementTy = getElementTypeOrSelf(resultTy);
557  if (!isValidElementType(elementTy)) {
558  op->emitOpError() << "is not profile-aligned: element type "
559  << elementTy << " is not legal";
560  return signalPassFailure();
561  }
562  }
563 
564  // Some uses of TOSA rely on the constant operands of particular
565  // operations.
566  if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
567  signalPassFailure();
568 
569  // do level checks
570  if (failed(applyLevelCheck(op)))
571  signalPassFailure();
572 
573  // do variable type checks
574  if (failed(applyVariableCheck(op)))
575  signalPassFailure();
576  });
577 }
578 } // namespace
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26