MLIR  22.0.0git
TosaProfileCompliance.cpp
Go to the documentation of this file.
1 //===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "llvm/ADT/StringExtras.h"
11 
12 using namespace mlir;
13 using namespace mlir::tosa;
14 
16  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
27 
28 // The profile-based compliance content below is auto-generated by a script
29 // in https://git.mlplatform.org/tosa/specification.git
31  // End of auto-generated metadata
32 }
33 
34 template <>
36  return profileComplianceMap;
37 }
38 
39 template <>
42  return extensionComplianceMap;
43 }
44 
45 // Base populating function
46 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
47  Value output) {
48  for (auto operand : operands)
49  addValue(operand);
50  addValue(output);
51  return success();
52 }
53 
54 template <>
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56  addValue(op.getInput1().front());
57  addValue(op.getOutput());
58  return success();
59 }
60 
61 template <>
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63  addValue(op.getInput());
64  addValue(op.getInputZp());
65  addValue(op.getOutputZp());
66  addType(op.getAccType());
67  addValue(op.getOutput());
68  return success();
69 }
70 
71 template <typename T>
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73  addValue(op.getInput());
74  addValue(op.getWeight());
75  addValue(op.getBias());
76  addValue(op.getInputZp());
77  addValue(op.getWeightZp());
78  addType(op.getAccType());
79  addValue(op.getOutput());
80  return success();
81 }
82 
83 template <>
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85  return populateProfileInfoConv(op);
86 }
87 
88 template <>
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90  return populateProfileInfoConv(op);
91 }
92 
93 template <>
94 LogicalResult
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96  return populateProfileInfoConv(op);
97 }
98 
99 template <>
100 LogicalResult
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102  return populateProfileInfoConv(op);
103 }
104 
105 template <>
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107  addValue(op.getInput1());
108  addValue(op.getPadConst());
109  addValue(op.getOutput());
110  return success();
111 }
112 
113 template <typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115  addValue(op.getInput1());
116  addValue(op.getOutput());
117  return success();
118 }
119 
120 template <>
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122  return populateProfileInfoDataLayout(op);
123 }
124 
125 template <>
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127  return populateProfileInfoDataLayout(op);
128 }
129 
130 template <>
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132  return populateProfileInfoDataLayout(op);
133 }
134 
135 template <>
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137  return populateProfileInfoDataLayout(op);
138 }
139 
140 template <>
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142  addValue(op.getValues());
143  addValue(op.getIndices());
144  addValue(op.getOutput());
145  return success();
146 }
147 
148 template <>
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150  addValue(op.getValuesIn());
151  addValue(op.getIndices());
152  addValue(op.getInput());
153  addValue(op.getValuesOut());
154  return success();
155 }
156 
157 template <>
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159  addValue(op.getInput1());
160  addValue(op.getInput2());
161  addValue(op.getOutput());
162  return success();
163 }
164 
165 template <>
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167  addValue(op.getInput());
168  addValue(op.getOutput());
169  return success();
170 }
171 
172 template <>
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174  addValue(op.getInputReal());
175  addValue(op.getInputImag());
176  addValue(op.getOutputReal());
177  addValue(op.getOutputImag());
178  return success();
179 }
180 
181 template <>
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183  addValue(op.getInputReal());
184  addValue(op.getOutputReal());
185  addValue(op.getOutputImag());
186  return success();
187 }
188 
189 template <>
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191  addValue(op.getOnTrue());
192  addValue(op.getOnFalse());
193  addValue(op.getOutput());
194  return success();
195 }
196 
197 template <>
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199  addValue(op.getInput());
200  addValue(op.getInputZp());
201  addValue(op.getOutputZp());
202  addValue(op.getOutput());
203  return success();
204 }
205 
206 template <>
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
208  addValue(op.getA());
209  addValue(op.getB());
210  addValue(op.getAZp());
211  addValue(op.getBZp());
212  addValue(op.getOutput());
213  return success();
214 }
215 
216 template <>
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218  addType(op.getType());
219  return success();
220 }
221 
222 template <>
223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
224  addValue(op.getInput1());
225  return success();
226 }
227 
228 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
229 // This helper function only populates the info for the customised operands.
230 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
231  if (isa<tosa::tosaOp##Op>(op)) { \
232  return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
233  }
234 
235 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
236  if (isa<tosa::tosaOp##Op>(op)) \
237  return success();
238 
239 // This helper function populates the info for all operands.
240 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
241  if (isa<tosa::tosaOp##Op>(op)) { \
242  return populateProfileInfo(op->getOperands(), op->getResult(0)); \
243  }
244 
245  // Skip irrelevant operands when they are independent and not tied to any
246  // specific profile/extension.
248  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
251  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
268  POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
269 
270  // For the most of tosa operators, all operands are profile/extension related
271  // and hence are all considered in this profile-based compilance check.
284  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
285  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
286  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
288  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
289  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
290  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
291  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
292  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
294  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
307  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
313  POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
316  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
320  POPULATE_PROFILE_INFO_COMMON(VariableRead)
321 
322  // Type Invariant Extension, a capability extension that is independent
323  // of the data type, meaning any compatible type can be used. No type
324  // constraint for those operations.
325  POPULATE_PROFILE_INFO_SKIP(ConstShape)
329 
330  return failure();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // Tosa Profile And Extension Compliance Checker
335 //===----------------------------------------------------------------------===//
336 
337 template <typename T>
338 FailureOr<SmallVector<T>>
339 TosaProfileCompliance::getOperatorDefinition(Operation *op,
340  CheckCondition &condition) {
341  const std::string opName = op->getName().getStringRef().str();
342  const auto complianceMap = getProfileComplianceMap<T>();
343  const auto it = complianceMap.find(opName);
344  if (it == complianceMap.end())
345  return {};
346 
347  return findMatchedProfile<T>(op, it->second, condition);
348 }
349 
350 template <typename T>
352  Operation *op, const tosa::TargetEnv &targetEnv,
353  const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
354 
355  // None of profile requirement is set in the specification.
356  if (specRequiredModeSet.size() == 0)
357  return success();
358 
360  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
361  if (failed(maybeOpRequiredMode)) {
362  // Operators such as control-flow and shape ops do not have an operand type
363  // restriction. When the profile compliance information of operation is not
364  // found, confirm if the target have enabled the profile required from the
365  // specification.
366  int mode_count = 0;
367  for (const auto &cands : specRequiredModeSet) {
368  if (targetEnv.allowsAnyOf(cands))
369  return success();
370  mode_count += cands.size();
371  }
372 
373  op->emitOpError() << "illegal: requires"
374  << (mode_count > 1 ? " any of " : " ") << "["
375  << llvm::join(stringifyProfile<T>(specRequiredModeSet),
376  ", ")
377  << "] but not enabled in target\n";
378 
379  return failure();
380  }
381 
382  // Find the required profiles or extensions according to the operand type
383  // combination.
384  const auto opRequiredMode = maybeOpRequiredMode.value();
385  if (opRequiredMode.size() == 0) {
386  // No matched restriction found.
387  return success();
388  }
389 
390  if (condition == CheckCondition::allOf &&
391  !targetEnv.allowsAllOf(opRequiredMode)) {
392  op->emitOpError() << "illegal: requires"
393  << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
394  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
395  << "] but not enabled in target\n";
396  return failure();
397  }
398 
399  if (condition == CheckCondition::anyOf &&
400  !targetEnv.allowsAnyOf(opRequiredMode)) {
401  op->emitOpError() << "illegal: requires"
402  << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
403  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
404  << "] but not enabled in target\n";
405  return failure();
406  }
407 
408  // Each extension can contain a list of profiles that it works with, usually
409  // have the same data type.
410  if constexpr (std::is_same_v<T, Extension>) {
411  for (const auto &mode : opRequiredMode) {
412  SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
413  if (!targetEnv.allowsAnyOf(coProfs)) {
414  op->emitOpError() << "illegal: requires ["
415  << llvm::join(stringifyProfile<Profile>(coProfs),
416  ", ")
417  << "] to work with but not enabled in target\n";
418  return failure();
419  }
420  }
421  }
422 
423  // Ensure the profile inference match the profile knowledge of the
424  // specification.
425  for (const auto &cands : specRequiredModeSet) {
426  for (const auto &mode : opRequiredMode) {
427  if (!llvm::is_contained(cands, mode)) {
428  op->emitOpError() << "illegal: requires ["
429  << llvm::join(stringifyProfile<T>(opRequiredMode),
430  ", ")
431  << "] but not included in the profile compliance ["
432  << llvm::join(
433  stringifyProfile<T>(specRequiredModeSet), ", ")
434  << "]\n";
435  return failure();
436  }
437  }
438  }
439 
440  return success();
441 }
442 
443 LogicalResult
445  const tosa::TargetEnv &targetEnv) {
446  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
447  return checkProfileOrExtension<Profile>(op, targetEnv,
448  interface.getProfiles());
449 
450  return success();
451 }
452 
453 LogicalResult
455  const tosa::TargetEnv &targetEnv) {
456  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
457  return checkProfileOrExtension<Extension>(op, targetEnv,
458  interface.getExtensions());
459 
460  return success();
461 }
462 
465  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
466  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
467  if (failed(maybeProfDef) && failed(maybeExtDef))
468  return success();
469 
470  const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
471  (succeeded(maybeExtDef) && !maybeExtDef->empty());
472  if (!hasEntry) {
473  std::string message;
474  llvm::raw_string_ostream os(message);
475  os << "illegal: operation operand/result data types did not align with any "
476  "profile or extension, got (";
477 
478  ProfileInfoDepot depot(op);
479  SmallVector<TypeInfo> current = depot.getInfo();
480  for (const auto &typeInfo : llvm::drop_end(current))
481  os << stringifyTypeInfo(typeInfo) << ",";
482  os << stringifyTypeInfo(current.back()) << ")";
483 
484  // avoid polluting the error message output by outputting only
485  // the best match
486  const std::string opName = op->getName().getStringRef().str();
487  int maxMatches = -1;
488  SmallVector<TypeInfo> bestTypeInfo;
489  const auto searchBestMatch = [&](auto map) {
490  for (const auto &complianceInfos : map[opName]) {
491  for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
492  const int matches = llvm::count_if(
493  llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
494  return isSameTypeInfo(std::get<0>(zipType),
495  std::get<1>(zipType));
496  });
497  if (matches > maxMatches) {
498  maxMatches = matches;
499  bestTypeInfo = typeInfos;
500  }
501  }
502  }
503  };
504  searchBestMatch(getProfileComplianceMap<Profile>());
505  searchBestMatch(getProfileComplianceMap<Extension>());
506 
507  os << ", did you mean (";
508  for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
509  os << stringifyTypeInfo(typeInfo) << ",";
510  os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
511  os << "Otherwise, please refer to the 'supported data types' for '"
512  << opName << "' in the specification.";
513  op->emitOpError(message);
514  return failure();
515  }
516 
517  return success();
518 }
519 
520 // Find the profiles or extensions requirement according to the signature of
521 // type of the operand list.
522 template <typename T>
524  Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
525  CheckCondition &condition) {
526  assert(compInfo.size() != 0 &&
527  "profile-based compliance information is empty");
528 
529  // Populate the type of profile/extension relevant operands.
530  ProfileInfoDepot depot(op);
531  SmallVector<TypeInfo> present = depot.getInfo();
532  if (present.size() == 0)
533  return {};
534 
535  for (size_t i = 0; i < compInfo.size(); i++) {
536  SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
537  for (SmallVector<TypeInfo> expected : sets) {
538  assert(present.size() == expected.size() &&
539  "the entries for profile-based compliance do not match between "
540  "the generated metadata and the type definition retrieved from "
541  " the operation");
542 
543  bool is_found = true;
544  // Compare the type signature between the given operation and the
545  // compliance metadata.
546  for (size_t j = 0; j < expected.size(); j++) {
547  if (!isSameTypeInfo(present[j], expected[j])) {
548  // Verify the next mode set from the list.
549  is_found = false;
550  break;
551  }
552  }
553 
554  if (is_found == true) {
555  condition = compInfo[i].condition;
556  return compInfo[i].mode;
557  }
558  }
559  }
560 
561  return {};
562 }
563 
564 // Debug utilites.
565 template <typename T>
568  SmallVector<StringRef> debugStrings;
569  for (const auto &profile : profiles) {
570  if constexpr (std::is_same_v<T, Profile>)
571  debugStrings.push_back(tosa::stringifyProfile(profile));
572  else
573  debugStrings.push_back(tosa::stringifyExtension(profile));
574  }
575  return debugStrings;
576 }
577 
578 template <typename T>
580  const SmallVector<ArrayRef<T>> &profileSet) {
581  SmallVector<StringRef> debugStrings;
582 
583  for (const auto &profiles : profileSet) {
584  auto tempStrings = stringifyProfile<T>(profiles);
585  llvm::append_range(debugStrings, tempStrings);
586  }
587 
588  return debugStrings;
589 }
590 
593  if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
594  return {"i" + llvm::utostr(typeInfo.bitWidth)};
595  } else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
596  return {"f16"};
597  } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
598  return {"f32"};
599  } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
600  return {"bf16"};
601  } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
602  return {"fp8e4m3"};
603  } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
604  return {"fp8e5m2"};
605  }
606  llvm_unreachable("unknown type");
607 }
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
@ Gather
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allowsAllOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:49
bool allowsAnyOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:45
NestedPattern If(const NestedPattern &child)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
mlir::TypeID typeID
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.