10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
115LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
123LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
130LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
131 return populateProfileInfoDataLayout(op);
135LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
136 return populateProfileInfoDataLayout(op);
140LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
141 return populateProfileInfoDataLayout(op);
145LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
146 return populateProfileInfoDataLayout(op);
150LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
158LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
167LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
175LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
182LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
190LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
209LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
215LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
220LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
222#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
223 if (isa<tosa::tosaOp##Op>(op)) { \
224 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
227#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
228 if (isa<tosa::tosaOp##Op>(op)) \
232#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
233 if (isa<tosa::tosaOp##Op>(op)) { \
234 return populateProfileInfo(op->getOperands(), op->getResults()); \
333FailureOr<OpComplianceInfo<T>>
334TosaProfileCompliance::getOperatorDefinition(
Operation *op) {
337 const auto it = complianceMap.find(opName);
338 if (it == complianceMap.end())
350 if (specRequiredModeSet.size() == 0)
353 const auto maybeOpDefinition = getOperatorDefinition<T>(op);
354 if (failed(maybeOpDefinition)) {
360 for (
const auto &cands : specRequiredModeSet) {
363 modeCount += cands.size();
367 << (modeCount > 1 ?
" any of " :
" ") <<
"["
370 <<
"] but not enabled in target\n";
377 const auto opDefinition = maybeOpDefinition.value();
381 if (opRequiredMode.size() == 0) {
389 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
391 <<
"] but not enabled in target\n";
398 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
400 <<
"] but not enabled in target\n";
406 if constexpr (std::is_same_v<T, Extension>) {
407 for (
const auto &mode : opRequiredMode) {
413 <<
"] to work with but not enabled in target\n";
421 for (
const auto &cands : specRequiredModeSet) {
422 for (
const auto &mode : opRequiredMode) {
423 if (!llvm::is_contained(cands, mode)) {
427 <<
"] but not included in the profile compliance ["
439 opDefinition.operandTypeInfoSet[0];
443 op->
emitOpError() <<
"illegal: the target specification version ("
445 <<
") is not backwards compatible with the op compliance "
446 "specification version ("
457 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
459 interface.getProfiles());
467 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
469 interface.getExtensions());
475 const auto maybeProfDef = getOperatorDefinition<Profile>(op);
476 const auto maybeExtDef = getOperatorDefinition<Extension>(op);
477 if (failed(maybeProfDef) && failed(maybeExtDef))
480 const bool hasEntry =
481 (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
482 (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
485 llvm::raw_string_ostream os(message);
486 os <<
"illegal: operation operand/result data types did not align with any "
487 "profile or extension, got (";
491 for (
const auto &typeInfo : llvm::drop_end(current))
500 const auto searchBestMatch = [&](
auto map) {
501 for (
const auto &complianceInfos : map[opName]) {
502 for (
const auto &versionedTypeInfos :
503 complianceInfos.operandTypeInfoSet) {
505 const int matches = llvm::count_if(
506 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
508 std::get<1>(zipType));
510 if (matches > maxMatches) {
511 maxMatches = matches;
512 bestTypeInfo = typeInfos;
520 os <<
", did you mean (";
521 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
524 os <<
"Otherwise, please refer to the 'supported data types' for '"
525 << opName <<
"' in the specification.";
538 assert(compInfo.size() != 0 &&
539 "profile-based compliance information is empty");
544 if (present.size() == 0)
547 for (
size_t i = 0; i < compInfo.size(); i++) {
549 for (
const auto &set : sets) {
551 assert(present.size() == expected.size() &&
552 "the entries for profile-based compliance do not match between "
553 "the generated metadata and the type definition retrieved from "
559 for (
size_t j = 0;
j < expected.size();
j++) {
567 if (isFound ==
true) {
570 compInfo[i].condition};
584 for (
const auto &profile : profiles) {
585 if constexpr (std::is_same_v<T, Profile>)
586 debugStrings.push_back(tosa::stringifyProfile(profile));
588 debugStrings.push_back(tosa::stringifyExtension(profile));
598 for (
const auto &profiles : profileSet) {
600 llvm::append_range(debugStrings, tempStrings);
608 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
609 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
611 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
613 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
615 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
617 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
619 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
621 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
623 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
625 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
627 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
629 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
632 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< Profile > getCooperativeProfiles(Extension ext)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.