MLIR 22.0.0git
TosaProfileCompliance.cpp
Go to the documentation of this file.
1//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/StringExtras.h"
11
12using namespace mlir;
13using namespace mlir::tosa;
14
16 const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
17 const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
18 const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
19 const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
20 const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
21 const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
22 const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
23 const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
24 const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
25 const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
26 const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
27 const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
28
29 // micro-scaling formats
30 const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
31 const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
32 const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
33 const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34 const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
35
36// The profile-based compliance content below is auto-generated by a script
37// in https://git.mlplatform.org/tosa/specification.git
39 // End of auto-generated metadata
40}
41
42template <>
46
47template <>
50 return extensionComplianceMap;
51}
52
53// Base populating function
54LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
55 ValueRange results) {
56 for (const auto &operand : operands)
57 addValue(operand);
58 for (const auto &result : results)
60 return success();
61}
62
63template <>
64LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
65 addValue(op.getInput1().front());
66 addValue(op.getOutput());
67 return success();
68}
69
70template <>
71LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
72 addValue(op.getInput());
73 addValue(op.getInputZp());
74 addValue(op.getOutputZp());
75 addType(op.getAccType());
76 addValue(op.getOutput());
77 return success();
78}
79
80template <typename T>
81LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
82 addValue(op.getInput());
83 addValue(op.getWeight());
84 addValue(op.getBias());
85 addValue(op.getInputZp());
86 addValue(op.getWeightZp());
87 addType(op.getAccType());
88 addValue(op.getOutput());
89 return success();
90}
91
92template <>
93LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
94 return populateProfileInfoConv(op);
95}
96
97template <>
98LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
99 return populateProfileInfoConv(op);
100}
101
102template <>
103LogicalResult
104ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
105 return populateProfileInfoConv(op);
106}
107
108template <>
109LogicalResult
110ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
111 return populateProfileInfoConv(op);
112}
113
114template <>
115LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
116 addValue(op.getInput1());
117 addValue(op.getPadConst());
118 addValue(op.getOutput());
119 return success();
120}
121
122template <typename T>
123LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
124 addValue(op.getInput1());
125 addValue(op.getOutput());
126 return success();
127}
128
129template <>
130LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
131 return populateProfileInfoDataLayout(op);
132}
133
134template <>
135LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
136 return populateProfileInfoDataLayout(op);
137}
138
139template <>
140LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
141 return populateProfileInfoDataLayout(op);
142}
143
144template <>
145LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
146 return populateProfileInfoDataLayout(op);
147}
148
149template <>
150LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
151 addValue(op.getValues());
152 addValue(op.getIndices());
153 addValue(op.getOutput());
154 return success();
155}
156
157template <>
158LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
159 addValue(op.getValuesIn());
160 addValue(op.getIndices());
161 addValue(op.getInput());
162 addValue(op.getValuesOut());
163 return success();
164}
165
166template <>
167LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
168 addValue(op.getInput1());
169 addValue(op.getInput2());
170 addValue(op.getOutput());
171 return success();
172}
173
174template <>
175LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
176 addValue(op.getInput());
177 addValue(op.getOutput());
178 return success();
179}
180
181template <>
182LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
183 addValue(op.getOnTrue());
184 addValue(op.getOnFalse());
185 addValue(op.getOutput());
186 return success();
187}
188
189template <>
190LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
191 addValue(op.getInput());
192 addValue(op.getInputZp());
193 addValue(op.getOutputZp());
194 addValue(op.getOutput());
195 return success();
196}
197
198template <>
199LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
200 addValue(op.getA());
201 addValue(op.getB());
202 addValue(op.getAZp());
203 addValue(op.getBZp());
204 addValue(op.getOutput());
205 return success();
206}
207
208template <>
209LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
210 addType(op.getType());
211 return success();
212}
213
214template <>
215LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
216 addValue(op.getInput1());
217 return success();
218}
219
220LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
221// This helper function only populates the info for the customised operands.
222#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
223 if (isa<tosa::tosaOp##Op>(op)) { \
224 return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
225 }
226
227#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
228 if (isa<tosa::tosaOp##Op>(op)) \
229 return success();
230
231// This helper function populates the info for all operands.
232#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
233 if (isa<tosa::tosaOp##Op>(op)) { \
234 return populateProfileInfo(op->getOperands(), op->getResults()); \
235 }
236
237 // Skip irrelevant operands when they are independent and not tied to any
238 // specific profile/extension.
240 POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
243 POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
258 POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
259
260 // For the most of tosa operators, all operands are profile/extension related
261 // and hence are all considered in this profile-based compilance check.
262 POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
266 POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
267 POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
279 POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
284 POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
285 POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
308 POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
311 POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
315 POPULATE_PROFILE_INFO_COMMON(VariableRead)
316
317 // Type Invariant Extension, a capability extension that is independent
318 // of the data type, meaning any compatible type can be used. No type
319 // constraint for those operations.
322 POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
323 POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
329
330 return failure();
331}
332
333//===----------------------------------------------------------------------===//
334// Tosa Profile And Extension Compliance Checker
335//===----------------------------------------------------------------------===//
336
337template <typename T>
338FailureOr<OpComplianceInfo<T>>
339TosaProfileCompliance::getOperatorDefinition(Operation *op) {
340 const std::string opName = op->getName().getStringRef().str();
341 const auto complianceMap = getProfileComplianceMap<T>();
342 const auto it = complianceMap.find(opName);
343 if (it == complianceMap.end())
344 return {};
345
346 return findMatchedEntry<T>(op, it->second);
347}
348
349template <typename T>
351 Operation *op, const tosa::TargetEnv &targetEnv,
352 const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
353
354 // None of profile requirement is set in the specification.
355 if (specRequiredModeSet.size() == 0)
356 return success();
357
358 const auto maybeOpDefinition = getOperatorDefinition<T>(op);
359 if (failed(maybeOpDefinition)) {
360 // Operators such as control-flow and shape ops do not have an operand type
361 // restriction. When the profile compliance information of operation is not
362 // found, confirm if the target have enabled the profile required from the
363 // specification.
364 int modeCount = 0;
365 for (const auto &cands : specRequiredModeSet) {
366 if (targetEnv.allowsAnyOf(cands))
367 return success();
368 modeCount += cands.size();
369 }
370
371 op->emitOpError() << "illegal: requires"
372 << (modeCount > 1 ? " any of " : " ") << "["
373 << llvm::join(stringifyProfile<T>(specRequiredModeSet),
374 ", ")
375 << "] but not enabled in target\n";
376
377 return failure();
378 }
379
380 // Find the required profiles or extensions according to the operand type
381 // combination.
382 const auto opDefinition = maybeOpDefinition.value();
383 const SmallVector<T> opRequiredMode = opDefinition.mode;
384 const CheckCondition condition = opDefinition.condition;
385
386 if (opRequiredMode.size() == 0) {
387 // No matched restriction found.
388 return success();
389 }
390
391 if (condition == CheckCondition::allOf &&
392 !targetEnv.allowsAllOf(opRequiredMode)) {
393 op->emitOpError() << "illegal: requires"
394 << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
395 << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
396 << "] but not enabled in target\n";
397 return failure();
398 }
399
400 if (condition == CheckCondition::anyOf &&
401 !targetEnv.allowsAnyOf(opRequiredMode)) {
402 op->emitOpError() << "illegal: requires"
403 << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
404 << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
405 << "] but not enabled in target\n";
406 return failure();
407 }
408
409 // Each extension can contain a list of profiles that it works with, usually
410 // have the same data type.
411 if constexpr (std::is_same_v<T, Extension>) {
412 for (const auto &mode : opRequiredMode) {
414 if (!targetEnv.allowsAnyOf(coProfs)) {
415 op->emitOpError() << "illegal: requires ["
416 << llvm::join(stringifyProfile<Profile>(coProfs),
417 ", ")
418 << "] to work with but not enabled in target\n";
419 return failure();
420 }
421 }
422 }
423
424 // Ensure the profile inference match the profile knowledge of the
425 // specification.
426 for (const auto &cands : specRequiredModeSet) {
427 for (const auto &mode : opRequiredMode) {
428 if (!llvm::is_contained(cands, mode)) {
429 op->emitOpError() << "illegal: requires ["
430 << llvm::join(stringifyProfile<T>(opRequiredMode),
431 ", ")
432 << "] but not included in the profile compliance ["
433 << llvm::join(
434 stringifyProfile<T>(specRequiredModeSet), ", ")
435 << "]\n";
436 return failure();
437 }
438 }
439 }
440
441 // Ensure the matched op compliance version does not exceed the target
442 // specification version.
443 const VersionedTypeInfo versionedTypeInfo =
444 opDefinition.operandTypeInfoSet[0];
445 const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
446 const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
447 if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
448 op->emitOpError() << "illegal: the target specification version ("
449 << stringifyVersion(targetVersion)
450 << ") is not backwards compatible with the op compliance "
451 "specification version ("
452 << stringifyVersion(complianceVersion) << ")\n";
453 return failure();
454 }
455
456 return success();
457}
458
459LogicalResult
461 const tosa::TargetEnv &targetEnv) {
462 if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
463 return checkProfileOrExtension<Profile>(op, targetEnv,
464 interface.getProfiles());
465
466 return success();
467}
468
469LogicalResult
471 const tosa::TargetEnv &targetEnv) {
472 if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
473 return checkProfileOrExtension<Extension>(op, targetEnv,
474 interface.getExtensions());
475
476 return success();
477}
478
480 const auto maybeProfDef = getOperatorDefinition<Profile>(op);
481 const auto maybeExtDef = getOperatorDefinition<Extension>(op);
482 if (failed(maybeProfDef) && failed(maybeExtDef))
483 return success();
484
485 const bool hasEntry =
486 (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
487 (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
488 if (!hasEntry) {
489 std::string message;
490 llvm::raw_string_ostream os(message);
491 os << "illegal: operation operand/result data types did not align with any "
492 "profile or extension, got (";
493
494 ProfileInfoDepot depot(op);
495 SmallVector<TypeInfo> current = depot.getInfo();
496 for (const auto &typeInfo : llvm::drop_end(current))
497 os << stringifyTypeInfo(typeInfo) << ",";
498 os << stringifyTypeInfo(current.back()) << ")";
499
500 // avoid polluting the error message output by outputting only
501 // the best match
502 const std::string opName = op->getName().getStringRef().str();
503 int maxMatches = -1;
504 SmallVector<TypeInfo> bestTypeInfo;
505 const auto searchBestMatch = [&](auto map) {
506 for (const auto &complianceInfos : map[opName]) {
507 for (const auto &versionedTypeInfos :
508 complianceInfos.operandTypeInfoSet) {
509 const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
510 const int matches = llvm::count_if(
511 llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
512 return isSameTypeInfo(std::get<0>(zipType),
513 std::get<1>(zipType));
514 });
515 if (matches > maxMatches) {
516 maxMatches = matches;
517 bestTypeInfo = typeInfos;
518 }
519 }
520 }
521 };
522 searchBestMatch(getProfileComplianceMap<Profile>());
523 searchBestMatch(getProfileComplianceMap<Extension>());
524
525 os << ", did you mean (";
526 for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
527 os << stringifyTypeInfo(typeInfo) << ",";
528 os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
529 os << "Otherwise, please refer to the 'supported data types' for '"
530 << opName << "' in the specification.";
531 op->emitOpError(message);
532 return failure();
533 }
534
535 return success();
536}
537
538// Find the profiles or extensions requirement according to the signature of
539// type of the operand list.
540template <typename T>
543 assert(compInfo.size() != 0 &&
544 "profile-based compliance information is empty");
545
546 // Populate the type of profile/extension relevant operands.
547 ProfileInfoDepot depot(op);
548 SmallVector<TypeInfo> present = depot.getInfo();
549 if (present.size() == 0)
550 return {};
551
552 for (size_t i = 0; i < compInfo.size(); i++) {
553 SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
554 for (const auto &set : sets) {
555 SmallVector<TypeInfo> expected = set.first;
556 assert(present.size() == expected.size() &&
557 "the entries for profile-based compliance do not match between "
558 "the generated metadata and the type definition retrieved from "
559 " the operation");
560
561 bool isFound = true;
562 // Compare the type signature between the given operation and the
563 // compliance metadata.
564 for (size_t j = 0; j < expected.size(); j++) {
565 if (!isSameTypeInfo(present[j], expected[j])) {
566 // Verify the next mode set from the list.
567 isFound = false;
568 break;
569 }
570 }
571
572 if (isFound == true) {
573 SmallVector<VersionedTypeInfo> typeInfoSet{set};
574 OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
575 compInfo[i].condition};
576 return info;
577 }
578 }
579 }
580
581 return {};
582}
583
584// Debug utilites.
585template <typename T>
588 SmallVector<StringRef> debugStrings;
589 for (const auto &profile : profiles) {
590 if constexpr (std::is_same_v<T, Profile>)
591 debugStrings.push_back(tosa::stringifyProfile(profile));
592 else
593 debugStrings.push_back(tosa::stringifyExtension(profile));
594 }
595 return debugStrings;
596}
597
598template <typename T>
600 const SmallVector<ArrayRef<T>> &profileSet) {
601 SmallVector<StringRef> debugStrings;
602
603 for (const auto &profiles : profileSet) {
604 auto tempStrings = stringifyProfile<T>(profiles);
605 llvm::append_range(debugStrings, tempStrings);
606 }
607
608 return debugStrings;
609}
610
613 if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
614 return {"i" + llvm::utostr(typeInfo.bitWidth)};
615 }
616 if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
617 return {"f16"};
618 } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
619 return {"f32"};
620 } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
621 return {"bf16"};
622 } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
623 return {"fp8e4m3"};
624 } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
625 return {"fp8e5m2"};
626 } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
627 return {"fp6e2m3"};
628 } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
629 return {"fp6e3m2"};
630 } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
631 return {"fp4e2m1"};
632 } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
633 return {"fp8e8m0"};
634 } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
635 return {"mxint8"};
636 }
637 llvm_unreachable("unknown type");
638}
return success()
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)
#define POPULATE_PROFILE_INFO_SKIP(tosaOp)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Profile > > > OperationProfileComplianceMap
std::pair< SmallVector< TypeInfo >, SpecificationVersion > VersionedTypeInfo
std::unordered_map< std::string, SmallVector< OpComplianceInfo< Extension > > > OperationExtensionComplianceMap
@ Gather
SmallVector< TypeInfo > getInfo()
bool isSameTypeInfo(TypeInfo a, TypeInfo b)
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv)
SmallVector< Profile > getCooperativeProfiles(Extension ext)
OpComplianceInfo< T > findMatchedEntry(Operation *op, SmallVector< OpComplianceInfo< T > > compInfo)
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv)
LogicalResult checkInvalid(Operation *op)
SmallVector< StringRef > stringifyProfile(ArrayRef< T > profiles)
LogicalResult checkProfileOrExtension(Operation *op, const tosa::TargetEnv &targetEnv, const SmallVector< ArrayRef< T > > &specDefinedProfileSet)
static llvm::SmallString< 7 > stringifyTypeInfo(const TypeInfo &typeInfo)
std::unordered_map< std::string, SmallVector< OpComplianceInfo< T > > > getProfileComplianceMap()
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents the capability enabled in the target implementation such as profile,...
Definition TargetEnv.h:97
bool allowsAllOf(ArrayRef< Profile > profs) const
Definition TargetEnv.h:130
bool allowsAnyOf(ArrayRef< Profile > profs) const
Definition TargetEnv.h:126
TosaSpecificationVersion getSpecVersion() const
Definition TargetEnv.h:110
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
Definition TargetEnv.h:55
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:64
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Definition TargetEnv.cpp:15
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
mlir::TypeID typeID
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.