10 #include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
36 return profileComplianceMap;
42 return extensionComplianceMap;
46 LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
48 for (
auto operand : operands)
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
113 template <
typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getOnTrue());
192 addValue(op.getOnFalse());
193 addValue(op.getOutput());
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218 addType(op.getType());
223 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
224 addValue(op.getInput1());
228 LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
230 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
231 if (isa<tosa::tosaOp##Op>(op)) { \
232 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
235 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
236 if (isa<tosa::tosaOp##Op>(op)) \
240 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
241 if (isa<tosa::tosaOp##Op>(op)) { \
242 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
337 template <
typename T>
338 FailureOr<OpComplianceInfo<T>>
339 TosaProfileCompliance::getOperatorDefinition(
Operation *op) {
341 const auto complianceMap = getProfileComplianceMap<T>();
342 const auto it = complianceMap.find(opName);
343 if (it == complianceMap.end())
346 return findMatchedEntry<T>(op, it->second);
349 template <
typename T>
352 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
355 if (specRequiredModeSet.size() == 0)
358 const auto maybeOpDefinition = getOperatorDefinition<T>(op);
359 if (
failed(maybeOpDefinition)) {
365 for (
const auto &cands : specRequiredModeSet) {
368 modeCount += cands.size();
372 << (modeCount > 1 ?
" any of " :
" ") <<
"["
373 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
375 <<
"] but not enabled in target\n";
382 const auto opDefinition = maybeOpDefinition.value();
383 const SmallVector<T> opRequiredMode = opDefinition.mode;
386 if (opRequiredMode.size() == 0) {
394 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
395 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
396 <<
"] but not enabled in target\n";
403 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
404 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
405 <<
"] but not enabled in target\n";
411 if constexpr (std::is_same_v<T, Extension>) {
412 for (
const auto &mode : opRequiredMode) {
413 SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
416 << llvm::join(stringifyProfile<Profile>(coProfs),
418 <<
"] to work with but not enabled in target\n";
426 for (
const auto &cands : specRequiredModeSet) {
427 for (
const auto &mode : opRequiredMode) {
428 if (!llvm::is_contained(cands, mode)) {
430 << llvm::join(stringifyProfile<T>(opRequiredMode),
432 <<
"] but not included in the profile compliance ["
434 stringifyProfile<T>(specRequiredModeSet),
", ")
444 opDefinition.operandTypeInfoSet[0];
447 if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
448 op->
emitOpError() <<
"illegal: the target specification version ("
450 <<
") is not backwards compatible with the op compliance "
451 "specification version ("
462 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
463 return checkProfileOrExtension<Profile>(op, targetEnv,
464 interface.getProfiles());
472 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
473 return checkProfileOrExtension<Extension>(op, targetEnv,
474 interface.getExtensions());
480 const auto maybeProfDef = getOperatorDefinition<Profile>(op);
481 const auto maybeExtDef = getOperatorDefinition<Extension>(op);
485 const bool hasEntry =
486 (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
487 (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
490 llvm::raw_string_ostream os(message);
491 os <<
"illegal: operation operand/result data types did not align with any "
492 "profile or extension, got (";
495 SmallVector<TypeInfo> current = depot.
getInfo();
496 for (
const auto &typeInfo : llvm::drop_end(current))
497 os << stringifyTypeInfo(typeInfo) <<
",";
498 os << stringifyTypeInfo(current.back()) <<
")";
504 SmallVector<TypeInfo> bestTypeInfo;
505 const auto searchBestMatch = [&](
auto map) {
506 for (
const auto &complianceInfos : map[opName]) {
507 for (
const auto &versionedTypeInfos :
508 complianceInfos.operandTypeInfoSet) {
509 const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
510 const int matches = llvm::count_if(
511 llvm::zip_equal(current, typeInfos), [&](
const auto zipType) {
512 return isSameTypeInfo(std::get<0>(zipType),
513 std::get<1>(zipType));
515 if (matches > maxMatches) {
516 maxMatches = matches;
517 bestTypeInfo = typeInfos;
522 searchBestMatch(getProfileComplianceMap<Profile>());
523 searchBestMatch(getProfileComplianceMap<Extension>());
525 os <<
", did you mean (";
526 for (
const auto &typeInfo : llvm::drop_end(bestTypeInfo))
527 os << stringifyTypeInfo(typeInfo) <<
",";
528 os << stringifyTypeInfo(bestTypeInfo.back()) <<
")? ";
529 os <<
"Otherwise, please refer to the 'supported data types' for '"
530 << opName <<
"' in the specification.";
540 template <
typename T>
543 assert(compInfo.size() != 0 &&
544 "profile-based compliance information is empty");
548 SmallVector<TypeInfo> present = depot.
getInfo();
549 if (present.size() == 0)
552 for (
size_t i = 0; i < compInfo.size(); i++) {
553 SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
554 for (
const auto &set : sets) {
555 SmallVector<TypeInfo> expected = set.first;
556 assert(present.size() == expected.size() &&
557 "the entries for profile-based compliance do not match between "
558 "the generated metadata and the type definition retrieved from "
564 for (
size_t j = 0;
j < expected.size();
j++) {
565 if (!isSameTypeInfo(present[
j], expected[
j])) {
572 if (isFound ==
true) {
573 SmallVector<VersionedTypeInfo> typeInfoSet{set};
575 compInfo[i].condition};
585 template <
typename T>
588 SmallVector<StringRef> debugStrings;
589 for (
const auto &profile : profiles) {
590 if constexpr (std::is_same_v<T, Profile>)
591 debugStrings.push_back(tosa::stringifyProfile(profile));
593 debugStrings.push_back(tosa::stringifyExtension(profile));
598 template <
typename T>
600 const SmallVector<ArrayRef<T>> &profileSet) {
601 SmallVector<StringRef> debugStrings;
603 for (
const auto &profiles : profileSet) {
604 auto tempStrings = stringifyProfile<T>(profiles);
605 llvm::append_range(debugStrings, tempStrings);
613 if (typeInfo.
typeID == mlir::IntegerType::getTypeID()) {
614 return {
"i" + llvm::utostr(typeInfo.
bitWidth)};
616 if (typeInfo.
typeID == mlir::Float16Type::getTypeID()) {
618 }
else if (typeInfo.
typeID == mlir::Float32Type::getTypeID()) {
620 }
else if (typeInfo.
typeID == mlir::BFloat16Type::getTypeID()) {
622 }
else if (typeInfo.
typeID == mlir::Float8E4M3FNType::getTypeID()) {
624 }
else if (typeInfo.
typeID == mlir::Float8E5M2Type::getTypeID()) {
627 llvm_unreachable(
"unknown type");
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
SpecificationVersion getSpecVersion() const
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
NestedPattern If(const NestedPattern &child)
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.