MLIR  22.0.0git
TosaProfileCompliance.h
Go to the documentation of this file.
1 //===- TosaProfileCompliance.h - Tosa Profile-based Compliance Validation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
10 #define MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
11 
12 #include <unordered_map>
13 
16 
17 #include "mlir/Support/TypeID.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 //===----------------------------------------------------------------------===//
23 // Type Compilance Definition
24 //===----------------------------------------------------------------------===//
25 
26 typedef struct {
28  uint32_t bitWidth;
29 } TypeInfo;
30 
33  // Valid when any of the profile (extension) requirement is meet.
35  // Valid when all of the profile (extension) requirement are meet.
36  allOf
37 };
38 
40  std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
41 
42 template <typename T>
44  // Certain operations require multiple modes enabled.
45  // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
49 };
50 
52  std::unordered_map<std::string, SmallVector<OpComplianceInfo<Profile>>>;
54  std::unordered_map<std::string, SmallVector<OpComplianceInfo<Extension>>>;
55 
56 //===----------------------------------------------------------------------===//
57 // Tosa Profile And Extension Information Depot
58 //===----------------------------------------------------------------------===//
59 
61 public:
63  if (failed(populatationDispatch(op)))
64  op->emitOpError() << "fail to populate the profile info\n";
65  }
66 
67  void addType(Type t) { tyInfo.push_back(convertTypeToInfo(t)); }
68  void addValue(Value v) { tyInfo.push_back(convertValueToInfo(v)); }
69  SmallVector<TypeInfo> getInfo() { return tyInfo; }
70 
71 private:
72  TypeInfo convertTypeToInfo(Type type) {
73  return {type.getTypeID(), type.getIntOrFloatBitWidth()};
74  }
75 
76  TypeInfo convertValueToInfo(Value value) {
77  return convertTypeToInfo(getElementTypeOrSelf(value.getType()));
78  }
79 
80  LogicalResult populatationDispatch(Operation *op);
81 
82  LogicalResult populateProfileInfo(ValueRange operands, Value output);
83 
84  // Base
85  template <typename T>
86  LogicalResult populateProfileInfo(T op) {
87  return op->emitOpError()
88  << "profile requirement for this op has not been defined";
89  }
90  // For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d.
91  template <typename T>
92  LogicalResult populateProfileInfoConv(T op);
93 
94  // For reshape, slice, tile, and transpose.
95  template <typename T>
96  LogicalResult populateProfileInfoDataLayout(T op);
97 
98 private:
99  SmallVector<TypeInfo> tyInfo;
100 };
101 
102 //===----------------------------------------------------------------------===//
103 // Tosa Profile And Extension Compliance Checker
104 //===----------------------------------------------------------------------===//
105 
107 public:
108  explicit TosaProfileCompliance();
109 
110  // Accessor of the compliance info map.
111  template <typename T>
112  std::unordered_map<std::string, SmallVector<OpComplianceInfo<T>>>
114  // Only profile and extension compliance info are provided.
115  return {};
116  }
117 
118  // Verify if the operation is allowed to be executed in the given target
119  // environment.
120  LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
121  LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
122  LogicalResult checkInvalid(Operation *op);
123 
124  template <typename T>
125  LogicalResult checkProfileOrExtension(
126  Operation *op, const tosa::TargetEnv &targetEnv,
127  const SmallVector<ArrayRef<T>> &specDefinedProfileSet);
128 
130  return a.typeID == b.typeID && a.bitWidth == b.bitWidth;
131  }
132 
133  // Find the required profiles or extensions from the compliance info according
134  // to the operand type combination.
135  template <typename T>
137  findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
138 
139  SmallVector<Profile> getCooperativeProfiles(Extension ext) {
140  switch (ext) {
141  case Extension::int16:
142  case Extension::int4:
143  case Extension::doubleround:
144  case Extension::inexactround:
145  return {Profile::pro_int};
146  case Extension::bf16:
147  case Extension::fp8e4m3:
148  case Extension::fp8e5m2:
149  case Extension::fft:
150  return {Profile::pro_fp};
151  case Extension::variable:
152  case Extension::controlflow:
153  case Extension::dynamic:
154  return {Profile::pro_fp, Profile::pro_int};
155  case Extension::none:
156  return {};
157  };
158  llvm_unreachable("bad Extension type");
159  }
160 
161  // Debug utilites.
162  template <typename T>
163  SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);
164 
165  template <typename T>
167  stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
168 
169  static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);
170 
171 private:
172  template <typename T>
173  FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
174 
175  OperationProfileComplianceMap profileComplianceMap;
176  OperationExtensionComplianceMap extensionComplianceMap;
177 };
178 
179 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
ProfileInfoDepot(Operation *op)
SmallVector< TypeInfo > getInfo()
SmallVector< Profile > getCooperativeProfiles(Extension ext)
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition: Types.h:101
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:91
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< VersionedTypeInfo > operandTypeInfoSet
SmallVector< T > mode
mlir::TypeID typeID