10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
82ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
93ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
100LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
112LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
113 return populateProfileInfoConv(op);
117LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
118 return populateProfileInfoConv(op);
123ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
124 return populateProfileInfoConv(op);
129ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
130 return populateProfileInfoConv(op);
135ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
146LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
154LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
161LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
162 return populateProfileInfoDataLayout(op);
166LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
167 return populateProfileInfoDataLayout(op);
171LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
172 return populateProfileInfoDataLayout(op);
176LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
177 return populateProfileInfoDataLayout(op);
181LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
189LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RowGatherOp op) {
198ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
199 for (
Value value : op.getValues())
209LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
218LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
226LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
233LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
241LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
250LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
260LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
266LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
272LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
277LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
279#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
280 if (isa<tosa::tosaOp##Op>(op)) { \
281 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
284#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
285 if (isa<tosa::tosaOp##Op>(op)) \
289#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
290 if (isa<tosa::tosaOp##Op>(op)) { \
291 return populateProfileInfo(op->getOperands(), op->getResults()); \
410FailureOr<SmallVector<OpComplianceInfo<T>>>
411TosaProfileCompliance::getOperatorMatchedEntries(
Operation *op) {
414 const auto it = complianceMap.find(opName);
415 if (it == complianceMap.end())
427 if (specRequiredModeSet.size() == 0)
430 const auto maybeOpEntries = getOperatorMatchedEntries<T>(op);
431 if (failed(maybeOpEntries)) {
437 for (
const auto &cands : specRequiredModeSet) {
440 modeCount += cands.size();
444 << (modeCount > 1 ?
" any of " :
" ") <<
"["
447 <<
"] but not enabled in target\n";
452 const auto opEntries = maybeOpEntries.value();
453 if (opEntries.size() == 0) {
469 const auto isVersionCompatible =
472 info.operandTypeInfoSet.front().second};
476 for (
const auto &info : opEntries) {
479 assert(llvm::all_of(info.mode,
480 [&specRequiredModeSet](
const T &mode) {
481 return llvm::is_contained(specRequiredModeSet.front(),
484 "the profile/extension requirement of the operator should be "
485 "included in the profile compliance information");
487 if (isModeAllowed(info) && isVersionCompatible(info))
494 llvm::raw_string_ostream os(message);
497 const size_t numOpEntries = opEntries.size();
498 for (
const auto &[
index, info] : llvm::enumerate(opEntries)) {
499 bool mismatchedVersion =
false;
500 if (!isVersionCompatible(info)) {
501 mismatchedVersion =
true;
502 os <<
"requires specification version compatible with "
507 if (!isModeAllowed(info)) {
508 if (mismatchedVersion)
513 <<
"] profiles/extensions ";
516 if (
index != numOpEntries - 1)
519 os <<
"to be specified in the target environment";
527 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
529 interface.getProfiles());
537 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
539 interface.getExtensions());
545 const auto maybeProfEntries = getOperatorMatchedEntries<Profile>(op);
546 const auto maybeExtEntries = getOperatorMatchedEntries<Extension>(op);
547 if (failed(maybeProfEntries) && failed(maybeExtEntries))
550 const bool hasEntry =
551 (succeeded(maybeProfEntries) && !maybeProfEntries.value().empty()) ||
552 (succeeded(maybeExtEntries) && !maybeExtEntries.value().empty());
556 llvm::raw_string_ostream os(message);
557 os <<
"illegal: operation operand/result data types did not align with any "
558 "profile or extension, got (";
562 for (
const auto &typeInfo : llvm::drop_end(current))
571 const auto searchBestMatch = [&](
auto map) {
572 for (
const auto &complianceInfos : map[opName]) {
573 for (
const auto &versionedTypeInfos :
574 complianceInfos.operandTypeInfoSet) {
576 if (current.size() != typeInfos.size())
578 const int matches = llvm::count_if(
579 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
581 std::get<1>(zipType));
583 if (matches > maxMatches) {
584 maxMatches = matches;
585 bestTypeInfo = typeInfos;
593 os <<
", did you mean (";
594 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
597 os <<
"Otherwise, please refer to the 'supported data types' for '"
598 << opName <<
"' in the specification.";
611 assert(compInfo.size() != 0 &&
612 "profile-based compliance information is empty");
617 if (present.size() == 0)
621 for (
size_t i = 0; i < compInfo.size(); i++) {
623 for (
const auto &set : sets) {
628 if (present.size() != expected.size())
634 for (
size_t j = 0;
j < expected.size();
j++) {
642 if (isFound ==
true) {
645 compInfo[i].condition};
646 matchedInfos.push_back(info);
659 for (
const auto &profile : profiles) {
660 if constexpr (std::is_same_v<T, Profile>)
661 debugStrings.push_back(tosa::stringifyProfile(profile));
663 debugStrings.push_back(tosa::stringifyExtension(profile));
673 for (
const auto &profiles : profileSet) {
675 llvm::append_range(debugStrings, tempStrings);
683 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
684 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
686 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
688 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
690 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
692 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
694 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
696 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
698 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
700 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
702 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
704 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
707 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
SmallVector< OpComplianceInfo< T > > findMatchedEntries(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.