MLIR 23.0.0git
TargetAndABI.cpp
Go to the documentation of this file.
1//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/SymbolTable.h"
15#include <optional>
16
17using namespace mlir;
18
19//===----------------------------------------------------------------------===//
20// TargetEnv
21//===----------------------------------------------------------------------===//
22
24 : targetAttr(targetAttr) {
25 givenExtensions.insert_range(targetAttr.getExtensions());
26
27 // Add extensions implied by the current version.
28 givenExtensions.insert_range(
29 spirv::getImpliedExtensions(targetAttr.getVersion()));
30
31 for (spirv::Capability cap : targetAttr.getCapabilities()) {
32 givenCapabilities.insert(cap);
33
34 // Add capabilities implied by the current capability.
35 givenCapabilities.insert_range(spirv::getRecursiveImpliedCapabilities(cap));
36 }
37}
38
39spirv::Version spirv::TargetEnv::getVersion() const {
40 return targetAttr.getVersion();
41}
42
43bool spirv::TargetEnv::allows(spirv::Capability capability) const {
44 return givenCapabilities.count(capability);
45}
46
47std::optional<spirv::Capability>
49 const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
50 return givenCapabilities.count(cap);
51 });
52 if (chosen != caps.end())
53 return *chosen;
54 return std::nullopt;
55}
56
57bool spirv::TargetEnv::allows(spirv::Extension extension) const {
58 return givenExtensions.count(extension);
59}
60
61std::optional<spirv::Extension>
62spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
63 const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
64 return givenExtensions.count(ext);
65 });
66 if (chosen != exts.end())
67 return *chosen;
68 return std::nullopt;
69}
70
71spirv::Vendor spirv::TargetEnv::getVendorID() const {
72 return targetAttr.getVendorID();
73}
74
75spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
76 return targetAttr.getDeviceType();
77}
78
80 return targetAttr.getDeviceID();
81}
82
83spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
84 return targetAttr.getResourceLimits();
85}
86
88 return targetAttr.getContext();
89}
90
91//===----------------------------------------------------------------------===//
92// Utility functions
93//===----------------------------------------------------------------------===//
94
96 return "spirv.interface_var_abi";
97}
98
100spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
101 std::optional<spirv::StorageClass> storageClass,
102 MLIRContext *context) {
103 return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
104 context);
105}
106
108 for (spirv::Capability cap : targetAttr.getCapabilities()) {
109 if (cap == spirv::Capability::Kernel)
110 return false;
111 if (cap == spirv::Capability::Shader)
112 return true;
113 }
114 return false;
115}
116
117StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
118
119spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
120 MLIRContext *context, ArrayRef<int32_t> workgroupSize,
121 std::optional<int> subgroupSize, std::optional<int> targetWidth) {
122 DenseI32ArrayAttr workgroupSizeAttr;
123 if (!workgroupSize.empty()) {
124 assert(workgroupSize.size() == 3);
125 workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
126 }
127 return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
128 targetWidth);
129}
130
131spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
132 while (op && !isa<FunctionOpInterface>(op))
133 op = op->getParentOp();
134 if (!op)
135 return {};
136
137 if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
139 return attr;
140
141 return {};
142}
143
145 if (auto entryPoint = spirv::lookupEntryPointABI(op))
146 return entryPoint.getWorkgroupSize();
147
148 return {};
149}
150
151spirv::ResourceLimitsAttr
153 // All the fields have default values. Here we just provide a nicer way to
154 // construct a default resource limit attribute.
155 Builder b(context);
156 return spirv::ResourceLimitsAttr::get(
157 context,
158 /*max_compute_shared_memory_size=*/16384,
159 /*max_compute_workgroup_invocations=*/128,
160 /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
161 /*subgroup_size=*/32,
162 /*min_subgroup_size=*/std::nullopt,
163 /*max_subgroup_size=*/std::nullopt,
164 /*cooperative_matrix_properties_khr=*/ArrayAttr{},
165 /*cooperative_matrix_properties_nv=*/ArrayAttr{});
166}
167
168StringRef spirv::getLoopControlAttrName() { return "spirv.loop_control"; }
169
171 return "spirv.selection_control";
172}
173
174StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
175
177 auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
178 {spirv::Capability::Shader},
179 ArrayRef<Extension>(), context);
181 triple, spirv::getDefaultResourceLimits(context),
182 spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
183 spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
184}
185
187 while (op) {
189 if (!op)
190 break;
191
192 if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
194 return attr;
195
196 op = op->getParentOp();
197 }
198
199 return {};
200}
201
208
209spirv::AddressingModel
211 bool use64bitAddress) {
212 for (spirv::Capability cap : targetAttr.getCapabilities()) {
213 if (cap == Capability::Kernel)
214 return use64bitAddress ? spirv::AddressingModel::Physical64
215 : spirv::AddressingModel::Physical32;
216 // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
217 // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
218 // and PhysicalStorageBuffer64EXT
219 if (cap == Capability::PhysicalStorageBufferAddresses)
220 return spirv::AddressingModel::PhysicalStorageBuffer64;
221 }
222 // Logical addressing doesn't need any capabilities so return it as default.
223 return spirv::AddressingModel::Logical;
224}
225
226FailureOr<spirv::ExecutionModel>
228 for (spirv::Capability cap : targetAttr.getCapabilities()) {
229 if (cap == spirv::Capability::Kernel)
230 return spirv::ExecutionModel::Kernel;
231 if (cap == spirv::Capability::Shader)
232 return spirv::ExecutionModel::GLCompute;
233 }
234 return failure();
235}
236
237FailureOr<spirv::MemoryModel>
239 for (spirv::Capability cap : targetAttr.getCapabilities()) {
240 if (cap == spirv::Capability::Kernel)
241 return spirv::MemoryModel::OpenCL;
242 if (cap == spirv::Capability::Shader)
243 return spirv::MemoryModel::GLSL450;
244 }
245 return failure();
246}
for(Operation *op :ops)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
DeviceType getDeviceType() const
Returns the device type.
Version getVersion() const
Vendor getVendorID() const
Returns the vendor ID.
bool allows(Capability) const
Returns true if the given capability is allowed.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
TargetEnv(TargetEnvAttr targetAttr)
MLIRContext * getContext() const
Returns the MLIRContext.
uint32_t getDeviceID() const
Returns the device ID.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ArrayRef< Extension > getImpliedExtensions(Version version)
Returns the implied extensions for the given version.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getLoopControlAttrName()
Returns the attribute name for specifying loop control.
EntryPointABIAttr getEntryPointABIAttr(MLIRContext *context, ArrayRef< int32_t > workgroupSize={}, std::optional< int > subgroupSize={}, std::optional< int > targetWidth={})
Gets the EntryPointABIAttr given its fields.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
StringRef getSelectionControlAttrName()
Returns the attribute name for specifying selection control.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr