MLIR 23.0.0git
TosaProfileCompliance.cpp
Go to the documentation of this file.
1//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/StringExtras.h"
11
12using namespace mlir;
13using namespace mlir::tosa;
14
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
28
29 // micro-scaling formats
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
35
36// The profile-based compliance content below is auto-generated by a script
37// in https://github.com/arm/tosa-specification
39 // End of auto-generated metadata
40}
41
42template <>
46
47template <>
50 return extensionComplianceMap;
51}
52
53// Base populating function
54LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
55 ValueRange results) {
56 for (const auto &operand : operands)
57 addValue(operand);
58 for (const auto &result : results)
60 return success();
61}
62
63template <>
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
65 addValue(op.getInput1().front());
66 addValue(op.getOutput());
67 return success();
68}
69
70template <>
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
72 addValue(op.getInput());
73 addValue(op.getInputZp());
74 addValue(op.getOutputZp());
75 addType(op.getAccType());
76 addValue(op.getOutput());
77 return success();
78}
79
80template <typename T>
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
82 addValue(op.getInput());
83 addValue(op.getWeight());
84 addValue(op.getBias());
85 addValue(op.getInputZp());
86 addValue(op.getWeightZp());
87 addType(op.getAccType());
88 addValue(op.getOutput());
89 return success();
90}
91
92template <>
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
95}
96
97template <>
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
100}
101
102template <>
103LogicalResult
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
106}
107
108template <>
109LogicalResult
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
112}
113
114template <>
115LogicalResult
116ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) {
117 addValue(op.getInputData());
118 addValue(op.getInputScale());
119 addValue(op.getWeightData());
120 addValue(op.getWeightScale());
121 addValue(op.getBias());
122 addValue(op.getOutput());
123 return success();
124}
125
126template <>
127LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
128 addValue(op.getInput1());
129 addValue(op.getPadConst());
130 addValue(op.getOutput());
131 return success();
132}
133
134template <typename T>
135LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
136 addValue(op.getInput1());
137 addValue(op.getOutput());
138 return success();
139}
140
141template <>
142LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
143 return populateProfileInfoDataLayout(op);
144}
145
146template <>
147LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
148 return populateProfileInfoDataLayout(op);
149}
150
151template <>
152LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
153 return populateProfileInfoDataLayout(op);
154}
155
156template <>
157LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
158 return populateProfileInfoDataLayout(op);
159}
160
161template <>
162LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
163 addValue(op.getValues());
164 addValue(op.getIndices());
165 addValue(op.getOutput());
166 return success();
167}
168
169template <>
170LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
171 addValue(op.getValuesIn());
172 addValue(op.getIndices());
173 addValue(op.getInput());
174 addValue(op.getValuesOut());
175 return success();
176}
177
178template <>
179LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
180 addValue(op.getInput1());
181 addValue(op.getInput2());
182 addValue(op.getOutput());
183 return success();
184}
185
186template <>
187LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
188 addValue(op.getInput());
189 addValue(op.getOutput());
190 return success();
191}
192
193template <>
194LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
195 addValue(op.getOnTrue());
196 addValue(op.getOnFalse());
197 addValue(op.getOutput());
198 return success();
199}
200
201template <>
202LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
203 addValue(op.getInput());
204 addValue(op.getInputZp());
205 addValue(op.getOutputZp());
206 addValue(op.getOutput());
207 return success();
208}
209
210template <>
211LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
212 addValue(op.getA());
213 addValue(op.getB());
214 addValue(op.getAZp());
215 addValue(op.getBZp());
216 addValue(op.getOutput());
217 return success();
218}
219
220template <>
221LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
222 addType(op.getType());
223 return success();
224}
225
226template <>
227LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
228 addValue(op.getInput1());
229 return success();
230}
231
232template <>
233LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
234 addValue(op.getInput1());
235 return success();
236}
237
238LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
239// This helper function only populates the info for the customised operands.
240#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
241 if (isa<tosa::tosaOp##Op>(op)) { \
242 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
243 }
244
245#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
246 if (isa<tosa::tosaOp##Op>(op)) \
247 return success();
248
249// This helper function populates the info for all operands.
250#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
251 if (isa<tosa::tosaOp##Op>(op)) { \
252 return populateProfileInfo(op->getOperands(), op->getResults()); \
253 }
254
255 // Skip irrelevant operands when they are independent and not tied to any
256 // specific profile/extension.
258 POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
260 POPULATE_PROFILE_INFO_CUSTOM(Conv2DBlockScaled)
262 POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
277 POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
279
280 // For the most of tosa operators, all operands are profile/extension related
281 // and hence are all considered in this profile-based compilance check.
282 POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
286 POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
287 POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
299 POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
304 POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
305 POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
328 POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
331 POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
335 POPULATE_PROFILE_INFO_COMMON(VariableRead)
336
337 // Type Invariant Extension, a capability extension that is independent
338 // of the data type, meaning any compatible type can be used. No type
339 // constraint for those operations.
341 POPULATE_PROFILE_INFO_SKIP(AssertEqualShape)
342 POPULATE_PROFILE_INFO_SKIP(ConcatShape)
344 POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
345 POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
347 POPULATE_PROFILE_INFO_SKIP(Log2CeilShape)
348 POPULATE_PROFILE_INFO_SKIP(Log2FloorShape)
358
359 return failure();
360}
361
362//===----------------------------------------------------------------------===//
363// Tosa Profile And Extension Compliance Checker
364//===----------------------------------------------------------------------===//
365
366template <typename T>
367FailureOr<SmallVector<OpComplianceInfo<T>>>
368TosaProfileCompliance::getOperatorMatchedEntries(Operation *op) {
369 const std::string opName = op->getName().getStringRef().str();
370 const auto complianceMap = getProfileComplianceMap<T>();
371 const auto it = complianceMap.find(opName);
372 if (it == complianceMap.end())
373 return {};
374
375 return findMatchedEntries<T>(op, it->second);
376}
377
378template <typename T>
380 Operation *op, const tosa::TargetEnv &targetEnv,
381 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
382
383 // None of profile requirement is set in the specification.
384 if (specRequiredModeSet.size() == 0)
385 return success();
386
387 const auto maybeOpEntries = getOperatorMatchedEntries<T>(op);
388 if (failed(maybeOpEntries)) {
389 // Operators such as control-flow and shape ops do not have an operand type
390 // restriction. When the profile compliance information of operation is not
391 // found, confirm if the target have enabled the profile required from the
392 // specification.
393 int modeCount = 0;
394 for (const auto &cands : specRequiredModeSet) {
395 if (targetEnv.allowsAnyOf(cands))
396 return success();
397 modeCount += cands.size();
398 }
399
400 op->emitOpError() << "illegal: requires"
401 << (modeCount > 1 ? " any of " : " ") << "["
402 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
403 ", ")
404 << "] but not enabled in target\n";
405
406 return failure();
407 }
408
409 const auto opEntries = maybeOpEntries.value();
410 if (opEntries.size() == 0) {
411 // No matched restriction found.
412 return success();
413 }
414
415 // Check the profile/extension requirement according to the current target
416 // profiles/extensions.
417 const auto isModeAllowed = [&](const OpComplianceInfo<T> &info) -> bool {
418 if (info.condition == CheckCondition::allOf)
419 return targetEnv.allowsAllOf(info.mode);
420 return targetEnv.allowsAnyOf(info.mode);
421 };
422
423 // Check the matched op compliance version does not exceed the target
424 // specification version.
425 const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
426 const auto isVersionCompatible =
427 [&targetVersion](const OpComplianceInfo<T> &info) -> bool {
428 const TosaSpecificationVersion complianceVersion{
429 info.operandTypeInfoSet.front().second};
430 return targetVersion.isBackwardsCompatibleWith(complianceVersion);
431 };
432
433 for (const auto &info : opEntries) {
434 // Ensure the profile compliance is compatible with the profile knowledge of
435 // the op definition.
436 assert(llvm::all_of(info.mode,
437 [&specRequiredModeSet](const T &mode) {
438 return llvm::is_contained(specRequiredModeSet.front(),
439 mode);
440 }) &&
441 "the profile/extension requirement of the operator should be "
442 "included in the profile compliance information");
443
444 if (isModeAllowed(info) && isVersionCompatible(info))
445 return success();
446 }
447
448 // No valid entry was found, now emit appropriate error message and return
449 // failure
450 std::string message;
451 llvm::raw_string_ostream os(message);
452
453 os << "illegal: ";
454 const size_t numOpEntries = opEntries.size();
455 for (const auto &[index, info] : llvm::enumerate(opEntries)) {
456 bool mismatchedVersion = false;
457 if (!isVersionCompatible(info)) {
458 mismatchedVersion = true;
459 os << "requires specification version compatible with "
460 << stringifyVersion(info.operandTypeInfoSet.front().second) << " (got "
461 << stringifyVersion(targetVersion) << ") ";
462 }
463
464 if (!isModeAllowed(info)) {
465 if (mismatchedVersion)
466 os << "and ";
467 os << "requires "
468 << (info.condition == CheckCondition::allOf ? "all of " : "any of ")
469 << "[" << llvm::join(stringifyProfile<T>(info.mode), ", ")
470 << "] profiles/extensions ";
471 }
472
473 if (index != numOpEntries - 1)
474 os << "OR ";
475 }
476 os << "to be specified in the target environment";
477
478 return op->emitOpError(message);
479}
480
481LogicalResult
483 const tosa::TargetEnv &targetEnv) {
484 if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
485 return checkProfileOrExtension<Profile>(op, targetEnv,
486 interface.getProfiles());
487
488 return success();
489}
490
491LogicalResult
493 const tosa::TargetEnv &targetEnv) {
494 if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
495 return checkProfileOrExtension<Extension>(op, targetEnv,
496 interface.getExtensions());
497
498 return success();
499}
500
502 const auto maybeProfEntries = getOperatorMatchedEntries<Profile>(op);
503 const auto maybeExtEntries = getOperatorMatchedEntries<Extension>(op);
504 if (failed(maybeProfEntries) && failed(maybeExtEntries))
505 return success();
506
507 const bool hasEntry =
508 (succeeded(maybeProfEntries) && !maybeProfEntries.value().empty()) ||
509 (succeeded(maybeExtEntries) && !maybeExtEntries.value().empty());
510
511 if (!hasEntry) {
512 std::string message;
513 llvm::raw_string_ostream os(message);
514 os << "illegal: operation operand/result data types did not align with any "
515 "profile or extension, got (";
516
517 ProfileInfoDepot depot(op);
518 SmallVector<TypeInfo> current = depot.getInfo();
519 for (const auto &typeInfo : llvm::drop_end(current))
520 os << stringifyTypeInfo(typeInfo) << ",";
521 os << stringifyTypeInfo(current.back()) << ")";
522
523 // avoid polluting the error message output by outputting only
524 // the best match
525 const std::string opName = op->getName().getStringRef().str();
526 int maxMatches = -1;
527 SmallVector<TypeInfo> bestTypeInfo;
528 const auto searchBestMatch = [&](auto map) {
529 for (const auto &complianceInfos : map[opName]) {
530 for (const auto &versionedTypeInfos :
531 complianceInfos.operandTypeInfoSet) {
532 const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
533 const int matches = llvm::count_if(
534 llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
535 return isSameTypeInfo(std::get<0>(zipType),
536 std::get<1>(zipType));
537 });
538 if (matches > maxMatches) {
539 maxMatches = matches;
540 bestTypeInfo = typeInfos;
541 }
542 }
543 }
544 };
545 searchBestMatch(getProfileComplianceMap<Profile>());
546 searchBestMatch(getProfileComplianceMap<Extension>());
547
548 os << ", did you mean (";
549 for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
550 os << stringifyTypeInfo(typeInfo) << ",";
551 os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
552 os << "Otherwise, please refer to the 'supported data types' for '"
553 << opName << "' in the specification.";
554 op->emitOpError(message);
555 return failure();
556 }
557
558 return success();
559}
560
561// Find the profiles or extensions requirement according to the signature of
562// type of the operand list.
563template <typename T>
566 assert(compInfo.size() != 0 &&
567 "profile-based compliance information is empty");
568
569 // Populate the type of profile/extension relevant operands.
570 ProfileInfoDepot depot(op);
571 SmallVector<TypeInfo> present = depot.getInfo();
572 if (present.size() == 0)
573 return {};
574
576 for (size_t i = 0; i < compInfo.size(); i++) {
577 SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
578 for (const auto &set : sets) {
579 SmallVector<TypeInfo> expected = set.first;
580 assert(present.size() == expected.size() &&
581 "the entries for profile-based compliance do not match between "
582 "the generated metadata and the type definition retrieved from "
583 " the operation");
584
585 bool isFound = true;
586 // Compare the type signature between the given operation and the
587 // compliance metadata.
588 for (size_t j = 0; j < expected.size(); j++) {
589 if (!isSameTypeInfo(present[j], expected[j])) {
590 // Verify the next mode set from the list.
591 isFound = false;
592 break;
593 }
594 }
595
596 if (isFound == true) {
597 SmallVector<VersionedTypeInfo> typeInfoSet{set};
598 OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
599 compInfo[i].condition};
600 matchedInfos.push_back(info);
601 }
602 }
603 }
604
605 return matchedInfos;
606}
607
608// Debug utilites.
609template <typename T>
612 SmallVector<StringRef> debugStrings;
613 for (const auto &profile : profiles) {
614 if constexpr (std::is_same_v<T, Profile>)
615 debugStrings.push_back(tosa::stringifyProfile(profile));
616 else
617 debugStrings.push_back(tosa::stringifyExtension(profile));
618 }
619 return debugStrings;
620}
621
622template <typename T>
624 const SmallVector<ArrayRef<T>> &profileSet) {
625 SmallVector<StringRef> debugStrings;
626
627 for (const auto &profiles : profileSet) {
628 auto tempStrings = stringifyProfile<T>(profiles);
629 llvm::append_range(debugStrings, tempStrings);
630 }
631
632 return debugStrings;
633}
634
637 if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
638 return {"i" + llvm::utostr(typeInfo.bitWidth)};
639 }
640 if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
641 return {"f16"};
642 } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
643 return {"f32"};
644 } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
645 return {"bf16"};
646 } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
647 return {"fp8e4m3"};
648 } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
649 return {"fp8e5m2"};
650 } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
651 return {"fp6e2m3"};
652 } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
653 return {"fp6e3m2"};
654 } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
655 return {"fp4e2m1"};
656 } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
657 return {"fp8e8m0"};
658 } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
659 return {"mxint8"};
660 }
661 llvm_unreachable("unknown type");
662}
return success()
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
@ Gather
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
SmallVector< OpComplianceInfo< T > > findMatchedEntries(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:100
bool allowsAllOf(ArrayRef< Profile > profs) const
Definition TargetEnv.h:133
bool allowsAnyOf(ArrayRef< Profile > profs) const
Definition TargetEnv.h:129
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:113
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
Definition TargetEnv.h:58
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Definition TargetEnv.cpp:15
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
mlir::TypeID typeID
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.