MLIR  21.0.0git
TosaProfileCompliance.cpp
Go to the documentation of this file.
1 //===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "llvm/ADT/StringExtras.h"
11 
12 using namespace mlir;
13 using namespace mlir::tosa;
14 
16  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
27 
28 // The profile-based compliance content below is auto-generated by a script
29 // in https://git.mlplatform.org/tosa/specification.git
31  // End of auto-generated metadata
32 }
33 
34 template <>
36  return profileComplianceMap;
37 }
38 
39 template <>
42  return extensionComplianceMap;
43 }
44 
45 // Base populating function
46 LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
47  Value output) {
48  for (auto operand : operands)
49  addValue(operand);
50  addValue(output);
51  return success();
52 }
53 
54 template <>
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56  addValue(op.getInput1().front());
57  addValue(op.getOutput());
58  return success();
59 }
60 
61 template <>
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63  addValue(op.getInput());
64  addValue(op.getInputZp());
65  addValue(op.getOutputZp());
66  addType(op.getAccType());
67  addValue(op.getOutput());
68  return success();
69 }
70 
71 template <typename T>
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73  addValue(op.getInput());
74  addValue(op.getWeight());
75  addValue(op.getBias());
76  addValue(op.getInputZp());
77  addValue(op.getWeightZp());
78  addType(op.getAccType());
79  addValue(op.getOutput());
80  return success();
81 }
82 
83 template <>
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85  return populateProfileInfoConv(op);
86 }
87 
88 template <>
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90  return populateProfileInfoConv(op);
91 }
92 
93 template <>
94 LogicalResult
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96  return populateProfileInfoConv(op);
97 }
98 
99 template <>
100 LogicalResult
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102  return populateProfileInfoConv(op);
103 }
104 
105 template <>
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107  addValue(op.getInput1());
108  addValue(op.getPadConst());
109  addValue(op.getOutput());
110  return success();
111 }
112 
113 template <typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115  addValue(op.getInput1());
116  addValue(op.getOutput());
117  return success();
118 }
119 
120 template <>
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122  return populateProfileInfoDataLayout(op);
123 }
124 
125 template <>
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127  return populateProfileInfoDataLayout(op);
128 }
129 
130 template <>
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132  return populateProfileInfoDataLayout(op);
133 }
134 
135 template <>
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137  return populateProfileInfoDataLayout(op);
138 }
139 
140 template <>
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142  addValue(op.getValues());
143  addValue(op.getIndices());
144  addValue(op.getOutput());
145  return success();
146 }
147 
148 template <>
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150  addValue(op.getValuesIn());
151  addValue(op.getIndices());
152  addValue(op.getInput());
153  addValue(op.getValuesOut());
154  return success();
155 }
156 
157 template <>
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159  addValue(op.getInput1());
160  addValue(op.getInput2());
161  addValue(op.getOutput());
162  return success();
163 }
164 
165 template <>
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167  addValue(op.getInput());
168  addValue(op.getOutput());
169  return success();
170 }
171 
172 template <>
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174  addValue(op.getInputReal());
175  addValue(op.getInputImag());
176  addValue(op.getOutputReal());
177  addValue(op.getOutputImag());
178  return success();
179 }
180 
181 template <>
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183  addValue(op.getInputReal());
184  addValue(op.getOutputReal());
185  addValue(op.getOutputImag());
186  return success();
187 }
188 
189 template <>
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191  addValue(op.getInput2());
192  addValue(op.getInput3());
193  addValue(op.getOutput());
194  return success();
195 }
196 
197 template <>
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199  addValue(op.getInput());
200  addValue(op.getInputZp());
201  addValue(op.getOutputZp());
202  addValue(op.getOutput());
203  return success();
204 }
205 
206 template <>
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
208  addValue(op.getA());
209  addValue(op.getB());
210  addValue(op.getAZp());
211  addValue(op.getBZp());
212  addValue(op.getOutput());
213  return success();
214 }
215 
216 template <>
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218  ::mlir::Attribute attr = op.getInitialValueAttr();
219  if (attr == nullptr)
220  return failure();
221 
222  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
223  addType(getElementTypeOrSelf(typedAttr));
224  return success();
225  }
226  return failure();
227 }
228 
229 template <>
230 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
231  addValue(op.getInput1());
232  return success();
233 }
234 
235 template <>
236 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
237  addValue(op.getCondition());
238  return success();
239 }
240 
241 template <>
242 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
243  Block *block = &op.getCondGraph().front();
244  Operation *terminator = block->getTerminator();
245  addValue(terminator->getOperands().front());
246  return success();
247 }
248 
249 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
250 // This helper function only populates the info for the customised operands.
251 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
252  if (isa<tosa::tosaOp##Op>(op)) { \
253  return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
254  }
255 
256 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
257  if (isa<tosa::tosaOp##Op>(op)) \
258  return success();
259 
260 // This helper function populates the info for all operands.
261 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
262  if (isa<tosa::tosaOp##Op>(op)) { \
263  return populateProfileInfo(op->getOperands(), op->getResult(0)); \
264  }
265 
266  // Skip irrelevant operands when they are independent and not tied to any
267  // specific profile/extension.
269  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
272  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
289  POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
292 
293  // For the most of tosa operators, all operands are profile/extension related
294  // and hence are all considered in this profile-based compilance check.
307  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
308  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
309  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
311  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
312  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
313  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
314  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
315  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
317  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
330  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
336  POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
339  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
343  POPULATE_PROFILE_INFO_COMMON(VariableRead)
344 
345  // Type Invariant Extension, a capability extension that is independent
346  // of the data type, meaning any compatible type can be used. No type
347  // constraint for those operations.
348  POPULATE_PROFILE_INFO_SKIP(ConstShape)
350 
351  return failure();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // Tosa Profile And Extension Compliance Checker
356 //===----------------------------------------------------------------------===//
357 
358 template <typename T>
359 FailureOr<SmallVector<T>>
360 TosaProfileCompliance::getOperatorDefinition(Operation *op,
361  CheckCondition &condition) {
362  const std::string opName = op->getName().getStringRef().str();
363  const auto complianceMap = getProfileComplianceMap<T>();
364  const auto it = complianceMap.find(opName);
365  if (it == complianceMap.end())
366  return {};
367 
368  return findMatchedProfile<T>(op, it->second, condition);
369 }
370 
371 template <typename T>
373  Operation *op, const tosa::TargetEnv &targetEnv,
374  const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
375 
376  // None of profile requirement is set in the specification.
377  if (specRequiredModeSet.size() == 0)
378  return success();
379 
381  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
382  if (failed(maybeOpRequiredMode)) {
383  // Operators such as control-flow and shape ops do not have an operand type
384  // restriction. When the profile compliance information of operation is not
385  // found, confirm if the target have enabled the profile required from the
386  // specification.
387  int mode_count = 0;
388  for (const auto &cands : specRequiredModeSet) {
389  if (targetEnv.allowsAnyOf(cands))
390  return success();
391  mode_count += cands.size();
392  }
393 
394  op->emitOpError() << "illegal: requires"
395  << (mode_count > 1 ? " any of " : " ") << "["
396  << llvm::join(stringifyProfile<T>(specRequiredModeSet),
397  ", ")
398  << "] but not enabled in target\n";
399 
400  return failure();
401  }
402 
403  // Find the required profiles or extensions according to the operand type
404  // combination.
405  const auto opRequiredMode = maybeOpRequiredMode.value();
406  if (opRequiredMode.size() == 0) {
407  // No matched restriction found.
408  return success();
409  }
410 
411  if (condition == CheckCondition::allOf &&
412  !targetEnv.allowsAllOf(opRequiredMode)) {
413  op->emitOpError() << "illegal: requires"
414  << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
415  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
416  << "] but not enabled in target\n";
417  return failure();
418  }
419 
420  if (condition == CheckCondition::anyOf &&
421  !targetEnv.allowsAnyOf(opRequiredMode)) {
422  op->emitOpError() << "illegal: requires"
423  << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
424  << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
425  << "] but not enabled in target\n";
426  return failure();
427  }
428 
429  // Each extension can contain a list of profiles that it works with, usually
430  // have the same data type.
431  if constexpr (std::is_same_v<T, Extension>) {
432  for (const auto &mode : opRequiredMode) {
433  SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
434  if (!targetEnv.allowsAnyOf(coProfs)) {
435  op->emitOpError() << "illegal: requires ["
436  << llvm::join(stringifyProfile<Profile>(coProfs),
437  ", ")
438  << "] to work with but not enabled in target\n";
439  return failure();
440  }
441  }
442  }
443 
444  // Ensure the profile inference match the profile knowledge of the
445  // specification.
446  for (const auto &cands : specRequiredModeSet) {
447  for (size_t i = 0; i < opRequiredMode.size(); i++) {
448  if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
449  cands.end()) {
450  op->emitOpError() << "illegal: requires ["
451  << llvm::join(stringifyProfile<T>(opRequiredMode),
452  ", ")
453  << "] but not included in the profile compliance ["
454  << llvm::join(
455  stringifyProfile<T>(specRequiredModeSet), ", ")
456  << "]\n";
457  return failure();
458  }
459  }
460  }
461 
462  return success();
463 }
464 
465 LogicalResult
467  const tosa::TargetEnv &targetEnv) {
468  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
469  return checkProfileOrExtension<Profile>(op, targetEnv,
470  interface.getProfiles());
471 
472  return success();
473 }
474 
475 LogicalResult
477  const tosa::TargetEnv &targetEnv) {
478  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
479  return checkProfileOrExtension<Extension>(op, targetEnv,
480  interface.getExtensions());
481 
482  return success();
483 }
484 
487  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
488  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
489  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
490  !maybeProfDef.value().size() && !maybeExtDef.value().size())
491  return failure();
492 
493  return success();
494 }
495 
496 // Find the profiles or extensions requirement according to the signature of
497 // type of the operand list.
498 template <typename T>
500  Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
501  CheckCondition &condition) {
502  assert(compInfo.size() != 0 &&
503  "profile-based compliance information is empty");
504 
505  // Populate the type of profile/extension relevant operands.
506  ProfileInfoDepot depot(op);
507  SmallVector<TypeInfo> present = depot.getInfo();
508  if (present.size() == 0)
509  return {};
510 
511  for (size_t i = 0; i < compInfo.size(); i++) {
512  SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
513  for (SmallVector<TypeInfo> expected : sets) {
514  assert(present.size() == expected.size() &&
515  "the entries for profile-based compliance do not match between "
516  "the generated metadata and the type definition retrieved from "
517  " the operation");
518 
519  bool is_found = true;
520  // Compare the type signature between the given operation and the
521  // compliance metadata.
522  for (size_t j = 0; j < expected.size(); j++) {
523  if (!isSameTypeInfo(present[j], expected[j])) {
524  // Verify the next mode set from the list.
525  is_found = false;
526  break;
527  }
528  }
529 
530  if (is_found == true) {
531  condition = compInfo[i].condition;
532  return compInfo[i].mode;
533  }
534  }
535  }
536 
537  return {};
538 }
539 
540 // Debug utilites.
541 template <typename T>
544  SmallVector<StringRef> debugStrings;
545  for (const auto &profile : profiles) {
546  if constexpr (std::is_same_v<T, Profile>)
547  debugStrings.push_back(tosa::stringifyProfile(profile));
548  else
549  debugStrings.push_back(tosa::stringifyExtension(profile));
550  }
551  return debugStrings;
552 }
553 
554 template <typename T>
556  const SmallVector<ArrayRef<T>> &profileSet) {
557  SmallVector<StringRef> debugStrings;
558 
559  for (const auto &profiles : profileSet) {
560  auto tempStrings = stringifyProfile<T>(profiles);
561  llvm::append_range(debugStrings, tempStrings);
562  }
563 
564  return debugStrings;
565 }
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
@ Gather
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
bool allowsAllOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:51
bool allowsAnyOf(ArrayRef< Profile > profs) const
Definition: TargetEnv.h:45
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.