MLIR 22.0.0git
TosaProfileCompliance.h
Go to the documentation of this file.
1//===- TosaProfileCompliance.h - Tosa Profile-based Compliance Validation -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
10#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
11
12#include <unordered_map>
13
16
17#include "mlir/Support/TypeID.h"
18
19using namespace mlir;
20using namespace mlir::tosa;
21
22//===----------------------------------------------------------------------===//
23// Type Compilance Definition
24//===----------------------------------------------------------------------===//
25
26typedef struct {
28 uint32_t bitWidth;
29} TypeInfo;
30
33 // Valid when any of the profile (extension) requirement is meet.
35 // Valid when all of the profile (extension) requirement are meet.
37};
38
40 std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
41
42template <typename T>
44 // Certain operations require multiple modes enabled.
45 // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
49};
50
52 std::unordered_map<std::string, SmallVector<OpComplianceInfo<Profile>>>;
54 std::unordered_map<std::string, SmallVector<OpComplianceInfo<Extension>>>;
55
56//===----------------------------------------------------------------------===//
57// Tosa Profile And Extension Information Depot
58//===----------------------------------------------------------------------===//
59
61public:
63 if (failed(populatationDispatch(op)))
64 op->emitOpError() << "fail to populate the profile info\n";
65 }
66
67 void addType(Type t) { tyInfo.push_back(convertTypeToInfo(t)); }
68 void addValue(Value v) { tyInfo.push_back(convertValueToInfo(v)); }
69 SmallVector<TypeInfo> getInfo() { return tyInfo; }
70
71private:
72 TypeInfo convertTypeToInfo(Type type) {
73 return {type.getTypeID(), tosa::getBitWidth(type)};
74 }
75
76 TypeInfo convertValueToInfo(Value value) {
77 return convertTypeToInfo(getElementTypeOrSelf(value.getType()));
78 }
79
80 LogicalResult populatationDispatch(Operation *op);
81
82 // Add input operands and output results to the profile type info list
83 LogicalResult populateProfileInfo(ValueRange operands, ValueRange results);
84
85 // Base
86 template <typename T>
87 LogicalResult populateProfileInfo(T op) {
88 return op->emitOpError()
89 << "profile requirement for this op has not been defined";
90 }
91 // For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d.
92 template <typename T>
93 LogicalResult populateProfileInfoConv(T op);
94
95 // For reshape, slice, tile, and transpose.
96 template <typename T>
97 LogicalResult populateProfileInfoDataLayout(T op);
98
99private:
100 SmallVector<TypeInfo> tyInfo;
101};
102
103//===----------------------------------------------------------------------===//
104// Tosa Profile And Extension Compliance Checker
105//===----------------------------------------------------------------------===//
106
108public:
109 explicit TosaProfileCompliance();
110
111 // Accessor of the compliance info map.
112 template <typename T>
113 std::unordered_map<std::string, SmallVector<OpComplianceInfo<T>>>
115 // Only profile and extension compliance info are provided.
116 return {};
117 }
118
119 // Verify if the operation is allowed to be executed in the given target
120 // environment.
121 LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
122 LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
123 LogicalResult checkInvalid(Operation *op);
124
125 template <typename T>
126 LogicalResult checkProfileOrExtension(
127 Operation *op, const tosa::TargetEnv &targetEnv,
128 const SmallVector<ArrayRef<T>> &specDefinedProfileSet);
129
131 return a.typeID == b.typeID && a.bitWidth == b.bitWidth;
132 }
133
134 // Find the required profiles or extensions from the compliance info according
135 // to the operand type combination.
136 template <typename T>
139
141 switch (ext) {
142 case Extension::int16:
143 case Extension::int4:
144 case Extension::doubleround:
145 case Extension::inexactround:
146 return {Profile::pro_int};
147 case Extension::bf16:
148 case Extension::fp8e4m3:
149 case Extension::fp8e5m2:
150 case Extension::fft:
151 case Extension::mxfp:
152 return {Profile::pro_fp};
153 case Extension::variable:
154 case Extension::controlflow:
155 case Extension::dynamic:
156 case Extension::int64:
157 return {Profile::pro_fp, Profile::pro_int};
158 case Extension::none:
159 return {};
160 };
161 llvm_unreachable("bad Extension type");
162 }
163
164 // Debug utilites.
165 template <typename T>
167
168 template <typename T>
170 stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
171
172 static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);
173
174private:
175 template <typename T>
176 FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
177
178 OperationProfileComplianceMap profileComplianceMap;
179 OperationExtensionComplianceMap extensionComplianceMap;
180};
181
182#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
ProfileInfoDepot(Operation *op)
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< Profile > getCooperativeProfiles(Extension ext)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:97
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:609
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< VersionedTypeInfo > operandTypeInfoSet
SmallVector< T > mode
mlir::TypeID typeID