MLIR  19.0.0git
UpdateVCEPass.cpp
Go to the documentation of this file.
1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include <optional>
26 
27 namespace mlir {
28 namespace spirv {
29 #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31 } // namespace spirv
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 /// Pass to deduce minimal version/extension/capability requirements for a
38 /// spirv::ModuleOp.
39 class UpdateVCEPass final
40  : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
41  void runOnOperation() override;
42 };
43 } // namespace
44 
45 /// Checks that `candidates` extension requirements are possible to be satisfied
46 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
47 /// errors attaching to the given `op` on failures.
48 ///
49 /// `candidates` is a vector of vector for extension requirements following
50 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51 /// convention.
53  Operation *op, const spirv::TargetEnv &targetEnv,
55  SetVector<spirv::Extension> &deducedExtensions) {
56  for (const auto &ors : candidates) {
57  if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
58  deducedExtensions.insert(*chosen);
59  } else {
60  SmallVector<StringRef, 4> extStrings;
61  for (spirv::Extension ext : ors)
62  extStrings.push_back(spirv::stringifyExtension(ext));
63 
64  return op->emitError("'")
65  << op->getName() << "' requires at least one extension in ["
66  << llvm::join(extStrings, ", ")
67  << "] but none allowed in target environment";
68  }
69  }
70  return success();
71 }
72 
73 /// Checks that `candidates`capability requirements are possible to be satisfied
74 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
75 /// errors attaching to the given `op` on failures.
76 ///
77 /// `candidates` is a vector of vector for capability requirements following
78 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
79 /// convention.
81  Operation *op, const spirv::TargetEnv &targetEnv,
83  SetVector<spirv::Capability> &deducedCapabilities) {
84  for (const auto &ors : candidates) {
85  if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
86  deducedCapabilities.insert(*chosen);
87  } else {
88  SmallVector<StringRef, 4> capStrings;
89  for (spirv::Capability cap : ors)
90  capStrings.push_back(spirv::stringifyCapability(cap));
91 
92  return op->emitError("'")
93  << op->getName() << "' requires at least one capability in ["
94  << llvm::join(capStrings, ", ")
95  << "] but none allowed in target environment";
96  }
97  }
98  return success();
99 }
100 
101 void UpdateVCEPass::runOnOperation() {
102  spirv::ModuleOp module = getOperation();
103 
104  spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
105  if (!targetAttr) {
106  module.emitError("missing 'spirv.target_env' attribute");
107  return signalPassFailure();
108  }
109 
110  spirv::TargetEnv targetEnv(targetAttr);
111  spirv::Version allowedVersion = targetAttr.getVersion();
112 
113  spirv::Version deducedVersion = spirv::Version::V_1_0;
114  SetVector<spirv::Extension> deducedExtensions;
115  SetVector<spirv::Capability> deducedCapabilities;
116 
117  // Walk each SPIR-V op to deduce the minimal version/extension/capability
118  // requirements.
119  WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
120  // Op min version requirements
121  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
122  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
123  if (minVersion) {
124  deducedVersion = std::max(deducedVersion, *minVersion);
125  if (deducedVersion > allowedVersion) {
126  return op->emitError("'")
127  << op->getName() << "' requires min version "
128  << spirv::stringifyVersion(deducedVersion)
129  << " but target environment allows up to "
130  << spirv::stringifyVersion(allowedVersion);
131  }
132  }
133  }
134 
135  // Op extension requirements
136  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
138  op, targetEnv, extensions.getExtensions(), deducedExtensions)))
139  return WalkResult::interrupt();
140 
141  // Op capability requirements
142  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
144  op, targetEnv, capabilities.getCapabilities(),
145  deducedCapabilities)))
146  return WalkResult::interrupt();
147 
148  SmallVector<Type, 4> valueTypes;
149  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
150  valueTypes.append(op->result_type_begin(), op->result_type_end());
151 
152  // Special treatment for global variables, whose type requirements are
153  // conveyed by type attributes.
154  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
155  valueTypes.push_back(globalVar.getType());
156 
157  // Requirements from values' types
158  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
159  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
160  for (Type valueType : valueTypes) {
161  typeExtensions.clear();
162  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
164  op, targetEnv, typeExtensions, deducedExtensions)))
165  return WalkResult::interrupt();
166 
167  typeCapabilities.clear();
168  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
170  op, targetEnv, typeCapabilities, deducedCapabilities)))
171  return WalkResult::interrupt();
172  }
173 
174  return WalkResult::advance();
175  });
176 
177  if (walkResult.wasInterrupted())
178  return signalPassFailure();
179 
180  // TODO: verify that the deduced version is consistent with
181  // SPIR-V ops' maximal version requirements.
182 
183  auto triple = spirv::VerCapExtAttr::get(
184  deducedVersion, deducedCapabilities.getArrayRef(),
185  deducedExtensions.getArrayRef(), &getContext());
186  module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
187 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_iterator operand_type_end()
Definition: Operation.h:391
result_type_iterator result_type_end()
Definition: Operation.h:422
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_iterator result_type_begin()
Definition: Operation.h:421
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_iterator operand_type_begin()
Definition: Operation.h:390
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
bool allows(Capability) const
Returns true if the given capability is allowed.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26