MLIR  14.0.0git
UpdateVCEPass.cpp
Go to the documentation of this file.
1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Pass to deduce minimal version/extension/capability requirements for a
30 /// spirv::ModuleOp.
31 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
32  void runOnOperation() override;
33 };
34 } // namespace
35 
36 /// Checks that `candidates` extension requirements are possible to be satisfied
37 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
38 /// errors attaching to the given `op` on failures.
39 ///
40 /// `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
44  Operation *op, const spirv::TargetEnv &targetEnv,
46  SetVector<spirv::Extension> &deducedExtensions) {
47  for (const auto &ors : candidates) {
48  if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
49  deducedExtensions.insert(*chosen);
50  } else {
51  SmallVector<StringRef, 4> extStrings;
52  for (spirv::Extension ext : ors)
53  extStrings.push_back(spirv::stringifyExtension(ext));
54 
55  return op->emitError("'")
56  << op->getName() << "' requires at least one extension in ["
57  << llvm::join(extStrings, ", ")
58  << "] but none allowed in target environment";
59  }
60  }
61  return success();
62 }
63 
64 /// Checks that `candidates`capability requirements are possible to be satisfied
65 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
66 /// errors attaching to the given `op` on failures.
67 ///
68 /// `candidates` is a vector of vector for capability requirements following
69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70 /// convention.
72  Operation *op, const spirv::TargetEnv &targetEnv,
74  SetVector<spirv::Capability> &deducedCapabilities) {
75  for (const auto &ors : candidates) {
76  if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
77  deducedCapabilities.insert(*chosen);
78  } else {
79  SmallVector<StringRef, 4> capStrings;
80  for (spirv::Capability cap : ors)
81  capStrings.push_back(spirv::stringifyCapability(cap));
82 
83  return op->emitError("'")
84  << op->getName() << "' requires at least one capability in ["
85  << llvm::join(capStrings, ", ")
86  << "] but none allowed in target environment";
87  }
88  }
89  return success();
90 }
91 
92 void UpdateVCEPass::runOnOperation() {
93  spirv::ModuleOp module = getOperation();
94 
95  spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
96  if (!targetAttr) {
97  module.emitError("missing 'spv.target_env' attribute");
98  return signalPassFailure();
99  }
100 
101  spirv::TargetEnv targetEnv(targetAttr);
102  spirv::Version allowedVersion = targetAttr.getVersion();
103 
104  spirv::Version deducedVersion = spirv::Version::V_1_0;
105  SetVector<spirv::Extension> deducedExtensions;
106  SetVector<spirv::Capability> deducedCapabilities;
107 
108  // Walk each SPIR-V op to deduce the minimal version/extension/capability
109  // requirements.
110  WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
111  // Op min version requirements
112  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
113  Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
114  if (minVersion) {
115  deducedVersion = std::max(deducedVersion, *minVersion);
116  if (deducedVersion > allowedVersion) {
117  return op->emitError("'")
118  << op->getName() << "' requires min version "
119  << spirv::stringifyVersion(deducedVersion)
120  << " but target environment allows up to "
121  << spirv::stringifyVersion(allowedVersion);
122  }
123  }
124  }
125 
126  // Op extension requirements
127  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
129  op, targetEnv, extensions.getExtensions(), deducedExtensions)))
130  return WalkResult::interrupt();
131 
132  // Op capability requirements
133  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
135  op, targetEnv, capabilities.getCapabilities(),
136  deducedCapabilities)))
137  return WalkResult::interrupt();
138 
139  SmallVector<Type, 4> valueTypes;
140  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
141  valueTypes.append(op->result_type_begin(), op->result_type_end());
142 
143  // Special treatment for global variables, whose type requirements are
144  // conveyed by type attributes.
145  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
146  valueTypes.push_back(globalVar.type());
147 
148  // Requirements from values' types
149  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
150  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
151  for (Type valueType : valueTypes) {
152  typeExtensions.clear();
153  valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
155  op, targetEnv, typeExtensions, deducedExtensions)))
156  return WalkResult::interrupt();
157 
158  typeCapabilities.clear();
159  valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
161  op, targetEnv, typeCapabilities, deducedCapabilities)))
162  return WalkResult::interrupt();
163  }
164 
165  return WalkResult::advance();
166  });
167 
168  if (walkResult.wasInterrupted())
169  return signalPassFailure();
170 
171  // TODO: verify that the deduced version is consistent with
172  // SPIR-V ops' maximal version requirements.
173 
174  auto triple = spirv::VerCapExtAttr::get(
175  deducedVersion, deducedCapabilities.getArrayRef(),
176  deducedExtensions.getArrayRef(), &getContext());
177  module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
178 }
179 
180 std::unique_ptr<OperationPass<spirv::ModuleOp>>
182  return std::make_unique<UpdateVCEPass>();
183 }
Include the generated interface declarations.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool allows(Capability) const
Returns true if the given capability is allowed.
static LogicalResult checkAndUpdateCapabilityRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, SetVector< spirv::Capability > &deducedCapabilities)
Checks that candidatescapability requirements are possible to be satisfied with the given targetEnv a...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
result_type_iterator result_type_end()
Definition: Operation.h:296
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
operand_type_iterator operand_type_begin()
Definition: Operation.h:264
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
operand_type_iterator operand_type_end()
Definition: Operation.h:265
result_type_iterator result_type_begin()
Definition: Operation.h:295
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
static LogicalResult checkAndUpdateExtensionRequirements(Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, SetVector< spirv::Extension > &deducedExtensions)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv a...
Version getVersion() const
Returns the target version.
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUpdateVersionCapabilityExtensionPass()
Creates an operation pass that deduces and attaches the minimal version/ capabilities/extensions requ...
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)