MLIR 22.0.0git
UpdateVCEPass.cpp
Go to the documentation of this file.
1//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to deduce minimal version/extension/capability
10// requirements for a spirv::ModuleOp.
11//
12//===----------------------------------------------------------------------===//
13
15
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Visitors.h"
21#include "llvm/ADT/StringExtras.h"
22#include <optional>
23
24namespace mlir {
25namespace spirv {
26#define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
27#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
28} // namespace spirv
29} // namespace mlir
30
31using namespace mlir;
32
33namespace {
34/// Pass to deduce minimal version/extension/capability requirements for a
35/// spirv::ModuleOp.
36class UpdateVCEPass final
37 : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
38 void runOnOperation() override;
39};
40} // namespace
41
42/// Checks that `candidates` extension requirements are possible to be satisfied
43/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
44/// errors attaching to the given `op` on failures.
45///
46/// `candidates` is a vector of vector for extension requirements following
47/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
48/// convention.
50 Operation *op, const spirv::TargetEnv &targetEnv,
52 SetVector<spirv::Extension> &deducedExtensions) {
53 for (const auto &ors : candidates) {
54 if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
55 deducedExtensions.insert(*chosen);
56 } else {
58 for (spirv::Extension ext : ors)
59 extStrings.push_back(spirv::stringifyExtension(ext));
60
61 return op->emitError("'")
62 << op->getName() << "' requires at least one extension in ["
63 << llvm::join(extStrings, ", ")
64 << "] but none allowed in target environment";
65 }
66 }
67 return success();
68}
69
70/// Checks that `candidates`capability requirements are possible to be satisfied
71/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
72/// errors attaching to the given `op` on failures.
73///
74/// `candidates` is a vector of vector for capability requirements following
75/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
76/// convention.
78 Operation *op, const spirv::TargetEnv &targetEnv,
80 SetVector<spirv::Capability> &deducedCapabilities) {
81 for (const auto &ors : candidates) {
82 if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
83 deducedCapabilities.insert(*chosen);
84 } else {
86 for (spirv::Capability cap : ors)
87 capStrings.push_back(spirv::stringifyCapability(cap));
88
89 return op->emitError("'")
90 << op->getName() << "' requires at least one capability in ["
91 << llvm::join(capStrings, ", ")
92 << "] but none allowed in target environment";
93 }
94 }
95 return success();
96}
97
100 for (spirv::Capability cap : caps)
101 tmp.insert_range(getRecursiveImpliedCapabilities(cap));
102 caps.insert_range(std::move(tmp));
103}
104
105void UpdateVCEPass::runOnOperation() {
106 spirv::ModuleOp module = getOperation();
107
108 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
109 if (!targetAttr) {
110 module.emitError("missing 'spirv.target_env' attribute");
111 return signalPassFailure();
112 }
113
114 spirv::TargetEnv targetEnv(targetAttr);
115 spirv::Version allowedVersion = targetAttr.getVersion();
116
117 spirv::Version deducedVersion = spirv::Version::V_1_0;
118 SetVector<spirv::Extension> deducedExtensions;
119 SetVector<spirv::Capability> deducedCapabilities;
120
121 // Walk each SPIR-V op to deduce the minimal version/extension/capability
122 // requirements.
123 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
124 // Op min version requirements
125 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
126 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
127 if (minVersion) {
128 deducedVersion = std::max(deducedVersion, *minVersion);
129 if (deducedVersion > allowedVersion) {
130 return op->emitError("'")
131 << op->getName() << "' requires min version "
132 << spirv::stringifyVersion(deducedVersion)
133 << " but target environment allows up to "
134 << spirv::stringifyVersion(allowedVersion);
135 }
136 }
137 }
138
139 // Op extension requirements
140 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
142 op, targetEnv, extensions.getExtensions(), deducedExtensions)))
143 return WalkResult::interrupt();
144
145 // Op capability requirements
146 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
148 op, targetEnv, capabilities.getCapabilities(),
149 deducedCapabilities)))
150 return WalkResult::interrupt();
151
152 SmallVector<Type, 4> valueTypes;
153 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
154 valueTypes.append(op->result_type_begin(), op->result_type_end());
155
156 // Special treatment for global variables, whose type requirements are
157 // conveyed by type attributes.
158 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
159 valueTypes.push_back(globalVar.getType());
160
161 // If the op is FunctionLike make sure to process input and result types.
162 if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
163 llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
164 llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
165 }
166
167 // Requirements from values' types
168 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
169 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
170 for (Type valueType : valueTypes) {
171 typeExtensions.clear();
172 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
174 op, targetEnv, typeExtensions, deducedExtensions)))
175 return WalkResult::interrupt();
176
177 typeCapabilities.clear();
178 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
180 op, targetEnv, typeCapabilities, deducedCapabilities)))
181 return WalkResult::interrupt();
182 }
183
184 return WalkResult::advance();
185 });
186
187 if (walkResult.wasInterrupted())
188 return signalPassFailure();
189
190 addAllImpliedCapabilities(deducedCapabilities);
191
192 // Update min version requirement for capabilities after deducing them.
193 for (spirv::Capability cap : deducedCapabilities) {
194 if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
195 deducedVersion = std::max(deducedVersion, *minVersion);
196 if (deducedVersion > allowedVersion) {
197 module.emitError("Capability '")
198 << spirv::stringifyCapability(cap) << "' requires min version "
199 << spirv::stringifyVersion(deducedVersion)
200 << " but target environment allows up to "
201 << spirv::stringifyVersion(allowedVersion);
202 return signalPassFailure();
203 }
204 }
205 }
206
207 // TODO: verify that the deduced version is consistent with
208 // SPIR-V ops' maximal version requirements.
209
210 auto triple = spirv::VerCapExtAttr::get(
211 deducedVersion, deducedCapabilities.getArrayRef(),
212 deducedExtensions.getArrayRef(), &getContext());
213 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
214}
return success()
b getContext())
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
static void addAllImpliedCapabilities(SetVector< spirv::Capability > &caps)
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
Version getVersion() const
Returns the target version.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
bool allows(Capability) const
Returns true if the given capability is allowed.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131