10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
82ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
93ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
100LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
112LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
113 return populateProfileInfoConv(op);
117LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
118 return populateProfileInfoConv(op);
123ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
124 return populateProfileInfoConv(op);
129ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
130 return populateProfileInfoConv(op);
135ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
146LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
154LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
161LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
162 return populateProfileInfoDataLayout(op);
166LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
167 return populateProfileInfoDataLayout(op);
171LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
172 return populateProfileInfoDataLayout(op);
176LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
177 return populateProfileInfoDataLayout(op);
181LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
190ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
191 for (
Value value : op.getValues())
201LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
210LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
218LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
225LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
233LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
242LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
252LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
258LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
264LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
269LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
271#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
272 if (isa<tosa::tosaOp##Op>(op)) { \
273 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
276#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
277 if (isa<tosa::tosaOp##Op>(op)) \
281#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
282 if (isa<tosa::tosaOp##Op>(op)) { \
283 return populateProfileInfo(op->getOperands(), op->getResults()); \
401FailureOr<SmallVector<OpComplianceInfo<T>>>
402TosaProfileCompliance::getOperatorMatchedEntries(
Operation *op) {
405 const auto it = complianceMap.find(opName);
406 if (it == complianceMap.end())
418 if (specRequiredModeSet.size() == 0)
421 const auto maybeOpEntries = getOperatorMatchedEntries<T>(op);
422 if (failed(maybeOpEntries)) {
428 for (
const auto &cands : specRequiredModeSet) {
431 modeCount += cands.size();
435 << (modeCount > 1 ?
" any of " :
" ") <<
"["
438 <<
"] but not enabled in target\n";
443 const auto opEntries = maybeOpEntries.value();
444 if (opEntries.size() == 0) {
460 const auto isVersionCompatible =
463 info.operandTypeInfoSet.front().second};
467 for (
const auto &info : opEntries) {
470 assert(llvm::all_of(info.mode,
471 [&specRequiredModeSet](
const T &mode) {
472 return llvm::is_contained(specRequiredModeSet.front(),
475 "the profile/extension requirement of the operator should be "
476 "included in the profile compliance information");
478 if (isModeAllowed(info) && isVersionCompatible(info))
485 llvm::raw_string_ostream os(message);
488 const size_t numOpEntries = opEntries.size();
489 for (
const auto &[
index, info] : llvm::enumerate(opEntries)) {
490 bool mismatchedVersion =
false;
491 if (!isVersionCompatible(info)) {
492 mismatchedVersion =
true;
493 os <<
"requires specification version compatible with "
498 if (!isModeAllowed(info)) {
499 if (mismatchedVersion)
504 <<
"] profiles/extensions ";
507 if (
index != numOpEntries - 1)
510 os <<
"to be specified in the target environment";
518 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
520 interface.getProfiles());
528 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
530 interface.getExtensions());
536 const auto maybeProfEntries = getOperatorMatchedEntries<Profile>(op);
537 const auto maybeExtEntries = getOperatorMatchedEntries<Extension>(op);
538 if (failed(maybeProfEntries) && failed(maybeExtEntries))
541 const bool hasEntry =
542 (succeeded(maybeProfEntries) && !maybeProfEntries.value().empty()) ||
543 (succeeded(maybeExtEntries) && !maybeExtEntries.value().empty());
547 llvm::raw_string_ostream os(message);
548 os <<
"illegal: operation operand/result data types did not align with any "
549 "profile or extension, got (";
553 for (
const auto &typeInfo : llvm::drop_end(current))
562 const auto searchBestMatch = [&](
auto map) {
563 for (
const auto &complianceInfos : map[opName]) {
564 for (
const auto &versionedTypeInfos :
565 complianceInfos.operandTypeInfoSet) {
567 if (current.size() != typeInfos.size())
569 const int matches = llvm::count_if(
570 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
572 std::get<1>(zipType));
574 if (matches > maxMatches) {
575 maxMatches = matches;
576 bestTypeInfo = typeInfos;
584 os <<
", did you mean (";
585 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
588 os <<
"Otherwise, please refer to the 'supported data types' for '"
589 << opName <<
"' in the specification.";
602 assert(compInfo.size() != 0 &&
603 "profile-based compliance information is empty");
608 if (present.size() == 0)
612 for (
size_t i = 0; i < compInfo.size(); i++) {
614 for (
const auto &set : sets) {
619 if (present.size() != expected.size())
625 for (
size_t j = 0;
j < expected.size();
j++) {
633 if (isFound ==
true) {
636 compInfo[i].condition};
637 matchedInfos.push_back(info);
650 for (
const auto &profile : profiles) {
651 if constexpr (std::is_same_v<T, Profile>)
652 debugStrings.push_back(tosa::stringifyProfile(profile));
654 debugStrings.push_back(tosa::stringifyExtension(profile));
664 for (
const auto &profiles : profileSet) {
666 llvm::append_range(debugStrings, tempStrings);
674 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
675 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
677 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
679 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
681 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
683 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
685 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
687 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
689 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
691 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
693 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
695 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
698 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
SmallVector< OpComplianceInfo< T > > findMatchedEntries(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.