10 #include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
36 return profileComplianceMap;
42 return extensionComplianceMap;
46 LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
48 for (
auto operand : operands)
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
113 template <
typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getInput2());
192 addValue(op.getInput3());
193 addValue(op.getOutput());
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
222 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
230 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
231 addValue(op.getCondition());
236 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
243 LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
245 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
246 if (isa<tosa::tosaOp##Op>(op)) { \
247 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
250 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
251 if (isa<tosa::tosaOp##Op>(op)) \
255 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
256 if (isa<tosa::tosaOp##Op>(op)) { \
257 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
352 template <
typename T>
353 FailureOr<SmallVector<T>>
354 TosaProfileCompliance::getOperatorDefinition(
Operation *op,
357 const auto complianceMap = getProfileComplianceMap<T>();
358 const auto it = complianceMap.find(opName);
359 if (it == complianceMap.end())
362 return findMatchedProfile<T>(op, it->second, condition);
365 template <
typename T>
368 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
371 if (specRequiredModeSet.size() == 0)
375 const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
376 if (failed(maybeOpRequiredMode)) {
382 for (
const auto &cands : specRequiredModeSet) {
385 mode_count += cands.size();
389 << (mode_count > 1 ?
" any of " :
" ") <<
"["
390 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
392 <<
"] but not enabled in target\n";
399 const auto opRequiredMode = maybeOpRequiredMode.value();
400 if (opRequiredMode.size() == 0) {
408 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
409 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
410 <<
"] but not enabled in target\n";
417 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
418 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
419 <<
"] but not enabled in target\n";
425 if constexpr (std::is_same_v<T, Extension>) {
426 for (
const auto &mode : opRequiredMode) {
427 SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
430 << llvm::join(stringifyProfile<Profile>(coProfs),
432 <<
"] to work with but not enabled in target\n";
440 for (
const auto &cands : specRequiredModeSet) {
441 for (
size_t i = 0; i < opRequiredMode.size(); i++) {
442 if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
445 << llvm::join(stringifyProfile<T>(opRequiredMode),
447 <<
"] but not included in the profile compliance ["
449 stringifyProfile<T>(specRequiredModeSet),
", ")
462 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
463 return checkProfileOrExtension<Profile>(op, targetEnv,
464 interface.getProfiles());
472 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
473 return checkProfileOrExtension<Extension>(op, targetEnv,
474 interface.getExtensions());
481 const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
482 const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
483 if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
484 !maybeProfDef.value().size() && !maybeExtDef.value().size())
492 template <
typename T>
496 assert(compInfo.size() != 0 &&
497 "profile-based compliance information is empty");
501 SmallVector<TypeInfo> present = depot.
getInfo();
502 if (present.size() == 0)
505 for (
size_t i = 0; i < compInfo.size(); i++) {
506 SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
507 for (SmallVector<TypeInfo> expected : sets) {
508 assert(present.size() == expected.size() &&
509 "the entries for profile-based compliance do not match between "
510 "the generated metadata and the type definition retrieved from "
513 bool is_found =
true;
516 for (
size_t j = 0;
j < expected.size();
j++) {
517 if (!isSameTypeInfo(present[
j], expected[
j])) {
524 if (is_found ==
true) {
525 condition = compInfo[i].condition;
526 return compInfo[i].mode;
535 template <
typename T>
538 SmallVector<StringRef> debugStrings;
539 for (
const auto &profile : profiles) {
540 if constexpr (std::is_same_v<T, Profile>)
541 debugStrings.push_back(tosa::stringifyProfile(profile));
543 debugStrings.push_back(tosa::stringifyExtension(profile));
548 template <
typename T>
550 const SmallVector<ArrayRef<T>> &profileSet) {
551 SmallVector<StringRef> debugStrings;
553 for (
const auto &profiles : profileSet) {
554 auto tempStrings = stringifyProfile<T>(profiles);
555 debugStrings.insert(debugStrings.end(), tempStrings.begin(),
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.