22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
29 #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
39 class UpdateVCEPass final
40 :
public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
41 void runOnOperation()
override;
56 for (
const auto &ors : candidates) {
57 if (std::optional<spirv::Extension> chosen = targetEnv.
allows(ors)) {
58 deducedExtensions.insert(*chosen);
61 for (spirv::Extension ext : ors)
62 extStrings.push_back(spirv::stringifyExtension(ext));
65 << op->
getName() <<
"' requires at least one extension in ["
66 << llvm::join(extStrings,
", ")
67 <<
"] but none allowed in target environment";
84 for (
const auto &ors : candidates) {
85 if (std::optional<spirv::Capability> chosen = targetEnv.
allows(ors)) {
86 deducedCapabilities.insert(*chosen);
89 for (spirv::Capability cap : ors)
90 capStrings.push_back(spirv::stringifyCapability(cap));
93 << op->
getName() <<
"' requires at least one capability in ["
94 << llvm::join(capStrings,
", ")
95 <<
"] but none allowed in target environment";
101 void UpdateVCEPass::runOnOperation() {
102 spirv::ModuleOp module = getOperation();
106 module.emitError(
"missing 'spirv.target_env' attribute");
107 return signalPassFailure();
111 spirv::Version allowedVersion = targetAttr.
getVersion();
113 spirv::Version deducedVersion = spirv::Version::V_1_0;
121 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
122 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
124 deducedVersion =
std::max(deducedVersion, *minVersion);
125 if (deducedVersion > allowedVersion) {
127 << op->
getName() <<
"' requires min version "
128 << spirv::stringifyVersion(deducedVersion)
129 <<
" but target environment allows up to "
130 << spirv::stringifyVersion(allowedVersion);
136 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
138 op, targetEnv, extensions.getExtensions(), deducedExtensions)))
142 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
144 op, targetEnv, capabilities.getCapabilities(),
145 deducedCapabilities)))
154 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
155 valueTypes.push_back(globalVar.getType());
160 for (
Type valueType : valueTypes) {
161 typeExtensions.clear();
162 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
164 op, targetEnv, typeExtensions, deducedExtensions)))
167 typeCapabilities.clear();
168 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
170 op, targetEnv, typeCapabilities, deducedCapabilities)))
178 return signalPassFailure();
184 deducedVersion, deducedCapabilities.getArrayRef(),
185 deducedExtensions.getArrayRef(), &
getContext());
186 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
Operation is the basic unit of execution within MLIR.
operand_type_iterator operand_type_end()
result_type_iterator result_type_end()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_iterator operand_type_begin()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
bool allows(Capability) const
Returns true if the given capability is allowed.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
Include the generated interface declarations.