10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
116ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
127LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
135LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
142LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
143 return populateProfileInfoDataLayout(op);
147LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
148 return populateProfileInfoDataLayout(op);
152LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
153 return populateProfileInfoDataLayout(op);
157LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
158 return populateProfileInfoDataLayout(op);
162LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
170LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
179LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
187LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
194LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
202LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
211LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
221LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
227LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
233LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
238LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
240#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
241 if (isa<tosa::tosaOp##Op>(op)) { \
242 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
245#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
246 if (isa<tosa::tosaOp##Op>(op)) \
250#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
251 if (isa<tosa::tosaOp##Op>(op)) { \
252 return populateProfileInfo(op->getOperands(), op->getResults()); \
367FailureOr<SmallVector<OpComplianceInfo<T>>>
368TosaProfileCompliance::getOperatorMatchedEntries(
Operation *op) {
371 const auto it = complianceMap.find(opName);
372 if (it == complianceMap.end())
384 if (specRequiredModeSet.size() == 0)
387 const auto maybeOpEntries = getOperatorMatchedEntries<T>(op);
388 if (failed(maybeOpEntries)) {
394 for (
const auto &cands : specRequiredModeSet) {
397 modeCount += cands.size();
401 << (modeCount > 1 ?
" any of " :
" ") <<
"["
404 <<
"] but not enabled in target\n";
409 const auto opEntries = maybeOpEntries.value();
410 if (opEntries.size() == 0) {
426 const auto isVersionCompatible =
429 info.operandTypeInfoSet.front().second};
433 for (
const auto &info : opEntries) {
436 assert(llvm::all_of(info.mode,
437 [&specRequiredModeSet](
const T &mode) {
438 return llvm::is_contained(specRequiredModeSet.front(),
441 "the profile/extension requirement of the operator should be "
442 "included in the profile compliance information");
444 if (isModeAllowed(info) && isVersionCompatible(info))
451 llvm::raw_string_ostream os(message);
454 const size_t numOpEntries = opEntries.size();
455 for (
const auto &[
index, info] : llvm::enumerate(opEntries)) {
456 bool mismatchedVersion =
false;
457 if (!isVersionCompatible(info)) {
458 mismatchedVersion =
true;
459 os <<
"requires specification version compatible with "
464 if (!isModeAllowed(info)) {
465 if (mismatchedVersion)
470 <<
"] profiles/extensions ";
473 if (
index != numOpEntries - 1)
476 os <<
"to be specified in the target environment";
484 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
486 interface.getProfiles());
494 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
496 interface.getExtensions());
502 const auto maybeProfEntries = getOperatorMatchedEntries<Profile>(op);
503 const auto maybeExtEntries = getOperatorMatchedEntries<Extension>(op);
504 if (failed(maybeProfEntries) && failed(maybeExtEntries))
507 const bool hasEntry =
508 (succeeded(maybeProfEntries) && !maybeProfEntries.value().empty()) ||
509 (succeeded(maybeExtEntries) && !maybeExtEntries.value().empty());
513 llvm::raw_string_ostream os(message);
514 os <<
"illegal: operation operand/result data types did not align with any "
515 "profile or extension, got (";
519 for (
const auto &typeInfo : llvm::drop_end(current))
528 const auto searchBestMatch = [&](
auto map) {
529 for (
const auto &complianceInfos : map[opName]) {
530 for (
const auto &versionedTypeInfos :
531 complianceInfos.operandTypeInfoSet) {
533 const int matches = llvm::count_if(
534 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
536 std::get<1>(zipType));
538 if (matches > maxMatches) {
539 maxMatches = matches;
540 bestTypeInfo = typeInfos;
548 os <<
", did you mean (";
549 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
552 os <<
"Otherwise, please refer to the 'supported data types' for '"
553 << opName <<
"' in the specification.";
566 assert(compInfo.size() != 0 &&
567 "profile-based compliance information is empty");
572 if (present.size() == 0)
576 for (
size_t i = 0; i < compInfo.size(); i++) {
578 for (
const auto &set : sets) {
580 assert(present.size() == expected.size() &&
581 "the entries for profile-based compliance do not match between "
582 "the generated metadata and the type definition retrieved from "
588 for (
size_t j = 0;
j < expected.size();
j++) {
596 if (isFound ==
true) {
599 compInfo[i].condition};
600 matchedInfos.push_back(info);
613 for (
const auto &profile : profiles) {
614 if constexpr (std::is_same_v<T, Profile>)
615 debugStrings.push_back(tosa::stringifyProfile(profile));
617 debugStrings.push_back(tosa::stringifyExtension(profile));
627 for (
const auto &profiles : profileSet) {
629 llvm::append_range(debugStrings, tempStrings);
637 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
638 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
640 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
642 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
644 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
646 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
648 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
650 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
652 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
654 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
656 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
658 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
661 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
SmallVector< OpComplianceInfo< T > > findMatchedEntries(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.