10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
116ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
127LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
135LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
142LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
143 return populateProfileInfoDataLayout(op);
147LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
148 return populateProfileInfoDataLayout(op);
152LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
153 return populateProfileInfoDataLayout(op);
157LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
158 return populateProfileInfoDataLayout(op);
162LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
170LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
179LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
187LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
194LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
202LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
211LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
221LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
227LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
233LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
238LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
240#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
241 if (isa<tosa::tosaOp##Op>(op)) { \
242 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
245#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
246 if (isa<tosa::tosaOp##Op>(op)) \
250#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
251 if (isa<tosa::tosaOp##Op>(op)) { \
252 return populateProfileInfo(op->getOperands(), op->getResults()); \
367FailureOr<OpComplianceInfo<T>>
368TosaProfileCompliance::getOperatorDefinition(
Operation *op) {
371 const auto it = complianceMap.find(opName);
372 if (it == complianceMap.end())
384 if (specRequiredModeSet.size() == 0)
387 const auto maybeOpDefinition = getOperatorDefinition<T>(op);
388 if (failed(maybeOpDefinition)) {
394 for (
const auto &cands : specRequiredModeSet) {
397 modeCount += cands.size();
401 << (modeCount > 1 ?
" any of " :
" ") <<
"["
404 <<
"] but not enabled in target\n";
411 const auto opDefinition = maybeOpDefinition.value();
415 if (opRequiredMode.size() == 0) {
423 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
425 <<
"] but not enabled in target\n";
432 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
434 <<
"] but not enabled in target\n";
440 if constexpr (std::is_same_v<T, Extension>) {
441 for (
const auto &mode : opRequiredMode) {
447 <<
"] to work with but not enabled in target\n";
455 for (
const auto &cands : specRequiredModeSet) {
456 for (
const auto &mode : opRequiredMode) {
457 if (!llvm::is_contained(cands, mode)) {
461 <<
"] but not included in the profile compliance ["
473 opDefinition.operandTypeInfoSet[0];
477 op->
emitOpError() <<
"illegal: the target specification version ("
479 <<
") is not backwards compatible with the op compliance "
480 "specification version ("
491 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
493 interface.getProfiles());
501 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
503 interface.getExtensions());
509 const auto maybeProfDef = getOperatorDefinition<Profile>(op);
510 const auto maybeExtDef = getOperatorDefinition<Extension>(op);
511 if (failed(maybeProfDef) && failed(maybeExtDef))
514 const bool hasEntry =
515 (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
516 (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
519 llvm::raw_string_ostream os(message);
520 os <<
"illegal: operation operand/result data types did not align with any "
521 "profile or extension, got (";
525 for (
const auto &typeInfo : llvm::drop_end(current))
534 const auto searchBestMatch = [&](
auto map) {
535 for (
const auto &complianceInfos : map[opName]) {
536 for (
const auto &versionedTypeInfos :
537 complianceInfos.operandTypeInfoSet) {
539 const int matches = llvm::count_if(
540 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
542 std::get<1>(zipType));
544 if (matches > maxMatches) {
545 maxMatches = matches;
546 bestTypeInfo = typeInfos;
554 os <<
", did you mean (";
555 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
558 os <<
"Otherwise, please refer to the 'supported data types' for '"
559 << opName <<
"' in the specification.";
572 assert(compInfo.size() != 0 &&
573 "profile-based compliance information is empty");
578 if (present.size() == 0)
581 for (
size_t i = 0; i < compInfo.size(); i++) {
583 for (
const auto &set : sets) {
585 assert(present.size() == expected.size() &&
586 "the entries for profile-based compliance do not match between "
587 "the generated metadata and the type definition retrieved from "
593 for (
size_t j = 0;
j < expected.size();
j++) {
601 if (isFound ==
true) {
604 compInfo[i].condition};
618 for (
const auto &profile : profiles) {
619 if constexpr (std::is_same_v<T, Profile>)
620 debugStrings.push_back(tosa::stringifyProfile(profile));
622 debugStrings.push_back(tosa::stringifyExtension(profile));
632 for (
const auto &profiles : profileSet) {
634 llvm::append_range(debugStrings, tempStrings);
642 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
643 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
645 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
647 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
649 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
651 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
653 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
655 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
657 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
659 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
661 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
663 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
666 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< Profile > getCooperativeProfiles(Extension ext)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.