MLIR  21.0.0git
TosaProfileCompliance.cpp
Go to the documentation of this file.
1 //===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "llvm/ADT/StringExtras.h"
11 
12 using namespace mlir;
13 using namespace mlir::tosa;
14 
16  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
27 
28 // The profile-based compliance content below is auto-generated by a script
29 // in https://git.mlplatform.org/tosa/specification.git
31  // End of auto-generated metadata
32 }
33 
34 template <>
36  return profileComplianceMap;
37 }
38 
39 template <>
42  return extensionComplianceMap;
43 }
44 
45 // Base populating function
46 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
47  Value output) {
48  for (auto operand : operands)
49  addValue(operand);
50  addValue(output);
51  return success();
52 }
53 
54 template <>
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56  addValue(op.getInput1().front());
57  addValue(op.getOutput());
58  return success();
59 }
60 
61 template <>
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63  addValue(op.getInput());
64  addValue(op.getInputZp());
65  addValue(op.getOutputZp());
66  addType(op.getAccType());
67  addValue(op.getOutput());
68  return success();
69 }
70 
71 template <typename T>
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73  addValue(op.getInput());
74  addValue(op.getWeight());
75  addValue(op.getBias());
76  addValue(op.getInputZp());
77  addValue(op.getWeightZp());
78  addType(op.getAccType());
79  addValue(op.getOutput());
80  return success();
81 }
82 
83 template <>
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85  return populateProfileInfoConv(op);
86 }
87 
88 template <>
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90  return populateProfileInfoConv(op);
91 }
92 
93 template <>
94 LogicalResult
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96  return populateProfileInfoConv(op);
97 }
98 
99 template <>
100 LogicalResult
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102  return populateProfileInfoConv(op);
103 }
104 
105 template <>
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107  addValue(op.getInput1());
108  addValue(op.getPadConst());
109  addValue(op.getOutput());
110  return success();
111 }
112 
113 template <typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115  addValue(op.getInput1());
116  addValue(op.getOutput());
117  return success();
118 }
119 
120 template <>
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122  return populateProfileInfoDataLayout(op);
123 }
124 
125 template <>
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127  return populateProfileInfoDataLayout(op);
128 }
129 
130 template <>
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132  return populateProfileInfoDataLayout(op);
133 }
134 
135 template <>
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137  return populateProfileInfoDataLayout(op);
138 }
139 
140 template <>
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142  addValue(op.getValues());
143  addValue(op.getIndices());
144  addValue(op.getOutput());
145  return success();
146 }
147 
148 template <>
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150  addValue(op.getValuesIn());
151  addValue(op.getIndices());
152  addValue(op.getInput());
153  addValue(op.getValuesOut());
154  return success();
155 }
156 
157 template <>
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159  addValue(op.getInput1());
160  addValue(op.getInput2());
161  addValue(op.getOutput());
162  return success();
163 }
164 
165 template <>
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167  addValue(op.getInput());
168  addValue(op.getOutput());
169  return success();
170 }
171 
172 template <>
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174  addValue(op.getInputReal());
175  addValue(op.getInputImag());
176  addValue(op.getOutputReal());
177  addValue(op.getOutputImag());
178  return success();
179 }
180 
181 template <>
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183  addValue(op.getInputReal());
184  addValue(op.getOutputReal());
185  addValue(op.getOutputImag());
186  return success();
187 }
188 
189 template <>
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191  addValue(op.getInput2());
192  addValue(op.getInput3());
193  addValue(op.getOutput());
194  return success();
195 }
196 
197 template <>
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199  addValue(op.getInput());
200  addValue(op.getInputZp());
201  addValue(op.getOutputZp());
202  addValue(op.getOutput());
203  return success();
204 }
205 
206 template <>
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
208  addValue(op.getA());
209  addValue(op.getB());
210  addValue(op.getAZp());
211  addValue(op.getBZp());
212  addValue(op.getOutput());
213  return success();
214 }
215 
216 template <>
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218  ::mlir::Attribute attr = op.getInitialValueAttr();
219  if (attr == nullptr)
220  return failure();
221 
222  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
223  addType(getElementTypeOrSelf(typedAttr));
224  return success();
225  }
226  return failure();
227 }
228 
229 template <>
230 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
231  addValue(op.getCondition());
232  return success();
233 }
234 
235 template <>
236 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
237  Block *block = &op.getCondGraph().front();
238  Operation *terminator = block->getTerminator();
239  addValue(terminator->getOperands().front());
240  return success();
241 }
242 
243 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
244 // This helper function only populates the info for the customised operands.
245 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
246  if (isa<tosa::tosaOp##Op>(op)) { \
247  return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
248  }
249 
250 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
251  if (isa<tosa::tosaOp##Op>(op)) \
252  return success();
253 
254 // This helper function populates the info for all operands.
255 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
256  if (isa<tosa::tosaOp##Op>(op)) { \
257  return populateProfileInfo(op->getOperands(), op->getResult(0)); \
258  }
259 
260  // Skip irrelevant operands when they are independent and not tied to any
261  // specific profile/extension.
263  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
266  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
285 
286  // For the most of tosa operators, all operands are profile/extension related
287  // and hence are all considered in this profile-based compilance check.
300  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
301  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
302  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
304  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
305  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
306  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
307  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
308  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
310  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
323  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
329  POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
332  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
336  POPULATE_PROFILE_INFO_COMMON(VariableRead)
337  POPULATE_PROFILE_INFO_COMMON(VariableWrite)
338 
339  // Type Invariant Extension, a capability extension that is independent
340  // of the data type, meaning any compatible type can be used. No type
341  // constraint for those operations.
342  POPULATE_PROFILE_INFO_SKIP(ConstShape)
344 
345  return failure();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // Tosa Profile And Extension Compliance Checker
350 //===----------------------------------------------------------------------===//
351 
352 template <typename T>
353 FailureOr<SmallVector<T>>
354 TosaProfileCompliance::getOperatorDefinition(Operation *op,
355  CheckCondition &condition) {
356  const std::string opName = op->getName().getStringRef().str();
357  const auto complianceMap = getProfileComplianceMap<T>();
358  const auto it = complianceMap.find(opName);
359  if (it == complianceMap.end())
360  return {};
361 
362  return findMatchedProfile<T>(op, it->second, condition);
363 }
364 
365 template <typename T>
367  Operation *op, const tosa::TargetEnv &targetEnv,
368  const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
369 
370  // None of profile requirement is set in the specification.
371  if (specRequiredModeSet.size() == 0)
372  return success();
373 
375  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
376  if (failed(maybeOpRequiredMode)) {
377  // Operators such as control-flow and shape ops do not have an operand type
378  // restriction. When the profile compliance information of operation is not
379  // found, confirm if the target have enabled the profile required from the
380  // specification.
381  int mode_count = 0;
382  for (const auto &cands : specRequiredModeSet) {
383  if (targetEnv.allowsAnyOf(cands))
384  return success();
385  mode_count += cands.size();
386  }
387 
388  op->emitOpError() << "illegal: requires"
389  << (mode_count > 1 ? " any of " : " ") << "["
390  << llvm::join(stringifyProfile<T>(specRequiredModeSet),
391  ", ")
392  << "] but not enabled in target\n";
393 
394  return failure();
395  }
396 
397  // Find the required profiles or extensions according to the operand type
398  // combination.
399  const auto opRequiredMode = maybeOpRequiredMode.value();
400  if (opRequiredMode.size() == 0) {
401  // No matched restriction found.
402  return success();
403  }
404 
405  if (condition == CheckCondition::allOf &&
406  !targetEnv.allowsAllOf(opRequiredMode)) {
407  op->emitOpError() << "illegal: requires"
408  << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
409  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
410  << "] but not enabled in target\n";
411  return failure();
412  }
413 
414  if (condition == CheckCondition::anyOf &&
415  !targetEnv.allowsAnyOf(opRequiredMode)) {
416  op->emitOpError() << "illegal: requires"
417  << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
418  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
419  << "] but not enabled in target\n";
420  return failure();
421  }
422 
423  // Each extension can contain a list of profiles that it works with, usually
424  // have the same data type.
425  if constexpr (std::is_same_v<T, Extension>) {
426  for (const auto &mode : opRequiredMode) {
427  SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
428  if (!targetEnv.allowsAnyOf(coProfs)) {
429  op->emitOpError() << "illegal: requires ["
430  << llvm::join(stringifyProfile<Profile>(coProfs),
431  ", ")
432  << "] to work with but not enabled in target\n";
433  return failure();
434  }
435  }
436  }
437 
438  // Ensure the profile inference match the profile knowledge of the
439  // specification.
440  for (const auto &cands : specRequiredModeSet) {
441  for (size_t i = 0; i < opRequiredMode.size(); i++) {
442  if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
443  cands.end()) {
444  op->emitOpError() << "illegal: requires ["
445  << llvm::join(stringifyProfile<T>(opRequiredMode),
446  ", ")
447  << "] but not included in the profile compliance ["
448  << llvm::join(
449  stringifyProfile<T>(specRequiredModeSet), ", ")
450  << "]\n";
451  return failure();
452  }
453  }
454  }
455 
456  return success();
457 }
458 
459 LogicalResult
461  const tosa::TargetEnv &targetEnv) {
462  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
463  return checkProfileOrExtension<Profile>(op, targetEnv,
464  interface.getProfiles());
465 
466  return success();
467 }
468 
469 LogicalResult
471  const tosa::TargetEnv &targetEnv) {
472  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
473  return checkProfileOrExtension<Extension>(op, targetEnv,
474  interface.getExtensions());
475 
476  return success();
477 }
478 
481  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
482  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
483  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
484  !maybeProfDef.value().size() && !maybeExtDef.value().size())
485  return failure();
486 
487  return success();
488 }
489 
490 // Find the profiles or extensions requirement according to the signature of
491 // type of the operand list.
492 template <typename T>
494  Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
495  CheckCondition &condition) {
496  assert(compInfo.size() != 0 &&
497  "profile-based compliance information is empty");
498 
499  // Populate the type of profile/extension relevant operands.
500  ProfileInfoDepot depot(op);
501  SmallVector<TypeInfo> present = depot.getInfo();
502  if (present.size() == 0)
503  return {};
504 
505  for (size_t i = 0; i < compInfo.size(); i++) {
506  SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
507  for (SmallVector<TypeInfo> expected : sets) {
508  assert(present.size() == expected.size() &&
509  "the entries for profile-based compliance do not match between "
510  "the generated metadata and the type definition retrieved from "
511  " the operation");
512 
513  bool is_found = true;
514  // Compare the type signature between the given operation and the
515  // compliance metadata.
516  for (size_t j = 0; j < expected.size(); j++) {
517  if (!isSameTypeInfo(present[j], expected[j])) {
518  // Verify the next mode set from the list.
519  is_found = false;
520  break;
521  }
522  }
523 
524  if (is_found == true) {
525  condition = compInfo[i].condition;
526  return compInfo[i].mode;
527  }
528  }
529  }
530 
531  return {};
532 }
533 
534 // Debug utilites.
535 template <typename T>
538  SmallVector<StringRef> debugStrings;
539  for (const auto &profile : profiles) {
540  if constexpr (std::is_same_v<T, Profile>)
541  debugStrings.push_back(tosa::stringifyProfile(profile));
542  else
543  debugStrings.push_back(tosa::stringifyExtension(profile));
544  }
545  return debugStrings;
546 }
547 
548 template <typename T>
550  const SmallVector<ArrayRef<T>> &profileSet) {
551  SmallVector<StringRef> debugStrings;
552 
553  for (const auto &profiles : profileSet) {
554  auto tempStrings = stringifyProfile<T>(profiles);
555  debugStrings.insert(debugStrings.end(), tempStrings.begin(),
556  tempStrings.end());
557  }
558 
559  return debugStrings;
560 }
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
@ Gather
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allowsAllOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:51
bool allowsAnyOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:45
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.