MLIR  21.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16 
17 #include <string>
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 namespace tosa {
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32 } // namespace tosa
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 namespace {
39 
40 static LogicalResult checkConstantOperandPad(Operation *op) {
41  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
42  DenseElementsAttr paddings;
43  if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
44  return op->emitOpError("padding of pad is not constant");
45 
46  DenseElementsAttr padConst;
47  // Assume this op is zero-padding if padConst is not presented.
48  if (padOp.getPadConst() &&
49  !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
50  return op->emitOpError("pad_const of pad is not constant");
51  }
52  return success();
53 }
54 
55 static LogicalResult checkConstantOperandTranspose(Operation *op) {
56  if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
57  DenseElementsAttr perms;
58  if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
59  return op->emitOpError("perms of transpose is not constant");
60  }
61  return success();
62 }
63 
64 struct TosaLevel {
65  int32_t MAX_RANK = 0;
66  int32_t MAX_KERNEL = 0;
67  int32_t MAX_STRIDE = 0;
68  int32_t MAX_SCALE = 0;
69 
70  // @todo: MAX_LOG2_SIZE value and checks
71 
72  bool operator==(const TosaLevel &rhs) {
73  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
74  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
75  }
76 };
77 
78 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
79 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
80 
81 //===----------------------------------------------------------------------===//
82 // TOSA Validation Pass.
83 //===----------------------------------------------------------------------===//
84 
85 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
86 public:
87  explicit TosaValidation() { populateConstantOperandChecks(); }
88  explicit TosaValidation(const TosaValidationOptions &options)
89  : TosaValidation() {
90  this->profile = options.profile;
91  this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
92  this->level = options.level;
93  }
94  void runOnOperation() final;
95 
96  LogicalResult applyConstantOperandCheck(Operation *op) {
97  for (auto &checker : constCheckers) {
98  if (failed(checker(op)))
99  return failure();
100  }
101  return success();
102  }
103 
104  LogicalResult applyLevelCheck(Operation *op);
105 
106  // check variable read/write data types against variable declarations
107  LogicalResult applyVariableCheck(Operation *op);
108 
109 private:
110  void populateConstantOperandChecks() {
111  constCheckers.emplace_back(checkConstantOperandPad);
112  constCheckers.emplace_back(checkConstantOperandTranspose);
113  }
114 
115  bool levelCheckKernel(Operation *op, int32_t v,
116  const std::string &checkDesc) {
117  if (v > tosaLevel.MAX_KERNEL) {
118  op->emitOpError() << "failed level check: " << checkDesc;
119  return false;
120  }
121  return true;
122  }
123 
124  bool levelCheckStride(Operation *op, int32_t v,
125  const std::string &checkDesc) {
126  if (v > tosaLevel.MAX_STRIDE) {
127  op->emitOpError() << "failed level check: " << checkDesc;
128  return false;
129  }
130  return true;
131  }
132 
133  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
134  if (v > tosaLevel.MAX_SCALE) {
135  op->emitOpError() << "failed level check: " << checkDesc;
136  return false;
137  }
138  return true;
139  }
140 
141  bool levelCheckRank(Operation *op, const Value &v,
142  const std::string &checkDesc) {
143  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
144  if (!type.hasRank()) {
145  op->emitOpError() << "failed level check: unranked tensor";
146  return false;
147  }
148  if (type.getRank() > tosaLevel.MAX_RANK) {
149  op->emitOpError() << "failed level check: " << checkDesc;
150  return false;
151  }
152  }
153  return true;
154  }
155 
156  template <typename T>
157  bool levelCheckRanksFor(Operation *op) {
158  if (dyn_cast<T>(op)) {
159  // level check ranks of all operands and results
160  for (auto v : op->getOperands()) {
161  if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
162  return false;
163  }
164  for (auto v : op->getResults()) {
165  if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
166  return false;
167  }
168  }
169  return true;
170  }
171 
172  bool levelCheckRanks(Operation *op) {
173 #define CHECK_RANKS_FOR(tosaOp) \
174  if (!levelCheckRanksFor<tosaOp##Op>(op)) \
175  return false;
176 
177  // tensor operators:
178  CHECK_RANKS_FOR(ArgMax);
179  // all activation functions:
180  CHECK_RANKS_FOR(Clamp);
181  CHECK_RANKS_FOR(Sigmoid);
182  CHECK_RANKS_FOR(Tanh);
183  // all elementwise binary operators:
184  CHECK_RANKS_FOR(Add);
185  CHECK_RANKS_FOR(ArithmeticRightShift);
186  CHECK_RANKS_FOR(BitwiseAnd);
187  CHECK_RANKS_FOR(BitwiseOr);
188  CHECK_RANKS_FOR(BitwiseXor);
189  CHECK_RANKS_FOR(IntDiv);
190  CHECK_RANKS_FOR(LogicalAnd);
191  CHECK_RANKS_FOR(LogicalLeftShift);
192  CHECK_RANKS_FOR(LogicalRightShift);
193  CHECK_RANKS_FOR(LogicalOr);
194  CHECK_RANKS_FOR(LogicalXor);
195  CHECK_RANKS_FOR(Maximum);
196  CHECK_RANKS_FOR(Minimum);
197  CHECK_RANKS_FOR(Mul);
198  CHECK_RANKS_FOR(Pow);
199  CHECK_RANKS_FOR(Sub);
200  CHECK_RANKS_FOR(Table);
201  // all elementwise unary operators:
202  CHECK_RANKS_FOR(Abs);
203  CHECK_RANKS_FOR(BitwiseNot);
204  CHECK_RANKS_FOR(Ceil);
205  CHECK_RANKS_FOR(Clz);
206  CHECK_RANKS_FOR(Exp);
207  CHECK_RANKS_FOR(Floor);
208  CHECK_RANKS_FOR(Log);
209  CHECK_RANKS_FOR(LogicalNot);
210  CHECK_RANKS_FOR(Negate);
211  CHECK_RANKS_FOR(Reciprocal);
212  CHECK_RANKS_FOR(Rsqrt);
213  // all elementwise ternary operators:
214  CHECK_RANKS_FOR(Select);
215  // all comparison operators:
216  CHECK_RANKS_FOR(Equal);
217  CHECK_RANKS_FOR(Greater);
218  CHECK_RANKS_FOR(GreaterEqual);
219  // all reduction operators:
220  CHECK_RANKS_FOR(ReduceAll);
221  CHECK_RANKS_FOR(ReduceAny);
222  CHECK_RANKS_FOR(ReduceMax);
223  CHECK_RANKS_FOR(ReduceMin);
224  CHECK_RANKS_FOR(ReduceProd);
225  CHECK_RANKS_FOR(ReduceSum);
226  // all data layout operators:
227  CHECK_RANKS_FOR(Concat);
228  CHECK_RANKS_FOR(Pad);
229  CHECK_RANKS_FOR(Reshape);
230  CHECK_RANKS_FOR(Reverse);
231  CHECK_RANKS_FOR(Slice);
232  CHECK_RANKS_FOR(Tile);
233  CHECK_RANKS_FOR(Transpose);
234  // all type conversion operators:
235  CHECK_RANKS_FOR(Cast);
236  CHECK_RANKS_FOR(Rescale);
237  // all data nodes operators:
238  CHECK_RANKS_FOR(Const);
239  CHECK_RANKS_FOR(Identity);
240 
241 #undef CHECK_RANKS_FOR
242  return true;
243  }
244 
245  // Pool Op: level check kernel/stride/pad values
246  template <typename T>
247  bool levelCheckPool(Operation *op) {
248  if (auto poolOp = dyn_cast<T>(op)) {
249  for (auto k : poolOp.getKernel()) {
250  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
251  return false;
252  }
253  }
254  for (auto s : poolOp.getStride()) {
255  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
256  return false;
257  }
258  }
259  for (auto p : poolOp.getPad()) {
260  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
261  return false;
262  }
263  }
264  }
265  return true;
266  }
267 
268  // Conv Op: level check dilation/stride/pad values
269  template <typename T>
270  bool levelCheckConv(Operation *op) {
271  if (auto convOp = dyn_cast<T>(op)) {
272 
273  for (auto k : convOp.getDilation()) {
274  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
275  return false;
276  }
277  }
278  for (auto p : convOp.getPad()) {
279  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
280  return false;
281  }
282  }
283  for (auto s : convOp.getStride()) {
284  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
285  return false;
286  }
287  }
288  auto dilation = convOp.getDilation();
289  if (ShapedType weightType =
290  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
291  auto shape = weightType.getShape();
292  if (isa<tosa::Conv2DOp>(op)) {
293  assert(shape.size() == 4);
294  assert(dilation.size() == 2);
295  if (!levelCheckKernel(op, dilation[0] * shape[1],
296  "dilation_y * KH <= MAX_KERNEL)") ||
297  !levelCheckKernel(op, dilation[1] * shape[2],
298  "dilation_x * KW <= MAX_KERNEL)"))
299  return false;
300  } else if (isa<tosa::Conv3DOp>(op)) {
301  assert(shape.size() == 5);
302  assert(dilation.size() == 3);
303  if (!levelCheckKernel(op, dilation[0] * shape[1],
304  "dilation_d * KD <= MAX_KERNEL)") ||
305  !levelCheckKernel(op, dilation[1] * shape[2],
306  "dilation_y * KH <= MAX_KERNEL)") ||
307  !levelCheckKernel(op, dilation[2] * shape[3],
308  "dilation_x * KW <= MAX_KERNEL)"))
309  return false;
310  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
311  assert(shape.size() == 4);
312  assert(dilation.size() == 2);
313  if (!levelCheckKernel(op, dilation[0] * shape[0],
314  "dilation_y * KH <= MAX_KERNEL)") ||
315  !levelCheckKernel(op, dilation[1] * shape[1],
316  "dilation_x * KW <= MAX_KERNEL)"))
317  return false;
318  }
319  }
320  }
321  return true;
322  }
323 
324  // FFT op: level check H, W in input shape [N,H,W]
325  template <typename T>
326  bool levelCheckFFT(Operation *op) {
327  if (isa<T>(op)) {
328  for (auto v : op->getOperands()) {
329  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
330  auto shape = type.getShape();
331  assert(shape.size() == 3);
332  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
333  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
334  return false;
335  }
336  }
337  }
338  }
339  return true;
340  }
341 
342  // TransposeConv2d op: level check kH/kW, outpad, and stride
343  bool levelCheckTransposeConv2d(Operation *op) {
344  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
345  if (ShapedType filterType =
346  dyn_cast<ShapedType>(transpose.getWeight().getType())) {
347  auto shape = filterType.getShape();
348  assert(shape.size() == 4);
349  // level check kernel sizes for kH and KW
350  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
351  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
352  return false;
353  }
354  }
355  for (auto p : transpose.getOutPad()) {
356  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
357  return false;
358  }
359  }
360  for (auto s : transpose.getStride()) {
361  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
362  return false;
363  }
364  }
365  }
366  return true;
367  }
368 
369  // Resize op: level check max scales
370  bool levelCheckResize(Operation *op) {
371  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
372  auto scale = resize.getScale();
373  int16_t scaleYN = scale[0];
374  int16_t scaleYD = scale[1];
375  int16_t scaleXN = scale[2];
376  int16_t scaleXD = scale[3];
377  if (!levelCheckScale(op, scaleYN / scaleYD,
378  "scale_y_n/scale_y_d <= MAX_SCALE") ||
379  !levelCheckScale(op, scaleXN / scaleXD,
380  "scale_x_n/scale_x_d <= MAX_SCALE")) {
381  return false;
382  }
383  }
384  return true;
385  }
386 
387  // configure profile and level values from pass options profileName and
388  // levelName
389  void configLevelAndProfile() {
390  tosaLevel = TOSA_LEVEL_NONE;
391  if (level == TosaLevelEnum::EightK) {
392  tosaLevel = TOSA_LEVEL_EIGHTK;
393  }
394 
395  if (!profile.empty()) {
396  for (std::string &prof : profile) {
397  auto profSymbol = symbolizeTosaProfileEnum(prof);
398  if (profSymbol) {
399  enabled_profiles.push_back(profSymbol.value());
400  }
401  }
402  }
403  }
404 
405  bool CheckVariable(Operation *op);
406  bool CheckVariableReadOrWrite(Operation *op);
407 
408  bool isValidElementType(Type type);
409  bool isEnabledProfile(TosaProfileEnum prof) {
410  return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
411  std::end(enabled_profiles);
412  }
413 
414  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
415  SmallVector<TosaProfileEnum, 3> enabled_profiles;
416  TosaLevel tosaLevel;
418 };
419 
420 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
421  if (tosaLevel == TOSA_LEVEL_NONE) {
422  // no need to do level checks
423  return success();
424  }
425 
426  if (!levelCheckRanks(op)) {
427  return failure();
428  }
429 
430  // additional level checks from spec 0.70
431  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
432  !levelCheckConv<tosa::Conv2DOp>(op) ||
433  !levelCheckConv<tosa::Conv3DOp>(op) ||
434  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
435  !levelCheckFFT<tosa::FFT2dOp>(op) ||
436  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
437  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
438  !levelCheckResize(op)) {
439  return failure();
440  }
441 
442  return success();
443 }
444 
445 inline bool CompatibleTypes(const mlir::Type &type,
446  const mlir::Type &declaredType) {
447  // for now, simply use type equality comparison
448  return type == declaredType;
449 }
450 
451 bool TosaValidation::CheckVariable(Operation *op) {
452  if (isa<mlir::tosa::VariableOp>(op)) {
453  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
454 
455  if (variablesMap.count(nameAttr)) {
456  op->emitOpError() << "name has already been declared";
457  return false;
458  }
459 
460  auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
461  mlir::Type type = typeAttr.getValue();
462 
463  variablesMap[nameAttr] = type;
464  }
465 
466  return true;
467 }
468 
469 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
470  if (isa<mlir::tosa::VariableReadOp>(op) ||
471  isa<mlir::tosa::VariableWriteOp>(op)) {
472  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
473 
474  if (!variablesMap.count(nameAttr)) {
475  op->emitOpError() << "name has not been declared";
476  return false;
477  }
478 
479  auto varType = variablesMap[nameAttr];
480 
481  for (auto v : op->getOperands()) {
482  auto type = v.getType();
483  if (!CompatibleTypes(type, varType)) {
484  op->emitOpError() << "operand type does not equal variable type";
485  return false;
486  }
487  }
488 
489  for (auto v : op->getResults()) {
490  auto type = v.getType();
491  if (!CompatibleTypes(type, varType)) {
492  op->emitOpError() << "result type does not equal variable type";
493  return false;
494  }
495  }
496  }
497 
498  return true;
499 }
500 
501 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
502  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
503  return failure();
504  }
505  return success();
506 }
507 
508 bool TosaValidation::isValidElementType(Type type) {
509  if (isa<FloatType>(type)) {
510  if (!isEnabledProfile(TosaProfileEnum::MainInference))
511  return false;
512  return type.isF32() || type.isF16() || type.isBF16();
513  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
514  if (intTy.isSignless()) {
515  switch (intTy.getWidth()) {
516  case 1:
517  case 4:
518  case 8:
519  case 16:
520  case 32:
521  case 48:
522  return true;
523  }
524  }
525  } else if (mlir::isa<tosa::shapeType>(type)) {
526  return true;
527  }
528  return false;
529 }
530 
531 void TosaValidation::runOnOperation() {
532  configLevelAndProfile();
533 
534  TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
535  if (!tosaDialect)
536  return;
537 
538  getOperation().walk([&](Operation *op) {
539  if (op->getDialect() != tosaDialect)
540  return;
541 
542  for (Value operand : op->getOperands()) {
543  auto elementTy = getElementTypeOrSelf(operand);
544  if (!isValidElementType(elementTy)) {
545  op->emitOpError() << "is not profile-aligned: element type "
546  << elementTy << " is not legal";
547  return signalPassFailure();
548  }
549  }
550  for (Type resultTy : op->getResultTypes()) {
551  auto elementTy = getElementTypeOrSelf(resultTy);
552  if (!isValidElementType(elementTy)) {
553  op->emitOpError() << "is not profile-aligned: element type "
554  << elementTy << " is not legal";
555  return signalPassFailure();
556  }
557  }
558 
559  // Some uses of TOSA rely on the constant operands of particular
560  // operations.
561  if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
562  signalPassFailure();
563 
564  // do level checks
565  if (failed(applyLevelCheck(op)))
566  signalPassFailure();
567 
568  // do variable type checks
569  if (failed(applyVariableCheck(op)))
570  signalPassFailure();
571  });
572 }
573 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369