MLIR  18.0.0git
TargetAndABI.cpp
Go to the documentation of this file.
1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TargetEnv
22 //===----------------------------------------------------------------------===//
23 
25  : targetAttr(targetAttr) {
26  for (spirv::Extension ext : targetAttr.getExtensions())
27  givenExtensions.insert(ext);
28 
29  // Add extensions implied by the current version.
30  for (spirv::Extension ext :
32  givenExtensions.insert(ext);
33 
34  for (spirv::Capability cap : targetAttr.getCapabilities()) {
35  givenCapabilities.insert(cap);
36 
37  // Add capabilities implied by the current capability.
38  for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
39  givenCapabilities.insert(c);
40  }
41 }
42 
43 spirv::Version spirv::TargetEnv::getVersion() const {
44  return targetAttr.getVersion();
45 }
46 
47 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
48  return givenCapabilities.count(capability);
49 }
50 
51 std::optional<spirv::Capability>
53  const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
54  return givenCapabilities.count(cap);
55  });
56  if (chosen != caps.end())
57  return *chosen;
58  return std::nullopt;
59 }
60 
61 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
62  return givenExtensions.count(extension);
63 }
64 
65 std::optional<spirv::Extension>
67  const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
68  return givenExtensions.count(ext);
69  });
70  if (chosen != exts.end())
71  return *chosen;
72  return std::nullopt;
73 }
74 
75 spirv::Vendor spirv::TargetEnv::getVendorID() const {
76  return targetAttr.getVendorID();
77 }
78 
79 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
80  return targetAttr.getDeviceType();
81 }
82 
84  return targetAttr.getDeviceID();
85 }
86 
87 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
88  return targetAttr.getResourceLimits();
89 }
90 
92  return targetAttr.getContext();
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Utility functions
97 //===----------------------------------------------------------------------===//
98 
100  return "spirv.interface_var_abi";
101 }
102 
104 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
105  std::optional<spirv::StorageClass> storageClass,
106  MLIRContext *context) {
107  return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
108  context);
109 }
110 
112  for (spirv::Capability cap : targetAttr.getCapabilities()) {
113  if (cap == spirv::Capability::Kernel)
114  return false;
115  if (cap == spirv::Capability::Shader)
116  return true;
117  }
118  return false;
119 }
120 
121 StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
122 
123 spirv::EntryPointABIAttr
125  ArrayRef<int32_t> workgroupSize,
126  std::optional<int> subgroupSize) {
127  DenseI32ArrayAttr workgroupSizeAttr;
128  if (!workgroupSize.empty()) {
129  assert(workgroupSize.size() == 3);
130  workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
131  }
132  return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr,
133  subgroupSize);
134 }
135 
136 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
137  while (op && !isa<FunctionOpInterface>(op))
138  op = op->getParentOp();
139  if (!op)
140  return {};
141 
142  if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
144  return attr;
145 
146  return {};
147 }
148 
150  if (auto entryPoint = spirv::lookupEntryPointABI(op))
151  return entryPoint.getWorkgroupSize();
152 
153  return {};
154 }
155 
156 spirv::ResourceLimitsAttr
158  // All the fields have default values. Here we just provide a nicer way to
159  // construct a default resource limit attribute.
160  Builder b(context);
162  context,
163  /*max_compute_shared_memory_size=*/16384,
164  /*max_compute_workgroup_invocations=*/128,
165  /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
166  /*subgroup_size=*/32,
167  /*min_subgroup_size=*/std::nullopt,
168  /*max_subgroup_size=*/std::nullopt,
169  /*cooperative_matrix_properties_khr=*/ArrayAttr{},
170  /*cooperative_matrix_properties_nv=*/ArrayAttr{});
171 }
172 
173 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
174 
176  auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
177  {spirv::Capability::Shader},
178  ArrayRef<Extension>(), context);
180  triple, spirv::getDefaultResourceLimits(context),
181  spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
182  spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
183 }
184 
186  while (op) {
188  if (!op)
189  break;
190 
191  if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
193  return attr;
194 
195  op = op->getParentOp();
196  }
197 
198  return {};
199 }
200 
203  return attr;
204 
205  return getDefaultTargetEnv(op->getContext());
206 }
207 
208 spirv::AddressingModel
210  bool use64bitAddress) {
211  for (spirv::Capability cap : targetAttr.getCapabilities()) {
212  if (cap == Capability::Kernel)
213  return use64bitAddress ? spirv::AddressingModel::Physical64
214  : spirv::AddressingModel::Physical32;
215  // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
216  // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
217  // and PhysicalStorageBuffer64EXT
218  if (cap == Capability::PhysicalStorageBufferAddresses)
219  return spirv::AddressingModel::PhysicalStorageBuffer64;
220  }
221  // Logical addressing doesn't need any capabilities so return it as default.
222  return spirv::AddressingModel::Logical;
223 }
224 
227  for (spirv::Capability cap : targetAttr.getCapabilities()) {
228  if (cap == spirv::Capability::Kernel)
229  return spirv::ExecutionModel::Kernel;
230  if (cap == spirv::Capability::Shader)
231  return spirv::ExecutionModel::GLCompute;
232  }
233  return failure();
234 }
235 
238  for (spirv::Capability cap : targetAttr.getCapabilities()) {
239  if (cap == spirv::Capability::Kernel)
241  if (cap == spirv::Capability::Shader)
242  return spirv::MemoryModel::GLSL450;
243  }
244  return failure();
245 }
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:283
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
An attribute that specifies the information regarding the interface variable: descriptor set,...
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
VerCapExtAttr::ext_range getExtensions()
Returns the target extensions.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
DeviceType getDeviceType() const
Returns the device type.
Version getVersion() const
Vendor getVendorID() const
Returns the vendor ID.
bool allows(Capability) const
Returns true if the given capability is allowed.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
TargetEnv(TargetEnvAttr targetAttr)
MLIRContext * getContext() const
Returns the MLIRContext.
uint32_t getDeviceID() const
Returns the device ID.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
SmallVector< Capability, 0 > getRecursiveImpliedCapabilities(Capability cap)
Returns the recursively implied capabilities for the given capability.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
EntryPointABIAttr getEntryPointABIAttr(MLIRContext *context, ArrayRef< int32_t > workgroupSize={}, std::optional< int > subgroupSize={})
Gets the EntryPointABIAttr given its fields.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
ArrayRef< Extension > getImpliedExtensions(Version version)
Returns the implied extensions for the given version.
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...