10#include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
44 return profileComplianceMap;
50 return extensionComplianceMap;
54LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
56 for (
const auto &operand : operands)
58 for (
const auto &
result : results)
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
115LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
123LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
130LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
131 return populateProfileInfoDataLayout(op);
135LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
136 return populateProfileInfoDataLayout(op);
140LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
141 return populateProfileInfoDataLayout(op);
145LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
146 return populateProfileInfoDataLayout(op);
150LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
158LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
167LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
175LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
182LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
190LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
209LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
215LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
220LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
222#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
223 if (isa<tosa::tosaOp##Op>(op)) { \
224 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
227#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
228 if (isa<tosa::tosaOp##Op>(op)) \
232#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
233 if (isa<tosa::tosaOp##Op>(op)) { \
234 return populateProfileInfo(op->getOperands(), op->getResults()); \
338FailureOr<OpComplianceInfo<T>>
339TosaProfileCompliance::getOperatorDefinition(
Operation *op) {
342 const auto it = complianceMap.find(opName);
343 if (it == complianceMap.end())
355 if (specRequiredModeSet.size() == 0)
358 const auto maybeOpDefinition = getOperatorDefinition<T>(op);
359 if (failed(maybeOpDefinition)) {
365 for (
const auto &cands : specRequiredModeSet) {
368 modeCount += cands.size();
372 << (modeCount > 1 ?
" any of " :
" ") <<
"["
375 <<
"] but not enabled in target\n";
382 const auto opDefinition = maybeOpDefinition.value();
386 if (opRequiredMode.size() == 0) {
394 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
396 <<
"] but not enabled in target\n";
403 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
405 <<
"] but not enabled in target\n";
411 if constexpr (std::is_same_v<T, Extension>) {
412 for (
const auto &mode : opRequiredMode) {
418 <<
"] to work with but not enabled in target\n";
426 for (
const auto &cands : specRequiredModeSet) {
427 for (
const auto &mode : opRequiredMode) {
428 if (!llvm::is_contained(cands, mode)) {
432 <<
"] but not included in the profile compliance ["
444 opDefinition.operandTypeInfoSet[0];
448 op->
emitOpError() <<
"illegal: the target specification version ("
450 <<
") is not backwards compatible with the op compliance "
451 "specification version ("
462 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
464 interface.getProfiles());
472 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
474 interface.getExtensions());
480 const auto maybeProfDef = getOperatorDefinition<Profile>(op);
481 const auto maybeExtDef = getOperatorDefinition<Extension>(op);
482 if (failed(maybeProfDef) && failed(maybeExtDef))
485 const bool hasEntry =
486 (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
487 (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
490 llvm::raw_string_ostream os(message);
491 os <<
"illegal: operation operand/result data types did not align with any "
492 "profile or extension, got (";
496 for (
const auto &typeInfo : llvm::drop_end(current))
505 const auto searchBestMatch = [&](
auto map) {
506 for (
const auto &complianceInfos : map[opName]) {
507 for (
const auto &versionedTypeInfos :
508 complianceInfos.operandTypeInfoSet) {
510 const int matches = llvm::count_if(
511 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
513 std::get<1>(zipType));
515 if (matches > maxMatches) {
516 maxMatches = matches;
517 bestTypeInfo = typeInfos;
525 os <<
", did you mean (";
526 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
529 os <<
"Otherwise, please refer to the 'supported data types' for '"
530 << opName <<
"' in the specification.";
543 assert(compInfo.size() != 0 &&
544 "profile-based compliance information is empty");
549 if (present.size() == 0)
552 for (
size_t i = 0; i < compInfo.size(); i++) {
554 for (
const auto &set : sets) {
556 assert(present.size() == expected.size() &&
557 "the entries for profile-based compliance do not match between "
558 "the generated metadata and the type definition retrieved from "
564 for (
size_t j = 0;
j < expected.size();
j++) {
572 if (isFound ==
true) {
575 compInfo[i].condition};
589 for (
const auto &profile : profiles) {
590 if constexpr (std::is_same_v<T, Profile>)
591 debugStrings.push_back(tosa::stringifyProfile(profile));
593 debugStrings.push_back(tosa::stringifyExtension(profile));
603 for (
const auto &profiles : profileSet) {
605 llvm::append_range(debugStrings, tempStrings);
613 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
614 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
616 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
618 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
620 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
622 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
624 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
626 }
else if (typeInfo.
typeID == mlir::Float6E2M3FNType::getTypeID()) {
628 }
else if (typeInfo.
typeID == mlir::Float6E3M2FNType::getTypeID()) {
630 }
else if (typeInfo.
typeID == mlir::Float4E2M1FNType::getTypeID()) {
632 }
else if (typeInfo.
typeID == mlir::Float8E8M0FNUType::getTypeID()) {
634 }
else if (typeInfo.
typeID == tosa::mxint8Type::getTypeID()) {
637 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< Profile > getCooperativeProfiles(Extension ext)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
TosaSpecificationVersion getSpecVersion() const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.