MLIR  20.0.0git
TosaValidation.cpp
Go to the documentation of this file.
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16 
17 #include <string>
18 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 namespace tosa {
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32 } // namespace tosa
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 namespace {
39 
40 static LogicalResult checkConstantOperandPad(Operation *op) {
41  if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
42  DenseElementsAttr paddings;
43  if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
44  return op->emitOpError("padding of pad is not constant");
45 
46  DenseElementsAttr padConst;
47  // Assume this op is zero-padding if padConst is not presented.
48  if (padOp.getPadConst() &&
49  !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
50  return op->emitOpError("pad_const of pad is not constant");
51  }
52  return success();
53 }
54 
55 static LogicalResult checkConstantOperandTranspose(Operation *op) {
56  if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
57  DenseElementsAttr perms;
58  if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
59  return op->emitOpError("perms of transpose is not constant");
60  }
61  return success();
62 }
63 
64 static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
65  if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
66  DenseElementsAttr weight;
67  if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
68  return op->emitOpError("weight of fully_connected is not constant");
69 
70  DenseElementsAttr bias;
71  if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
72  return op->emitOpError("bias of fully_connected is not constant");
73  }
74  return success();
75 }
76 
77 struct TosaLevel {
78  int32_t MAX_RANK = 0;
79  int32_t MAX_KERNEL = 0;
80  int32_t MAX_STRIDE = 0;
81  int32_t MAX_SCALE = 0;
82 
83  // @todo: MAX_LOG2_SIZE value and checks
84 
85  bool operator==(const TosaLevel &rhs) {
86  return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87  MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
88  }
89 };
90 
91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
93 
94 //===----------------------------------------------------------------------===//
95 // TOSA Validation Pass.
96 //===----------------------------------------------------------------------===//
97 
98 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
99 public:
100  explicit TosaValidation() { populateConstantOperandChecks(); }
101  explicit TosaValidation(const TosaValidationOptions &options)
102  : TosaValidation() {
103  this->profile = options.profile;
104  this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
105  this->level = options.level;
106  }
107  void runOnOperation() final;
108 
109  LogicalResult applyConstantOperandCheck(Operation *op) {
110  for (auto &checker : constCheckers) {
111  if (failed(checker(op)))
112  return failure();
113  }
114  return success();
115  }
116 
117  LogicalResult applyLevelCheck(Operation *op);
118 
119  // check variable read/write data types against variable declarations
120  LogicalResult applyVariableCheck(Operation *op);
121 
122 private:
123  void populateConstantOperandChecks() {
124  constCheckers.emplace_back(checkConstantOperandPad);
125  constCheckers.emplace_back(checkConstantOperandTranspose);
126  constCheckers.emplace_back(checkConstantOperandFullyConnected);
127  }
128 
129  bool levelCheckKernel(Operation *op, int32_t v,
130  const std::string &checkDesc) {
131  if (v > tosaLevel.MAX_KERNEL) {
132  op->emitOpError() << "failed level check: " << checkDesc;
133  return false;
134  }
135  return true;
136  }
137 
138  bool levelCheckStride(Operation *op, int32_t v,
139  const std::string &checkDesc) {
140  if (v > tosaLevel.MAX_STRIDE) {
141  op->emitOpError() << "failed level check: " << checkDesc;
142  return false;
143  }
144  return true;
145  }
146 
147  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
148  if (v > tosaLevel.MAX_SCALE) {
149  op->emitOpError() << "failed level check: " << checkDesc;
150  return false;
151  }
152  return true;
153  }
154 
155  bool levelCheckRank(Operation *op, const Value &v,
156  const std::string &checkDesc) {
157  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
158  if (!type.hasRank()) {
159  op->emitOpError() << "failed level check: unranked tensor";
160  return false;
161  }
162  if (type.getRank() > tosaLevel.MAX_RANK) {
163  op->emitOpError() << "failed level check: " << checkDesc;
164  return false;
165  }
166  }
167  return true;
168  }
169 
170  template <typename T>
171  bool levelCheckRanksFor(Operation *op) {
172  if (dyn_cast<T>(op)) {
173  // level check ranks of all operands and results
174  for (auto v : op->getOperands()) {
175  if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
176  return false;
177  }
178  for (auto v : op->getResults()) {
179  if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
180  return false;
181  }
182  }
183  return true;
184  }
185 
186  bool levelCheckRanks(Operation *op) {
187 #define CHECK_RANKS_FOR(tosaOp) \
188  if (!levelCheckRanksFor<tosaOp##Op>(op)) \
189  return false;
190 
191  // tensor operators:
192  CHECK_RANKS_FOR(ArgMax);
193  // all activation functions:
194  CHECK_RANKS_FOR(Clamp);
195  CHECK_RANKS_FOR(Sigmoid);
196  CHECK_RANKS_FOR(Tanh);
197  // all elementwise binary operators:
198  CHECK_RANKS_FOR(Add);
199  CHECK_RANKS_FOR(ArithmeticRightShift);
200  CHECK_RANKS_FOR(BitwiseAnd);
201  CHECK_RANKS_FOR(BitwiseOr);
202  CHECK_RANKS_FOR(BitwiseXor);
203  CHECK_RANKS_FOR(IntDiv);
204  CHECK_RANKS_FOR(LogicalAnd);
205  CHECK_RANKS_FOR(LogicalLeftShift);
206  CHECK_RANKS_FOR(LogicalRightShift);
207  CHECK_RANKS_FOR(LogicalOr);
208  CHECK_RANKS_FOR(LogicalXor);
209  CHECK_RANKS_FOR(Maximum);
210  CHECK_RANKS_FOR(Minimum);
211  CHECK_RANKS_FOR(Mul);
212  CHECK_RANKS_FOR(Pow);
213  CHECK_RANKS_FOR(Sub);
214  CHECK_RANKS_FOR(Table);
215  // all elementwise unary operators:
216  CHECK_RANKS_FOR(Abs);
217  CHECK_RANKS_FOR(BitwiseNot);
218  CHECK_RANKS_FOR(Ceil);
219  CHECK_RANKS_FOR(Clz);
220  CHECK_RANKS_FOR(Exp);
221  CHECK_RANKS_FOR(Floor);
222  CHECK_RANKS_FOR(Log);
223  CHECK_RANKS_FOR(LogicalNot);
224  CHECK_RANKS_FOR(Negate);
225  CHECK_RANKS_FOR(Reciprocal);
226  CHECK_RANKS_FOR(Rsqrt);
227  // all elementwise ternary operators:
228  CHECK_RANKS_FOR(Select);
229  // all comparison operators:
230  CHECK_RANKS_FOR(Equal);
231  CHECK_RANKS_FOR(Greater);
232  CHECK_RANKS_FOR(GreaterEqual);
233  // all reduction operators:
234  CHECK_RANKS_FOR(ReduceAll);
235  CHECK_RANKS_FOR(ReduceAny);
236  CHECK_RANKS_FOR(ReduceMax);
237  CHECK_RANKS_FOR(ReduceMin);
238  CHECK_RANKS_FOR(ReduceProd);
239  CHECK_RANKS_FOR(ReduceSum);
240  // all data layout operators:
241  CHECK_RANKS_FOR(Concat);
242  CHECK_RANKS_FOR(Pad);
243  CHECK_RANKS_FOR(Reshape);
244  CHECK_RANKS_FOR(Reverse);
245  CHECK_RANKS_FOR(Slice);
246  CHECK_RANKS_FOR(Tile);
247  CHECK_RANKS_FOR(Transpose);
248  // all type conversion operators:
249  CHECK_RANKS_FOR(Cast);
250  CHECK_RANKS_FOR(Rescale);
251  // all data nodes operators:
252  CHECK_RANKS_FOR(Const);
253  CHECK_RANKS_FOR(Identity);
254 
255 #undef CHECK_RANKS_FOR
256  return true;
257  }
258 
259  // Pool Op: level check kernel/stride/pad values
260  template <typename T>
261  bool levelCheckPool(Operation *op) {
262  if (auto poolOp = dyn_cast<T>(op)) {
263  for (auto k : poolOp.getKernel()) {
264  if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
265  return false;
266  }
267  }
268  for (auto s : poolOp.getStride()) {
269  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
270  return false;
271  }
272  }
273  for (auto p : poolOp.getPad()) {
274  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
275  return false;
276  }
277  }
278  }
279  return true;
280  }
281 
282  // Conv Op: level check dilation/stride/pad values
283  template <typename T>
284  bool levelCheckConv(Operation *op) {
285  if (auto convOp = dyn_cast<T>(op)) {
286 
287  for (auto k : convOp.getDilation()) {
288  if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
289  return false;
290  }
291  }
292  for (auto p : convOp.getPad()) {
293  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
294  return false;
295  }
296  }
297  for (auto s : convOp.getStride()) {
298  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
299  return false;
300  }
301  }
302  auto dilation = convOp.getDilation();
303  if (ShapedType weightType =
304  dyn_cast<ShapedType>(op->getOperand(1).getType())) {
305  auto shape = weightType.getShape();
306  if (isa<tosa::Conv2DOp>(op)) {
307  assert(shape.size() == 4);
308  assert(dilation.size() == 2);
309  if (!levelCheckKernel(op, dilation[0] * shape[1],
310  "dilation_y * KH <= MAX_KERNEL)") ||
311  !levelCheckKernel(op, dilation[1] * shape[2],
312  "dilation_x * KW <= MAX_KERNEL)"))
313  return false;
314  } else if (isa<tosa::Conv3DOp>(op)) {
315  assert(shape.size() == 5);
316  assert(dilation.size() == 3);
317  if (!levelCheckKernel(op, dilation[0] * shape[1],
318  "dilation_d * KD <= MAX_KERNEL)") ||
319  !levelCheckKernel(op, dilation[1] * shape[2],
320  "dilation_y * KH <= MAX_KERNEL)") ||
321  !levelCheckKernel(op, dilation[2] * shape[3],
322  "dilation_x * KW <= MAX_KERNEL)"))
323  return false;
324  } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325  assert(shape.size() == 4);
326  assert(dilation.size() == 2);
327  if (!levelCheckKernel(op, dilation[0] * shape[0],
328  "dilation_y * KH <= MAX_KERNEL)") ||
329  !levelCheckKernel(op, dilation[1] * shape[1],
330  "dilation_x * KW <= MAX_KERNEL)"))
331  return false;
332  }
333  }
334  }
335  return true;
336  }
337 
338  // FFT op: level check H, W in input shape [N,H,W]
339  template <typename T>
340  bool levelCheckFFT(Operation *op) {
341  if (isa<T>(op)) {
342  for (auto v : op->getOperands()) {
343  if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
344  auto shape = type.getShape();
345  assert(shape.size() == 3);
346  if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
347  !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
348  return false;
349  }
350  }
351  }
352  }
353  return true;
354  }
355 
356  // TransposeConv2d op: level check kH/kW, outpad, and stride
357  bool levelCheckTransposeConv2d(Operation *op) {
358  if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359  if (ShapedType filterType =
360  dyn_cast<ShapedType>(transpose.getFilter().getType())) {
361  auto shape = filterType.getShape();
362  assert(shape.size() == 4);
363  // level check kernel sizes for kH and KW
364  if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
365  !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
366  return false;
367  }
368  }
369  for (auto p : transpose.getOutPad()) {
370  if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
371  return false;
372  }
373  }
374  for (auto s : transpose.getStride()) {
375  if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
376  return false;
377  }
378  }
379  }
380  return true;
381  }
382 
383  // Resize op: level check max scales
384  bool levelCheckResize(Operation *op) {
385  if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386  auto scale = resize.getScale();
387  int16_t scaleYN = scale[0];
388  int16_t scaleYD = scale[1];
389  int16_t scaleXN = scale[2];
390  int16_t scaleXD = scale[3];
391  if (!levelCheckScale(op, scaleYN / scaleYD,
392  "scale_y_n/scale_y_d <= MAX_SCALE") ||
393  !levelCheckScale(op, scaleXN / scaleXD,
394  "scale_x_n/scale_x_d <= MAX_SCALE")) {
395  return false;
396  }
397  }
398  return true;
399  }
400 
401  // configure profile and level values from pass options profileName and
402  // levelName
403  void configLevelAndProfile() {
404  tosaLevel = TOSA_LEVEL_NONE;
405  if (level == TosaLevelEnum::EightK) {
406  tosaLevel = TOSA_LEVEL_EIGHTK;
407  }
408 
409  if (!profile.empty()) {
410  for (std::string &prof : profile) {
411  auto profSymbol = symbolizeTosaProfileEnum(prof);
412  if (profSymbol) {
413  enabled_profiles.push_back(profSymbol.value());
414  }
415  }
416  }
417  }
418 
419  bool CheckVariable(Operation *op);
420  bool CheckVariableReadOrWrite(Operation *op);
421 
422  bool isValidElementType(Type type);
423  bool isEnabledProfile(TosaProfileEnum prof) {
424  return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
425  std::end(enabled_profiles);
426  }
427 
428  SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
429  SmallVector<TosaProfileEnum, 3> enabled_profiles;
430  TosaLevel tosaLevel;
432 };
433 
434 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
435  if (tosaLevel == TOSA_LEVEL_NONE) {
436  // no need to do level checks
437  return success();
438  }
439 
440  if (!levelCheckRanks(op)) {
441  return failure();
442  }
443 
444  // additional level checks from spec 0.70
445  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
446  !levelCheckConv<tosa::Conv2DOp>(op) ||
447  !levelCheckConv<tosa::Conv3DOp>(op) ||
448  !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
449  !levelCheckFFT<tosa::FFT2dOp>(op) ||
450  !levelCheckPool<tosa::MaxPool2dOp>(op) ||
451  !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
452  !levelCheckResize(op)) {
453  return failure();
454  }
455 
456  return success();
457 }
458 
459 inline bool CompatibleTypes(const mlir::Type &type,
460  const mlir::Type &declaredType) {
461  // for now, simply use type equality comparison
462  return type == declaredType;
463 }
464 
465 bool TosaValidation::CheckVariable(Operation *op) {
466  if (isa<mlir::tosa::VariableOp>(op)) {
467  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
468 
469  if (variablesMap.count(nameAttr)) {
470  op->emitOpError() << "name has already been declared";
471  return false;
472  }
473 
474  auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
475  mlir::Type type = typeAttr.getValue();
476 
477  variablesMap[nameAttr] = type;
478  }
479 
480  return true;
481 }
482 
483 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
484  if (isa<mlir::tosa::VariableReadOp>(op) ||
485  isa<mlir::tosa::VariableWriteOp>(op)) {
486  auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
487 
488  if (!variablesMap.count(nameAttr)) {
489  op->emitOpError() << "name has not been declared";
490  return false;
491  }
492 
493  auto varType = variablesMap[nameAttr];
494 
495  for (auto v : op->getOperands()) {
496  auto type = v.getType();
497  if (!CompatibleTypes(type, varType)) {
498  op->emitOpError() << "operand type does not equal variable type";
499  return false;
500  }
501  }
502 
503  for (auto v : op->getResults()) {
504  auto type = v.getType();
505  if (!CompatibleTypes(type, varType)) {
506  op->emitOpError() << "result type does not equal variable type";
507  return false;
508  }
509  }
510  }
511 
512  return true;
513 }
514 
515 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
516  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
517  return failure();
518  }
519  return success();
520 }
521 
522 bool TosaValidation::isValidElementType(Type type) {
523  if (isa<FloatType>(type)) {
524  if (!isEnabledProfile(TosaProfileEnum::MainInference))
525  return false;
526  return type.isF32() || type.isF16() || type.isBF16();
527  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
528  if (intTy.isSignless()) {
529  switch (intTy.getWidth()) {
530  case 1:
531  case 4:
532  case 8:
533  case 16:
534  case 32:
535  case 48:
536  return true;
537  }
538  }
539  }
540  return false;
541 }
542 
543 void TosaValidation::runOnOperation() {
544  configLevelAndProfile();
545  getOperation().walk([&](Operation *op) {
546  if (!op->getDialect() ||
547  op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
548  return;
549 
550  for (Value operand : op->getOperands()) {
551  auto elementTy = getElementTypeOrSelf(operand);
552  if (!isValidElementType(elementTy)) {
553  op->emitOpError() << "is not profile-aligned: element type "
554  << elementTy << " is not legal";
555  return signalPassFailure();
556  }
557  }
558  for (Type resultTy : op->getResultTypes()) {
559  auto elementTy = getElementTypeOrSelf(resultTy);
560  if (!isValidElementType(elementTy)) {
561  op->emitOpError() << "is not profile-aligned: element type "
562  << elementTy << " is not legal";
563  return signalPassFailure();
564  }
565  }
566 
567  // Some uses of TOSA rely on the constant operands of particular
568  // operations.
569  if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
570  signalPassFailure();
571 
572  // do level checks
573  if (failed(applyLevelCheck(op)))
574  signalPassFailure();
575 
576  // do variable type checks
577  if (failed(applyVariableCheck(op)))
578  signalPassFailure();
579  });
580 }
581 } // namespace
static llvm::ManagedStatic< PassManagerOptions > options
#define CHECK_RANKS_FOR(tosaOp)
An attribute that represents a reference to a dense vector or tensor object.
StringRef getNamespace() const
Definition: Dialect.h:54
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isF16() const
Definition: Types.cpp:57
bool isBF16() const
Definition: Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369