10 #include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
36 return profileComplianceMap;
42 return extensionComplianceMap;
46 LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
48 for (
auto operand : operands)
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
113 template <
typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getInput2());
192 addValue(op.getInput3());
193 addValue(op.getOutput());
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218 addType(op.getType());
223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
224 addValue(op.getInput1());
229 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
230 addValue(op.getCondition());
235 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
242 LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
244 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
245 if (isa<tosa::tosaOp##Op>(op)) { \
246 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
249 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
250 if (isa<tosa::tosaOp##Op>(op)) \
254 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
255 if (isa<tosa::tosaOp##Op>(op)) { \
256 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
351 template <
typename T>
352 FailureOr<SmallVector<T>>
353 TosaProfileCompliance::getOperatorDefinition(
Operation *op,
356 const auto complianceMap = getProfileComplianceMap<T>();
357 const auto it = complianceMap.find(opName);
358 if (it == complianceMap.end())
361 return findMatchedProfile<T>(op, it->second, condition);
364 template <
typename T>
367 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
370 if (specRequiredModeSet.size() == 0)
374 const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
375 if (failed(maybeOpRequiredMode)) {
381 for (
const auto &cands : specRequiredModeSet) {
384 mode_count += cands.size();
388 << (mode_count > 1 ?
" any of " :
" ") <<
"["
389 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
391 <<
"] but not enabled in target\n";
398 const auto opRequiredMode = maybeOpRequiredMode.value();
399 if (opRequiredMode.size() == 0) {
407 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
408 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
409 <<
"] but not enabled in target\n";
416 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
417 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
418 <<
"] but not enabled in target\n";
424 if constexpr (std::is_same_v<T, Extension>) {
425 for (
const auto &mode : opRequiredMode) {
426 SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
429 << llvm::join(stringifyProfile<Profile>(coProfs),
431 <<
"] to work with but not enabled in target\n";
439 for (
const auto &cands : specRequiredModeSet) {
440 for (
const auto &mode : opRequiredMode) {
441 if (!llvm::is_contained(cands, mode)) {
443 << llvm::join(stringifyProfile<T>(opRequiredMode),
445 <<
"] but not included in the profile compliance ["
447 stringifyProfile<T>(specRequiredModeSet),
", ")
460 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
461 return checkProfileOrExtension<Profile>(op, targetEnv,
462 interface.getProfiles());
470 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
471 return checkProfileOrExtension<Extension>(op, targetEnv,
472 interface.getExtensions());
479 const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
480 const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
482 if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
483 !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
485 llvm::raw_string_ostream os(message);
486 os <<
"illegal: operation operand/result data types did not align with any "
487 "profile or extension, got (";
490 SmallVector<TypeInfo> current = depot.
getInfo();
491 for (
const auto &typeInfo : llvm::drop_end(current))
492 os << stringifyTypeInfo(typeInfo) <<
",";
493 os << stringifyTypeInfo(current.back()) <<
")";
499 SmallVector<TypeInfo> bestTypeInfo;
500 const auto searchBestMatch = [&](
auto map) {
501 for (
const auto &complianceInfos : map[opName]) {
502 for (
const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
503 const int matches = llvm::count_if(
504 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
505 return isSameTypeInfo(std::get<0>(zipType),
506 std::get<1>(zipType));
508 if (matches > maxMatches) {
509 maxMatches = matches;
510 bestTypeInfo = typeInfos;
515 searchBestMatch(getProfileComplianceMap<Profile>());
516 searchBestMatch(getProfileComplianceMap<Extension>());
518 os <<
", did you mean (";
519 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
520 os << stringifyTypeInfo(typeInfo) <<
",";
521 os << stringifyTypeInfo(bestTypeInfo.back()) <<
")? ";
522 os <<
"Otherwise, please refer to the 'supported data types' for '"
523 << opName <<
"' in the specification.";
533 template <
typename T>
537 assert(compInfo.size() != 0 &&
538 "profile-based compliance information is empty");
542 SmallVector<TypeInfo> present = depot.
getInfo();
543 if (present.size() == 0)
546 for (
size_t i = 0; i < compInfo.size(); i++) {
547 SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
548 for (SmallVector<TypeInfo> expected : sets) {
549 assert(present.size() == expected.size() &&
550 "the entries for profile-based compliance do not match between "
551 "the generated metadata and the type definition retrieved from "
554 bool is_found =
true;
557 for (
size_t j = 0;
j < expected.size();
j++) {
558 if (!isSameTypeInfo(present[
j], expected[
j])) {
565 if (is_found ==
true) {
566 condition = compInfo[i].condition;
567 return compInfo[i].mode;
576 template <
typename T>
579 SmallVector<StringRef> debugStrings;
580 for (
const auto &profile : profiles) {
581 if constexpr (std::is_same_v<T, Profile>)
582 debugStrings.push_back(tosa::stringifyProfile(profile));
584 debugStrings.push_back(tosa::stringifyExtension(profile));
589 template <
typename T>
591 const SmallVector<ArrayRef<T>> &profileSet) {
592 SmallVector<StringRef> debugStrings;
594 for (
const auto &profiles : profileSet) {
595 auto tempStrings = stringifyProfile<T>(profiles);
596 llvm::append_range(debugStrings, tempStrings);
604 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
605 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
606 }
else if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
608 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
610 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
612 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
614 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
617 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.