10 #include "llvm/ADT/StringExtras.h"
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
23 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
24 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
25 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
26 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
36 return profileComplianceMap;
42 return extensionComplianceMap;
46 LogicalResult ProfileInfoDepot::populateProfileInfo(
ValueRange operands,
48 for (
auto operand : operands)
55 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
56 addValue(op.getInput1().front());
57 addValue(op.getOutput());
62 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
63 addValue(op.getInput());
64 addValue(op.getInputZp());
65 addValue(op.getOutputZp());
66 addType(op.getAccType());
67 addValue(op.getOutput());
72 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
73 addValue(op.getInput());
74 addValue(op.getWeight());
75 addValue(op.getBias());
76 addValue(op.getInputZp());
77 addValue(op.getWeightZp());
78 addType(op.getAccType());
79 addValue(op.getOutput());
84 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
85 return populateProfileInfoConv(op);
89 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
90 return populateProfileInfoConv(op);
95 ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
96 return populateProfileInfoConv(op);
101 ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
102 return populateProfileInfoConv(op);
106 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
107 addValue(op.getInput1());
108 addValue(op.getPadConst());
109 addValue(op.getOutput());
113 template <
typename T>
114 LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
115 addValue(op.getInput1());
116 addValue(op.getOutput());
121 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
122 return populateProfileInfoDataLayout(op);
126 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
127 return populateProfileInfoDataLayout(op);
131 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
132 return populateProfileInfoDataLayout(op);
136 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
137 return populateProfileInfoDataLayout(op);
141 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
142 addValue(op.getValues());
143 addValue(op.getIndices());
144 addValue(op.getOutput());
149 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
150 addValue(op.getValuesIn());
151 addValue(op.getIndices());
152 addValue(op.getInput());
153 addValue(op.getValuesOut());
158 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
159 addValue(op.getInput1());
160 addValue(op.getInput2());
161 addValue(op.getOutput());
166 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
167 addValue(op.getInput());
168 addValue(op.getOutput());
173 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
174 addValue(op.getInputReal());
175 addValue(op.getInputImag());
176 addValue(op.getOutputReal());
177 addValue(op.getOutputImag());
182 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
183 addValue(op.getInputReal());
184 addValue(op.getOutputReal());
185 addValue(op.getOutputImag());
190 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191 addValue(op.getInput2());
192 addValue(op.getInput3());
193 addValue(op.getOutput());
198 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
199 addValue(op.getInput());
200 addValue(op.getInputZp());
201 addValue(op.getOutputZp());
202 addValue(op.getOutput());
207 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
210 addValue(op.getAZp());
211 addValue(op.getBZp());
212 addValue(op.getOutput());
217 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
222 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
230 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
231 addValue(op.getInput1());
236 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
237 addValue(op.getCondition());
242 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
249 LogicalResult ProfileInfoDepot::populatationDispatch(
Operation *op) {
251 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
252 if (isa<tosa::tosaOp##Op>(op)) { \
253 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
256 #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
257 if (isa<tosa::tosaOp##Op>(op)) \
261 #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
262 if (isa<tosa::tosaOp##Op>(op)) { \
263 return populateProfileInfo(op->getOperands(), op->getResult(0)); \
358 template <
typename T>
359 FailureOr<SmallVector<T>>
360 TosaProfileCompliance::getOperatorDefinition(
Operation *op,
363 const auto complianceMap = getProfileComplianceMap<T>();
364 const auto it = complianceMap.find(opName);
365 if (it == complianceMap.end())
368 return findMatchedProfile<T>(op, it->second, condition);
371 template <
typename T>
374 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
377 if (specRequiredModeSet.size() == 0)
381 const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
382 if (failed(maybeOpRequiredMode)) {
388 for (
const auto &cands : specRequiredModeSet) {
391 mode_count += cands.size();
395 << (mode_count > 1 ?
" any of " :
" ") <<
"["
396 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
398 <<
"] but not enabled in target\n";
405 const auto opRequiredMode = maybeOpRequiredMode.value();
406 if (opRequiredMode.size() == 0) {
414 << (opRequiredMode.size() > 1 ?
" all of " :
" ") <<
"["
415 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
416 <<
"] but not enabled in target\n";
423 << (opRequiredMode.size() > 1 ?
" any of " :
" ") <<
"["
424 << llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
425 <<
"] but not enabled in target\n";
431 if constexpr (std::is_same_v<T, Extension>) {
432 for (
const auto &mode : opRequiredMode) {
433 SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
436 << llvm::join(stringifyProfile<Profile>(coProfs),
438 <<
"] to work with but not enabled in target\n";
446 for (
const auto &cands : specRequiredModeSet) {
447 for (
size_t i = 0; i < opRequiredMode.size(); i++) {
448 if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
451 << llvm::join(stringifyProfile<T>(opRequiredMode),
453 <<
"] but not included in the profile compliance ["
455 stringifyProfile<T>(specRequiredModeSet),
", ")
468 if (
auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
469 return checkProfileOrExtension<Profile>(op, targetEnv,
470 interface.getProfiles());
478 if (
auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
479 return checkProfileOrExtension<Extension>(op, targetEnv,
480 interface.getExtensions());
487 const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
488 const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
489 if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
490 !maybeProfDef.value().size() && !maybeExtDef.value().size())
498 template <
typename T>
502 assert(compInfo.size() != 0 &&
503 "profile-based compliance information is empty");
507 SmallVector<TypeInfo> present = depot.
getInfo();
508 if (present.size() == 0)
511 for (
size_t i = 0; i < compInfo.size(); i++) {
512 SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
513 for (SmallVector<TypeInfo> expected : sets) {
514 assert(present.size() == expected.size() &&
515 "the entries for profile-based compliance do not match between "
516 "the generated metadata and the type definition retrieved from "
519 bool is_found =
true;
522 for (
size_t j = 0;
j < expected.size();
j++) {
523 if (!isSameTypeInfo(present[
j], expected[
j])) {
530 if (is_found ==
true) {
531 condition = compInfo[i].condition;
532 return compInfo[i].mode;
541 template <
typename T>
544 SmallVector<StringRef> debugStrings;
545 for (
const auto &profile : profiles) {
546 if constexpr (std::is_same_v<T, Profile>)
547 debugStrings.push_back(tosa::stringifyProfile(profile));
549 debugStrings.push_back(tosa::stringifyExtension(profile));
554 template <
typename T>
556 const SmallVector<ArrayRef<T>> &profileSet) {
557 SmallVector<StringRef> debugStrings;
559 for (
const auto &profiles : profileSet) {
560 auto tempStrings = stringifyProfile<T>(profiles);
561 llvm::append_range(debugStrings, tempStrings);
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > >> OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > >> OperationExtensionComplianceMap
SmallVector< TypeInfo > getInfo()
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< T > findMatchedProfile(Operation *op, SmallVector< OpComplianceInfo< T >> compInfo, CheckCondition &condition)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T >> &specDefinedProfileSet)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class represents the capability enabled in the target implementation such as profile,...
bool allowsAllOf(ArrayRef< Profile > profs) const
bool allowsAnyOf(ArrayRef< Profile > profs) const
NestedPattern If(const NestedPattern &child)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.