MLIR  21.0.0git
TargetAndABI.cpp
Go to the documentation of this file.
1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TargetEnv
22 //===----------------------------------------------------------------------===//
23 
25  : targetAttr(targetAttr) {
26  givenExtensions.insert_range(targetAttr.getExtensions());
27 
28  // Add extensions implied by the current version.
29  givenExtensions.insert_range(
31 
32  for (spirv::Capability cap : targetAttr.getCapabilities()) {
33  givenCapabilities.insert(cap);
34 
35  // Add capabilities implied by the current capability.
36  givenCapabilities.insert_range(spirv::getRecursiveImpliedCapabilities(cap));
37  }
38 }
39 
40 spirv::Version spirv::TargetEnv::getVersion() const {
41  return targetAttr.getVersion();
42 }
43 
44 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
45  return givenCapabilities.count(capability);
46 }
47 
48 std::optional<spirv::Capability>
50  const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
51  return givenCapabilities.count(cap);
52  });
53  if (chosen != caps.end())
54  return *chosen;
55  return std::nullopt;
56 }
57 
58 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
59  return givenExtensions.count(extension);
60 }
61 
62 std::optional<spirv::Extension>
64  const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
65  return givenExtensions.count(ext);
66  });
67  if (chosen != exts.end())
68  return *chosen;
69  return std::nullopt;
70 }
71 
72 spirv::Vendor spirv::TargetEnv::getVendorID() const {
73  return targetAttr.getVendorID();
74 }
75 
76 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
77  return targetAttr.getDeviceType();
78 }
79 
81  return targetAttr.getDeviceID();
82 }
83 
84 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
85  return targetAttr.getResourceLimits();
86 }
87 
89  return targetAttr.getContext();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Utility functions
94 //===----------------------------------------------------------------------===//
95 
97  return "spirv.interface_var_abi";
98 }
99 
101 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
102  std::optional<spirv::StorageClass> storageClass,
103  MLIRContext *context) {
104  return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
105  context);
106 }
107 
109  for (spirv::Capability cap : targetAttr.getCapabilities()) {
110  if (cap == spirv::Capability::Kernel)
111  return false;
112  if (cap == spirv::Capability::Shader)
113  return true;
114  }
115  return false;
116 }
117 
118 StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
119 
120 spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
121  MLIRContext *context, ArrayRef<int32_t> workgroupSize,
122  std::optional<int> subgroupSize, std::optional<int> targetWidth) {
123  DenseI32ArrayAttr workgroupSizeAttr;
124  if (!workgroupSize.empty()) {
125  assert(workgroupSize.size() == 3);
126  workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
127  }
128  return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
129  targetWidth);
130 }
131 
132 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
133  while (op && !isa<FunctionOpInterface>(op))
134  op = op->getParentOp();
135  if (!op)
136  return {};
137 
138  if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
140  return attr;
141 
142  return {};
143 }
144 
146  if (auto entryPoint = spirv::lookupEntryPointABI(op))
147  return entryPoint.getWorkgroupSize();
148 
149  return {};
150 }
151 
152 spirv::ResourceLimitsAttr
154  // All the fields have default values. Here we just provide a nicer way to
155  // construct a default resource limit attribute.
156  Builder b(context);
158  context,
159  /*max_compute_shared_memory_size=*/16384,
160  /*max_compute_workgroup_invocations=*/128,
161  /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
162  /*subgroup_size=*/32,
163  /*min_subgroup_size=*/std::nullopt,
164  /*max_subgroup_size=*/std::nullopt,
165  /*cooperative_matrix_properties_khr=*/ArrayAttr{},
166  /*cooperative_matrix_properties_nv=*/ArrayAttr{});
167 }
168 
169 StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
170 
172  auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
173  {spirv::Capability::Shader},
174  ArrayRef<Extension>(), context);
176  triple, spirv::getDefaultResourceLimits(context),
177  spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
178  spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
179 }
180 
182  while (op) {
184  if (!op)
185  break;
186 
187  if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
189  return attr;
190 
191  op = op->getParentOp();
192  }
193 
194  return {};
195 }
196 
199  return attr;
200 
201  return getDefaultTargetEnv(op->getContext());
202 }
203 
204 spirv::AddressingModel
206  bool use64bitAddress) {
207  for (spirv::Capability cap : targetAttr.getCapabilities()) {
208  if (cap == Capability::Kernel)
209  return use64bitAddress ? spirv::AddressingModel::Physical64
210  : spirv::AddressingModel::Physical32;
211  // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
212  // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
213  // and PhysicalStorageBuffer64EXT
214  if (cap == Capability::PhysicalStorageBufferAddresses)
215  return spirv::AddressingModel::PhysicalStorageBuffer64;
216  }
217  // Logical addressing doesn't need any capabilities so return it as default.
218  return spirv::AddressingModel::Logical;
219 }
220 
221 FailureOr<spirv::ExecutionModel>
223  for (spirv::Capability cap : targetAttr.getCapabilities()) {
224  if (cap == spirv::Capability::Kernel)
225  return spirv::ExecutionModel::Kernel;
226  if (cap == spirv::Capability::Shader)
227  return spirv::ExecutionModel::GLCompute;
228  }
229  return failure();
230 }
231 
232 FailureOr<spirv::MemoryModel>
234  for (spirv::Capability cap : targetAttr.getCapabilities()) {
235  if (cap == spirv::Capability::Kernel)
237  if (cap == spirv::Capability::Shader)
238  return spirv::MemoryModel::GLSL450;
239  }
240  return failure();
241 }
constexpr unsigned subgroupSize
HW dependent constants.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
An attribute that specifies the information regarding the interface variable: descriptor set,...
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
VerCapExtAttr::ext_range getExtensions()
Returns the target extensions.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
DeviceType getDeviceType() const
Returns the device type.
Version getVersion() const
Vendor getVendorID() const
Returns the vendor ID.
bool allows(Capability) const
Returns true if the given capability is allowed.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
TargetEnv(TargetEnvAttr targetAttr)
MLIRContext * getContext() const
Returns the MLIRContext.
uint32_t getDeviceID() const
Returns the device ID.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
SmallVector< Capability, 0 > getRecursiveImpliedCapabilities(Capability cap)
Returns the recursively implied capabilities for the given capability.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
ArrayRef< Extension > getImpliedExtensions(Version version)
Returns the implied extensions for the given version.
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
EntryPointABIAttr getEntryPointABIAttr(MLIRContext *context, ArrayRef< int32_t > workgroupSize={}, std::optional< int > subgroupSize={}, std::optional< int > targetWidth={})
Gets the EntryPointABIAttr given its fields.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...