21#include "llvm/ADT/StringExtras.h"
26#define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
27#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
36class UpdateVCEPass final
38 void runOnOperation()
override;
53 for (
const auto &ors : candidates) {
54 if (std::optional<spirv::Extension> chosen = targetEnv.
allows(ors)) {
55 deducedExtensions.insert(*chosen);
58 for (spirv::Extension ext : ors)
59 extStrings.push_back(spirv::stringifyExtension(ext));
62 << op->
getName() <<
"' requires at least one extension in ["
63 << llvm::join(extStrings,
", ")
64 <<
"] but none allowed in target environment";
81 for (
const auto &ors : candidates) {
82 if (std::optional<spirv::Capability> chosen = targetEnv.
allows(ors)) {
83 deducedCapabilities.insert(*chosen);
86 for (spirv::Capability cap : ors)
87 capStrings.push_back(spirv::stringifyCapability(cap));
90 << op->
getName() <<
"' requires at least one capability in ["
91 << llvm::join(capStrings,
", ")
92 <<
"] but none allowed in target environment";
100 for (spirv::Capability cap : caps)
101 tmp.insert_range(getRecursiveImpliedCapabilities(cap));
102 caps.insert_range(std::move(tmp));
105void UpdateVCEPass::runOnOperation() {
106 spirv::ModuleOp module = getOperation();
110 module.emitError("missing 'spirv.target_env' attribute");
111 return signalPassFailure();
114 spirv::TargetEnv targetEnv(targetAttr);
115 spirv::Version allowedVersion = targetAttr.
getVersion();
117 spirv::Version deducedVersion = spirv::Version::V_1_0;
123 WalkResult walkResult =
module.walk([&](Operation *op) -> WalkResult {
125 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
126 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
128 deducedVersion = std::max(deducedVersion, *minVersion);
129 if (deducedVersion > allowedVersion) {
130 return op->emitError(
"'")
131 << op->getName() <<
"' requires min version "
132 << spirv::stringifyVersion(deducedVersion)
133 <<
" but target environment allows up to "
134 << spirv::stringifyVersion(allowedVersion);
140 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
142 op, targetEnv, extensions.getExtensions(), deducedExtensions)))
146 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
148 op, targetEnv, capabilities.getCapabilities(),
149 deducedCapabilities)))
152 SmallVector<Type, 4> valueTypes;
153 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
154 valueTypes.append(op->result_type_begin(), op->result_type_end());
159 auto requireLinkage = [&](spirv::LinkageType linkageType) -> LogicalResult {
160 if (
auto caps = spirv::getCapabilities(linkageType)) {
161 SmallVector<ArrayRef<spirv::Capability>, 1> capCandidates = {*caps};
163 op, targetEnv, capCandidates, deducedCapabilities)))
166 if (
auto exts = spirv::getExtensions(linkageType)) {
167 SmallVector<ArrayRef<spirv::Extension>, 1> extCandidates = {*exts};
169 op, targetEnv, extCandidates, deducedExtensions)))
177 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) {
178 valueTypes.push_back(globalVar.getType());
183 if (globalVar.getBinding() || globalVar.getDescriptorSet()) {
184 spirv::Capability shader = spirv::Capability::Shader;
185 SmallVector<ArrayRef<spirv::Capability>, 1> caps = {shader};
187 deducedCapabilities)))
191 if (
auto linkage = globalVar.getLinkageAttributes())
192 if (
failed(requireLinkage(linkage->getLinkageType().getValue())))
196 if (
auto funcOp = dyn_cast<spirv::FuncOp>(op))
197 if (
auto linkage = funcOp.getLinkageAttributes())
198 if (
failed(requireLinkage(linkage->getLinkageType().getValue())))
202 if (
auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
203 llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
204 llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
208 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
209 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
210 for (Type valueType : valueTypes) {
211 typeExtensions.clear();
212 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
214 op, targetEnv, typeExtensions, deducedExtensions)))
217 typeCapabilities.clear();
218 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
220 op, targetEnv, typeCapabilities, deducedCapabilities)))
228 return signalPassFailure();
233 for (spirv::Capability cap : deducedCapabilities) {
234 if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
235 deducedVersion = std::max(deducedVersion, *minVersion);
236 if (deducedVersion > allowedVersion) {
237 module.emitError("Capability '")
238 << spirv::stringifyCapability(cap) << "' requires min version "
239 << spirv::stringifyVersion(deducedVersion)
240 << " but target environment allows up to "
241 << spirv::stringifyVersion(allowedVersion);
242 return signalPassFailure();
251 deducedVersion, deducedCapabilities.getArrayRef(),
252 deducedExtensions.getArrayRef(), &
getContext());
253 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
static void addAllImpliedCapabilities(SetVector< spirv::Capability > &caps)
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Version getVersion() const
Returns the target version.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
bool allows(Capability) const
Returns true if the given capability is allowed.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector