MLIR  22.0.0git
UpdateVCEPass.cpp
Go to the documentation of this file.
1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Visitors.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace spirv {
26 #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
27 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
28 } // namespace spirv
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 /// Pass to deduce minimal version/extension/capability requirements for a
35 /// spirv::ModuleOp.
36 class UpdateVCEPass final
37  : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
38  void runOnOperation() override;
39 };
40 } // namespace
41 
42 /// Checks that `candidates` extension requirements are possible to be satisfied
43 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
44 /// errors attaching to the given `op` on failures.
45 ///
46 /// `candidates` is a vector of vector for extension requirements following
47 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
48 /// convention.
50  Operation *op, const spirv::TargetEnv &targetEnv,
52  SetVector<spirv::Extension> &deducedExtensions) {
53  for (const auto &ors : candidates) {
54  if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
55  deducedExtensions.insert(*chosen);
56  } else {
57  SmallVector<StringRef, 4> extStrings;
58  for (spirv::Extension ext : ors)
59  extStrings.push_back(spirv::stringifyExtension(ext));
60 
61  return op->emitError("'")
62  << op->getName() << "' requires at least one extension in ["
63  << llvm::join(extStrings, ", ")
64  << "] but none allowed in target environment";
65  }
66  }
67  return success();
68 }
69 
70 /// Checks that `candidates`capability requirements are possible to be satisfied
71 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
72 /// errors attaching to the given `op` on failures.
73 ///
74 /// `candidates` is a vector of vector for capability requirements following
75 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
76 /// convention.
78  Operation *op, const spirv::TargetEnv &targetEnv,
80  SetVector<spirv::Capability> &deducedCapabilities) {
81  for (const auto &ors : candidates) {
82  if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
83  deducedCapabilities.insert(*chosen);
84  } else {
85  SmallVector<StringRef, 4> capStrings;
86  for (spirv::Capability cap : ors)
87  capStrings.push_back(spirv::stringifyCapability(cap));
88 
89  return op->emitError("'")
90  << op->getName() << "' requires at least one capability in ["
91  << llvm::join(capStrings, ", ")
92  << "] but none allowed in target environment";
93  }
94  }
95  return success();
96 }
97 
100  for (spirv::Capability cap : caps)
101  tmp.insert_range(getRecursiveImpliedCapabilities(cap));
102  caps.insert_range(std::move(tmp));
103 }
104 
105 void UpdateVCEPass::runOnOperation() {
106  spirv::ModuleOp module = getOperation();
107 
108  spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
109  if (!targetAttr) {
110  module.emitError("missing 'spirv.target_env' attribute");
111  return signalPassFailure();
112  }
113 
114  spirv::TargetEnv targetEnv(targetAttr);
115  spirv::Version allowedVersion = targetAttr.getVersion();
116 
117  spirv::Version deducedVersion = spirv::Version::V_1_0;
118  SetVector<spirv::Extension> deducedExtensions;
119  SetVector<spirv::Capability> deducedCapabilities;
120 
121  // Walk each SPIR-V op to deduce the minimal version/extension/capability
122  // requirements.
123  WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
124  // Op min version requirements
125  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
126  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
127  if (minVersion) {
128  deducedVersion = std::max(deducedVersion, *minVersion);
129  if (deducedVersion > allowedVersion) {
130  return op->emitError("'")
131  << op->getName() << "' requires min version "
132  << spirv::stringifyVersion(deducedVersion)
133  << " but target environment allows up to "
134  << spirv::stringifyVersion(allowedVersion);
135  }
136  }
137  }
138 
139  // Op extension requirements
140  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
142  op, targetEnv, extensions.getExtensions(), deducedExtensions)))
143  return WalkResult::interrupt();
144 
145  // Op capability requirements
146  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
148  op, targetEnv, capabilities.getCapabilities(),
149  deducedCapabilities)))
150  return WalkResult::interrupt();
151 
152  SmallVector<Type, 4> valueTypes;
153  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
154  valueTypes.append(op->result_type_begin(), op->result_type_end());
155 
156  // Special treatment for global variables, whose type requirements are
157  // conveyed by type attributes.
158  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
159  valueTypes.push_back(globalVar.getType());
160 
161  // If the op is FunctionLike make sure to process input and result types.
162  if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
163  llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
164  llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
165  }
166 
167  // Requirements from values' types
168  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
169  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
170  for (Type valueType : valueTypes) {
171  typeExtensions.clear();
172  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
174  op, targetEnv, typeExtensions, deducedExtensions)))
175  return WalkResult::interrupt();
176 
177  typeCapabilities.clear();
178  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
180  op, targetEnv, typeCapabilities, deducedCapabilities)))
181  return WalkResult::interrupt();
182  }
183 
184  return WalkResult::advance();
185  });
186 
187  if (walkResult.wasInterrupted())
188  return signalPassFailure();
189 
190  addAllImpliedCapabilities(deducedCapabilities);
191 
192  // Update min version requirement for capabilities after deducing them.
193  for (spirv::Capability cap : deducedCapabilities) {
194  if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
195  deducedVersion = std::max(deducedVersion, *minVersion);
196  if (deducedVersion > allowedVersion) {
197  module.emitError("Capability '")
198  << spirv::stringifyCapability(cap) << "' requires min version "
199  << spirv::stringifyVersion(deducedVersion)
200  << " but target environment allows up to "
201  << spirv::stringifyVersion(allowedVersion);
202  return signalPassFailure();
203  }
204  }
205  }
206 
207  // TODO: verify that the deduced version is consistent with
208  // SPIR-V ops' maximal version requirements.
209 
210  auto triple = spirv::VerCapExtAttr::get(
211  deducedVersion, deducedCapabilities.getArrayRef(),
212  deducedExtensions.getArrayRef(), &getContext());
213  module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
214 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
static void addAllImpliedCapabilities(SetVector< spirv::Capability > &caps)
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_iterator operand_type_end()
Definition: Operation.h:396
result_type_iterator result_type_end()
Definition: Operation.h:427
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
result_type_iterator result_type_begin()
Definition: Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
bool allows(Capability) const
Returns true if the given capability is allowed.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< Capability, 0 > getRecursiveImpliedCapabilities(Capability cap)
Returns the recursively implied capabilities for the given capability.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
Include the generated interface declarations.