MLIR  14.0.0git
TargetAndABI.cpp
Go to the documentation of this file.
1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TargetEnv
20 //===----------------------------------------------------------------------===//
21 
23  : targetAttr(targetAttr) {
24  for (spirv::Extension ext : targetAttr.getExtensions())
25  givenExtensions.insert(ext);
26 
27  // Add extensions implied by the current version.
28  for (spirv::Extension ext :
30  givenExtensions.insert(ext);
31 
32  for (spirv::Capability cap : targetAttr.getCapabilities()) {
33  givenCapabilities.insert(cap);
34 
35  // Add capabilities implied by the current capability.
36  for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
37  givenCapabilities.insert(c);
38  }
39 }
40 
41 spirv::Version spirv::TargetEnv::getVersion() const {
42  return targetAttr.getVersion();
43 }
44 
45 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
46  return givenCapabilities.count(capability);
47 }
48 
51  const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
52  return givenCapabilities.count(cap);
53  });
54  if (chosen != caps.end())
55  return *chosen;
56  return llvm::None;
57 }
58 
59 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
60  return givenExtensions.count(extension);
61 }
62 
65  const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
66  return givenExtensions.count(ext);
67  });
68  if (chosen != exts.end())
69  return *chosen;
70  return llvm::None;
71 }
72 
73 spirv::Vendor spirv::TargetEnv::getVendorID() const {
74  return targetAttr.getVendorID();
75 }
76 
77 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
78  return targetAttr.getDeviceType();
79 }
80 
82  return targetAttr.getDeviceID();
83 }
84 
85 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
86  return targetAttr.getResourceLimits();
87 }
88 
90  return targetAttr.getContext();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Utility functions
95 //===----------------------------------------------------------------------===//
96 
98  return "spv.interface_var_abi";
99 }
100 
102 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
103  Optional<spirv::StorageClass> storageClass,
104  MLIRContext *context) {
105  return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
106  context);
107 }
108 
110  for (spirv::Capability cap : targetAttr.getCapabilities()) {
111  if (cap == spirv::Capability::Kernel)
112  return false;
113  if (cap == spirv::Capability::Shader)
114  return true;
115  }
116  return false;
117 }
118 
119 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
120 
121 spirv::EntryPointABIAttr
123  assert(localSize.size() == 3);
124  return spirv::EntryPointABIAttr::get(
125  DenseElementsAttr::get<int32_t>(
126  VectorType::get(3, IntegerType::get(context, 32)), localSize)
127  .cast<DenseIntElementsAttr>(),
128  context);
129 }
130 
131 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
132  while (op && !isa<FunctionOpInterface>(op))
133  op = op->getParentOp();
134  if (!op)
135  return {};
136 
137  if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
139  return attr;
140 
141  return {};
142 }
143 
145  if (auto entryPoint = spirv::lookupEntryPointABI(op))
146  return entryPoint.local_size();
147 
148  return {};
149 }
150 
151 spirv::ResourceLimitsAttr
153  // All the fields have default values. Here we just provide a nicer way to
154  // construct a default resource limit attribute.
155  return spirv::ResourceLimitsAttr ::get(
156  /*max_compute_shared_memory_size=*/nullptr,
157  /*max_compute_workgroup_invocations=*/nullptr,
158  /*max_compute_workgroup_size=*/nullptr,
159  /*subgroup_size=*/nullptr,
160  /*cooperative_matrix_properties_nv=*/nullptr, context);
161 }
162 
163 StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
164 
166  auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
167  {spirv::Capability::Shader},
168  ArrayRef<Extension>(), context);
169  return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
170  spirv::DeviceType::Unknown,
173 }
174 
176  while (op) {
178  if (!op)
179  break;
180 
181  if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
183  return attr;
184 
185  op = op->getParentOp();
186  }
187 
188  return {};
189 }
190 
193  return attr;
194 
195  return getDefaultTargetEnv(op->getContext());
196 }
197 
198 spirv::AddressingModel
200  for (spirv::Capability cap : targetAttr.getCapabilities()) {
201  // TODO: Physical64 is hard-coded here, but some information should come
202  // from TargetEnvAttr to selected between Physical32 and Physical64.
203  if (cap == Capability::Kernel)
204  return spirv::AddressingModel::Physical64;
205  }
206  // Logical addressing doesn't need any capabilities so return it as default.
207  return spirv::AddressingModel::Logical;
208 }
209 
212  for (spirv::Capability cap : targetAttr.getCapabilities()) {
213  if (cap == spirv::Capability::Kernel)
214  return spirv::ExecutionModel::Kernel;
215  if (cap == spirv::Capability::Shader)
216  return spirv::ExecutionModel::GLCompute;
217  }
218  return failure();
219 }
220 
223  for (spirv::Capability cap : targetAttr.getCapabilities()) {
224  if (cap == spirv::Capability::Addresses)
226  if (cap == spirv::Capability::Shader)
227  return spirv::MemoryModel::GLSL450;
228  }
229  return failure();
230 }
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
DenseIntElementsAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
An attribute that specifies the information regarding the interface variable: descriptor set...
static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID, DeviceType deviceType, uint32_t deviceId, DictionaryAttr limits)
Gets a TargetEnvAttr instance.
TargetEnv(TargetEnvAttr targetAttr)
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:327
DeviceType getDeviceType() const
Returns the device type.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
EntryPointABIAttr getEntryPointABIAttr(ArrayRef< int32_t > localSize, MLIRContext *context)
Gets the EntryPointABIAttr given its fields.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
bool allows(Capability) const
Returns true if the given capability is allowed.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
DeviceType getDeviceType() const
Returns the device type.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< Capability, 0 > getRecursiveImpliedCapabilities(Capability cap)
Returns the recursively implied capabilities for the given capability.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
MLIRContext * getContext() const
Returns the MLIRContext.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, Optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Vendor getVendorID() const
Returns the vendor ID.
AddressingModel getAddressingModel(TargetEnvAttr targetAttr)
Returns addressing model selected based on target environment.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
uint32_t getDeviceID() const
Returns the device ID.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions...
uint32_t getDeviceID() const
Returns the device ID.
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, Optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
ArrayRef< Extension > getImpliedExtensions(Version version)
Returns the implied extensions for the given version.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Version getVersion() const
Returns the target version.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
Version getVersion() const
Vendor getVendorID() const
Returns the vendor ID.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
An attribute that represents a reference to a dense integer vector or tensor object.
VerCapExtAttr::ext_range getExtensions()
Returns the target extensions.