MLIR 23.0.0git
TargetEnv.cpp
Go to the documentation of this file.
1//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/Support/FormatVariadic.h"
11
12namespace mlir {
13namespace tosa {
14
16 return llvm::formatv("{0}.{1}{2}", version.getMajor(), version.getMinor(),
17 version.isDraft() ? ".draft" : "");
18}
19
20TosaSpecificationVersion getMinVersion(const Profile &profile) {
21 switch (profile) {
22 case Profile::pro_int:
23 case Profile::pro_fp:
24 return TosaSpecificationVersion(1, 0);
25 case Profile::none:
26 return TosaSpecificationVersion(0, 0);
27 }
28 llvm_unreachable("Unknown TOSA profile");
29}
30
31TosaSpecificationVersion getMinVersion(const Extension &extension) {
32 switch (extension) {
33 case Extension::int16:
34 case Extension::int4:
35 case Extension::bf16:
36 case Extension::fp8e4m3:
37 case Extension::fp8e5m2:
38 case Extension::fft:
39 case Extension::variable:
40 case Extension::controlflow:
41 case Extension::doubleround:
42 case Extension::inexactround:
43 case Extension::dynamic:
44 return TosaSpecificationVersion(1, 0);
45 case Extension::mxfp:
46 case Extension::int64:
47 case Extension::mxfp_conv:
48 case Extension::shape:
49 return TosaSpecificationVersion(1, 1, true);
50 case Extension::none:
51 return TosaSpecificationVersion(0, 0);
52 }
53 llvm_unreachable("Unknown TOSA extension");
54}
55
57 switch (ext) {
58 case Extension::int16:
59 case Extension::int4:
60 case Extension::doubleround:
61 case Extension::inexactround:
62 return {Profile::pro_int};
63 case Extension::bf16:
64 case Extension::fp8e4m3:
65 case Extension::fp8e5m2:
66 case Extension::fft:
67 case Extension::mxfp:
68 case Extension::mxfp_conv:
69 return {Profile::pro_fp};
70 case Extension::variable:
71 case Extension::controlflow:
72 case Extension::dynamic:
73 case Extension::int64:
74 case Extension::shape:
75 return {Profile::pro_fp, Profile::pro_int};
76 case Extension::none:
77 return {};
78 };
79 llvm_unreachable("bad Extension type");
80}
81
83 switch (level) {
84 case Level::eightK:
85 case Level::none:
86 return TosaSpecificationVersion(1, 0);
87 }
88 llvm_unreachable("Unknown TOSA level");
89}
90
91FailureOr<TargetEnv>
92TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
93 Location targetEnvAttrLoc) {
94 if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
95 return failure();
96
97 return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
98 targetAttr.getProfiles(), targetAttr.getExtensions());
99}
100
101LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
102 Location targetAttrLoc) {
103 TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
104
105 const auto isCompatibleWithTargetVersion =
106 [&](const auto &targetEnum, Location targetAttrLoc,
107 StringRef enumName) -> LogicalResult {
108 const TosaSpecificationVersion minRequiredVersion =
109 getMinVersion(targetEnum);
110 if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
111 return emitError(targetAttrLoc, enumName)
112 << " '" << stringifyEnum(targetEnum)
113 << "' is not compatible with the target version "
114 << stringifyVersion(targetVersion)
115 << ", minimum required version is "
116 << stringifyVersion(minRequiredVersion);
117 return success();
118 };
119
120 const auto isExtensionCooperativeWithProfile =
121 [&](Extension ext) -> LogicalResult {
122 const auto cooperativeProfiles = getCooperativeProfiles(ext);
123
124 const ArrayRef<Profile> targetProfiles = targetAttr.getProfiles();
125 if (!llvm::any_of(cooperativeProfiles,
126 [&targetProfiles](const auto &profile) {
127 return llvm::is_contained(targetProfiles, profile);
128 }))
129 return emitError(targetAttrLoc)
130 << "use of extension '" << stringifyEnum(ext)
131 << "' requires any of profiles: [" << cooperativeProfiles
132 << "] to be enabled in the target";
133
134 return success();
135 };
136
137 for (const auto &profile : targetAttr.getProfiles())
138 if (failed(
139 isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
140 return failure();
141 for (const auto &extension : targetAttr.getExtensions()) {
142 if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
143 "extension")))
144 return failure();
145 if (failed(isExtensionCooperativeWithProfile(extension)))
146 return failure();
147 }
148 if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
149 "level")))
150 return failure();
151
152 return success();
153}
154
155TargetEnvAttr lookupTargetEnv(Operation *op) {
156 while (op) {
158 if (!op)
159 break;
160
161 if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
162 return attr;
163
164 op = op->getParentOp();
165 }
166
167 return {};
168}
169
170TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
171 return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
172 {Profile::pro_int, Profile::pro_fp}, {});
173}
174
176 if (auto attr = lookupTargetEnv(op))
177 return attr;
178
179 return getDefaultTargetEnv(op->getContext());
180}
181
182} // namespace tosa
183} // namespace mlir
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static FailureOr< TargetEnv > createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc)
Definition TargetEnv.cpp:92
static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr, Location targetAttrLoc)
A thin wrapper around the SpecificationVersion enum to represent and provide utilities around the TOS...
Definition TargetEnv.h:58
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const
Definition TargetEnv.h:67
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
Definition TargetEnv.cpp:15
SmallVector< Profile, 2 > getCooperativeProfiles(Extension ext)
Definition TargetEnv.cpp:56
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
TosaSpecificationVersion getMinVersion(const Profile &profile)
Definition TargetEnv.cpp:20
TargetEnvAttr lookupTargetEnv(Operation *op)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.