10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
82ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
93ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
100LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
112LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
113 return populateProfileInfoConv(op);
117LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
118 return populateProfileInfoConv(op);
123ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
124 return populateProfileInfoConv(op);
129ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
130 return populateProfileInfoConv(op);
135ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
146LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
154LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
161LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
162 return populateProfileInfoDataLayout(op);
166LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
167 return populateProfileInfoDataLayout(op);
171LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
172 return populateProfileInfoDataLayout(op);
176LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
177 return populateProfileInfoDataLayout(op);
181LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
189LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
198LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
206LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
213LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
221LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
230LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
240LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
246LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
252LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
257LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
259#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
260 if (isa<tosa::tosaOp##Op>(op)) { \
261 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
264#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
265 if (isa<tosa::tosaOp##Op>(op)) \
269#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
270 if (isa<tosa::tosaOp##Op>(op)) { \
271 return populateProfileInfo(op->getOperands(), op->getResults()); \
388FailureOr<SmallVector<OpComplianceInfo<T>>>
389TosaProfileCompliance::getOperatorMatchedEntries(
Operation *op) {
392 const auto it = complianceMap.find(opName);
393 if (it == complianceMap.end())
405 if (specRequiredModeSet.size() == 0)
408 const auto maybeOpEntries = getOperatorMatchedEntries<T>(op);
409 if (failed(maybeOpEntries)) {
415 for (
const auto &cands : specRequiredModeSet) {
418 modeCount += cands.size();
422 << (modeCount > 1 ?
" any of " :
" ") <<
"["
425 <<
"] but not enabled in target\n";
430 const auto opEntries = maybeOpEntries.value();
431 if (opEntries.size() == 0) {
447 const auto isVersionCompatible =
450 info.operandTypeInfoSet.front().second};
454 for (
const auto &info : opEntries) {
457 assert(llvm::all_of(info.mode,
458 [&specRequiredModeSet](
const T &mode) {
459 return llvm::is_contained(specRequiredModeSet.front(),
462 "the profile/extension requirement of the operator should be "
463 "included in the profile compliance information");
465 if (isModeAllowed(info) && isVersionCompatible(info))
472 llvm::raw_string_ostream os(message);
475 const size_t numOpEntries = opEntries.size();
476 for (
const auto &[
index, info] : llvm::enumerate(opEntries)) {
477 bool mismatchedVersion =
false;
478 if (!isVersionCompatible(info)) {
479 mismatchedVersion =
true;
480 os <<
"requires specification version compatible with "
485 if (!isModeAllowed(info)) {
486 if (mismatchedVersion)
491 <<
"] profiles/extensions ";
494 if (
index != numOpEntries - 1)
497 os <<
"to be specified in the target environment";
505 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
507 interface.getProfiles());
515 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
517 interface.getExtensions());
523 const auto maybeProfEntries = getOperatorMatchedEntries<Profile>(op);
524 const auto maybeExtEntries = getOperatorMatchedEntries<Extension>(op);
525 if (failed(maybeProfEntries) && failed(maybeExtEntries))
528 const bool hasEntry =
529 (succeeded(maybeProfEntries) && !maybeProfEntries.value().empty()) ||
530 (succeeded(maybeExtEntries) && !maybeExtEntries.value().empty());
534 llvm::raw_string_ostream os(message);
535 os <<
"illegal: operation operand/result data types did not align with any "
536 "profile or extension, got (";
540 for (
const auto &typeInfo : llvm::drop_end(current))
549 const auto searchBestMatch = [&](
auto map) {
550 for (
const auto &complianceInfos : map[opName]) {
551 for (
const auto &versionedTypeInfos :
552 complianceInfos.operandTypeInfoSet) {
554 const int matches = llvm::count_if(
555 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
557 std::get<1>(zipType));
559 if (matches > maxMatches) {
560 maxMatches = matches;
561 bestTypeInfo = typeInfos;
569 os <<
", did you mean (";
570 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
573 os <<
"Otherwise, please refer to the 'supported data types' for '"
574 << opName <<
"' in the specification.";
587 assert(compInfo.size() != 0 &&
588 "profile-based compliance information is empty");
593 if (present.size() == 0)
597 for (
size_t i = 0; i < compInfo.size(); i++) {
599 for (
const auto &set : sets) {
601 assert(present.size() == expected.size() &&
602 "the entries for profile-based compliance do not match between "
603 "the generated metadata and the type definition retrieved from "
609 for (
size_t j = 0;
j < expected.size();
j++) {
617 if (isFound ==
true) {
620 compInfo[i].condition};
621 matchedInfos.push_back(info);
634 for (
const auto &profile : profiles) {
635 if constexpr (std::is_same_v<T, Profile>)
636 debugStrings.push_back(tosa::stringifyProfile(profile));
638 debugStrings.push_back(tosa::stringifyExtension(profile));
648 for (
const auto &profiles : profileSet) {
650 llvm::append_range(debugStrings, tempStrings);
658 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
659 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
661 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
663 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
665 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
667 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
669 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
671 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
673 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
675 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
677 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
679 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
682 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
SmallVector< OpComplianceInfo< T > > findMatchedEntries(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.