10 #include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
36 return profileComplianceMap;
42 return extensionComplianceMap;
46 LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
48 for (
auto operand : operands)
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
113 template <
typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getOnTrue());
192 addValue(op.getOnFalse());
193 addValue(op.getOutput());
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218 addType(op.getType());
223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
224 addValue(op.getInput1());
228 LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
230 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
231 if (isa<tosa::tosaOp##Op>(op)) { \
232 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
235 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
236 if (isa<tosa::tosaOp##Op>(op)) \
240 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
241 if (isa<tosa::tosaOp##Op>(op)) { \
242 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
337 template <
typename T>
338 FailureOr<SmallVector<T>>
339 TosaProfileCompliance::getOperatorDefinition(
Operation *op,
342 const auto complianceMap = getProfileComplianceMap<T>();
343 const auto it = complianceMap.find(opName);
344 if (it == complianceMap.end())
347 return findMatchedProfile<T>(op, it->second, condition);
350 template <
typename T>
353 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
356 if (specRequiredModeSet.size() == 0)
360 const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
361 if (
failed(maybeOpRequiredMode)) {
367 for (
const auto &cands : specRequiredModeSet) {
370 mode_count += cands.size();
374 << (mode_count > 1 ?
" any of " :
" ") <<
"["
375 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
377 <<
"] but not enabled in target\n";
384 const auto opRequiredMode = maybeOpRequiredMode.value();
385 if (opRequiredMode.size() == 0) {
393 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
394 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
395 <<
"] but not enabled in target\n";
402 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
403 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
404 <<
"] but not enabled in target\n";
410 if constexpr (std::is_same_v<T, Extension>) {
411 for (
const auto &mode : opRequiredMode) {
412 SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
415 << llvm::join(stringifyProfile<Profile>(coProfs),
417 <<
"] to work with but not enabled in target\n";
425 for (
const auto &cands : specRequiredModeSet) {
426 for (
const auto &mode : opRequiredMode) {
427 if (!llvm::is_contained(cands, mode)) {
429 << llvm::join(stringifyProfile<T>(opRequiredMode),
431 <<
"] but not included in the profile compliance ["
433 stringifyProfile<T>(specRequiredModeSet),
", ")
446 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
447 return checkProfileOrExtension<Profile>(op, targetEnv,
448 interface.getProfiles());
456 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
457 return checkProfileOrExtension<Extension>(op, targetEnv,
458 interface.getExtensions());
465 const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
466 const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
470 const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
471 (succeeded(maybeExtDef) && !maybeExtDef->empty());
474 llvm::raw_string_ostream os(message);
475 os <<
"illegal: operation operand/result data types did not align with any "
476 "profile or extension, got (";
479 SmallVector<TypeInfo> current = depot.
getInfo();
480 for (
const auto &typeInfo : llvm::drop_end(current))
481 os << stringifyTypeInfo(typeInfo) <<
",";
482 os << stringifyTypeInfo(current.back()) <<
")";
488 SmallVector<TypeInfo> bestTypeInfo;
489 const auto searchBestMatch = [&](
auto map) {
490 for (
const auto &complianceInfos : map[opName]) {
491 for (
const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
492 const int matches = llvm::count_if(
493 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
494 return isSameTypeInfo(std::get<0>(zipType),
495 std::get<1>(zipType));
497 if (matches > maxMatches) {
498 maxMatches = matches;
499 bestTypeInfo = typeInfos;
504 searchBestMatch(getProfileComplianceMap<Profile>());
505 searchBestMatch(getProfileComplianceMap<Extension>());
507 os <<
", did you mean (";
508 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
509 os << stringifyTypeInfo(typeInfo) <<
",";
510 os << stringifyTypeInfo(bestTypeInfo.back()) <<
")? ";
511 os <<
"Otherwise, please refer to the 'supported data types' for '"
512 << opName <<
"' in the specification.";
522 template <
typename T>
526 assert(compInfo.size() != 0 &&
527 "profile-based compliance information is empty");
531 SmallVector<TypeInfo> present = depot.
getInfo();
532 if (present.size() == 0)
535 for (
size_t i = 0; i < compInfo.size(); i++) {
536 SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
537 for (SmallVector<TypeInfo> expected : sets) {
538 assert(present.size() == expected.size() &&
539 "the entries for profile-based compliance do not match between "
540 "the generated metadata and the type definition retrieved from "
543 bool is_found =
true;
546 for (
size_t j = 0;
j < expected.size();
j++) {
547 if (!isSameTypeInfo(present[
j], expected[
j])) {
554 if (is_found ==
true) {
555 condition = compInfo[i].condition;
556 return compInfo[i].mode;
565 template <
typename T>
568 SmallVector<StringRef> debugStrings;
569 for (
const auto &profile : profiles) {
570 if constexpr (std::is_same_v<T, Profile>)
571 debugStrings.push_back(tosa::stringifyProfile(profile));
573 debugStrings.push_back(tosa::stringifyExtension(profile));
578 template <
typename T>
580 const SmallVector<ArrayRef<T>> &profileSet) {
581 SmallVector<StringRef> debugStrings;
583 for (
const auto &profiles : profileSet) {
584 auto tempStrings = stringifyProfile<T>(profiles);
585 llvm::append_range(debugStrings, tempStrings);
593 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
594 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
595 }
else if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
597 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
599 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
601 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
603 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
606 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.