MLIR  21.0.0git
TosaProfileCompliance.h
Go to the documentation of this file.
1 //===- TosaProfileCompliance.h - Tosa Profile-based Compliance Validation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
10 #define MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
11 
12 #include <unordered_map>
13 
16 
17 #include "mlir/Support/TypeID.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 //===----------------------------------------------------------------------===//
23 // Type Compilance Definition
24 //===----------------------------------------------------------------------===//
25 
26 typedef struct {
28  uint32_t bitWidth;
29 } TypeInfo;
30 
33  // Valid when any of the profile (extension) requirement is meet.
35  // Valid when all of the profile (extension) requirement are meet.
36  allOf
37 };
38 
39 template <typename T>
41  // Certain operations require multiple modes enabled.
42  // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
46 };
47 
49  std::unordered_map<std::string, SmallVector<OpComplianceInfo<Profile>>>;
51  std::unordered_map<std::string, SmallVector<OpComplianceInfo<Extension>>>;
52 
53 //===----------------------------------------------------------------------===//
54 // Tosa Profile And Extension Information Depot
55 //===----------------------------------------------------------------------===//
56 
58 public:
60  if (failed(populatationDispatch(op)))
61  op->emitOpError() << "fail to populate the profile info\n";
62  }
63 
64  void addType(Type t) { tyInfo.push_back(convertTypeToInfo(t)); }
65  void addValue(Value v) { tyInfo.push_back(convertValueToInfo(v)); }
66  SmallVector<TypeInfo> getInfo() { return tyInfo; }
67 
68 private:
69  TypeInfo convertTypeToInfo(Type type) {
70  return {type.getTypeID(), type.getIntOrFloatBitWidth()};
71  }
72 
73  TypeInfo convertValueToInfo(Value value) {
74  return convertTypeToInfo(getElementTypeOrSelf(value.getType()));
75  }
76 
77  LogicalResult populatationDispatch(Operation *op);
78 
79  LogicalResult populateProfileInfo(ValueRange operands, Value output);
80 
81  // Base
82  template <typename T>
83  LogicalResult populateProfileInfo(T op) {
84  return op->emitOpError()
85  << "profile requirement for this op has not been defined";
86  }
87  // For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d.
88  template <typename T>
89  LogicalResult populateProfileInfoConv(T op);
90 
91  // For reshape, slice, tile, and transpose.
92  template <typename T>
93  LogicalResult populateProfileInfoDataLayout(T op);
94 
95 private:
96  SmallVector<TypeInfo> tyInfo;
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // Tosa Profile And Extension Compliance Checker
101 //===----------------------------------------------------------------------===//
102 
104 public:
105  explicit TosaProfileCompliance();
106 
107  // Accessor of the compliance info map.
108  template <typename T>
109  std::unordered_map<std::string, SmallVector<OpComplianceInfo<T>>>
111  // Only profile and extension compliance info are provided.
112  return {};
113  }
114 
115  // Verify if the operation is allowed to be executed in the given target
116  // environment.
117  LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
118  LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
119  LogicalResult checkInvalid(Operation *op);
120 
121  template <typename T>
122  LogicalResult checkProfileOrExtension(
123  Operation *op, const tosa::TargetEnv &targetEnv,
124  const SmallVector<ArrayRef<T>> &specDefinedProfileSet);
125 
127  return a.typeID == b.typeID && a.bitWidth == b.bitWidth;
128  }
129 
130  // Find the required profiles or extensions from the compliance info according
131  // to the operand type combination.
132  template <typename T>
133  SmallVector<T> findMatchedProfile(Operation *op,
135  CheckCondition &condition);
136 
137  SmallVector<Profile> getCooperativeProfiles(Extension ext) {
138  switch (ext) {
139  case Extension::int16:
140  case Extension::int4:
141  case Extension::doubleround:
142  case Extension::inexactround:
143  return {Profile::pro_int};
144  case Extension::bf16:
145  case Extension::fp8e4m3:
146  case Extension::fp8e5m2:
147  case Extension::fft:
148  return {Profile::pro_fp};
149  case Extension::variable:
150  case Extension::controlflow:
151  case Extension::dynamic:
152  return {Profile::pro_fp, Profile::pro_int};
153  case Extension::none:
154  return {};
155  };
156  llvm_unreachable("bad Extension type");
157  }
158 
159  // Debug utilites.
160  template <typename T>
161  SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);
162 
163  template <typename T>
165  stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
166 
167 private:
168  template <typename T>
169  FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
170  CheckCondition &condition);
171 
172  OperationProfileComplianceMap profileComplianceMap;
173  OperationExtensionComplianceMap extensionComplianceMap;
174 };
175 
176 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
ProfileInfoDepot(Operation *op)
SmallVector< TypeInfo > getInfo()
SmallVector< Profile > getCooperativeProfiles(Extension ext)
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition: Types.h:101
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This class represents the capability enabled in the target implementation such as profile,...
Definition: TargetEnv.h:25
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< SmallVector< TypeInfo > > operandTypeInfoSet
SmallVector< T > mode
mlir::TypeID typeID