MLIR 22.0.0git
MapMemRefStorageClassPass.cpp
Go to the documentation of this file.
1//===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to map numeric MemRef memory spaces to
10// symbolic ones defined in the SPIR-V specification.
11//
12//===----------------------------------------------------------------------===//
13
15
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/Visitors.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34#include "mlir/Conversion/Passes.h.inc"
35} // namespace mlir
36
37#define DEBUG_TYPE "mlir-map-memref-storage-class"
38
39using namespace mlir;
40
41//===----------------------------------------------------------------------===//
42// Mappings
43//===----------------------------------------------------------------------===//
44
45/// Mapping between SPIR-V storage classes to memref memory spaces.
46///
47/// Note: memref does not have a defined semantics for each memory space; it
48/// depends on the context where it is used. There are no particular reasons
49/// behind the number assignments; we try to follow NVVM conventions and largely
50/// give common storage classes a smaller number.
51#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63 MAP_FN(spirv::StorageClass::Image, 12)
64
65std::optional<spirv::StorageClass>
67 // Handle null memory space attribute specially.
68 if (!memorySpaceAttr)
69 return spirv::StorageClass::StorageBuffer;
70
71 // Unknown dialect custom attributes are not supported by default.
72 // Downstream callers should plug in more specialized ones.
73 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
74 if (!intAttr)
75 return std::nullopt;
76 unsigned memorySpace = intAttr.getInt();
77
78#define STORAGE_SPACE_MAP_FN(storage, space) \
79 case space: \
80 return storage;
81
82 switch (memorySpace) {
84 default:
85 break;
86 }
87 return std::nullopt;
88
89#undef STORAGE_SPACE_MAP_FN
90}
91
92std::optional<unsigned>
93spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
94#define STORAGE_SPACE_MAP_FN(storage, space) \
95 case storage: \
96 return space;
97
98 switch (storageClass) {
100 default:
101 break;
102 }
103 return std::nullopt;
104
105#undef STORAGE_SPACE_MAP_FN
106}
107
108#undef VULKAN_STORAGE_SPACE_MAP_LIST
109
110#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
111 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
112 MAP_FN(spirv::StorageClass::Generic, 1) \
113 MAP_FN(spirv::StorageClass::Workgroup, 3) \
114 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
115 MAP_FN(spirv::StorageClass::Private, 5) \
116 MAP_FN(spirv::StorageClass::Function, 6) \
117 MAP_FN(spirv::StorageClass::Image, 7)
118
119std::optional<spirv::StorageClass>
121 // Handle null memory space attribute specially.
122 if (!memorySpaceAttr)
123 return spirv::StorageClass::CrossWorkgroup;
124
125 // Unknown dialect custom attributes are not supported by default.
126 // Downstream callers should plug in more specialized ones.
127 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
128 if (!intAttr)
129 return std::nullopt;
130 unsigned memorySpace = intAttr.getInt();
131
132#define STORAGE_SPACE_MAP_FN(storage, space) \
133 case space: \
134 return storage;
135
136 switch (memorySpace) {
138 default:
139 break;
140 }
141 return std::nullopt;
142
143#undef STORAGE_SPACE_MAP_FN
144}
145
146std::optional<unsigned>
147spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
148#define STORAGE_SPACE_MAP_FN(storage, space) \
149 case storage: \
150 return space;
151
152 switch (storageClass) {
154 default:
155 break;
156 }
157 return std::nullopt;
158
159#undef STORAGE_SPACE_MAP_FN
160}
161
162#undef OPENCL_STORAGE_SPACE_MAP_LIST
163
164//===----------------------------------------------------------------------===//
165// Type Converter
166//===----------------------------------------------------------------------===//
167
169 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
170 : memorySpaceMap(memorySpaceMap) {
171 // Pass through for all other types.
172 addConversion([](Type type) { return type; });
173
174 addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
175 std::optional<spirv::StorageClass> storage =
176 this->memorySpaceMap(memRefType.getMemorySpace());
177 if (!storage) {
178 LLVM_DEBUG(llvm::dbgs()
179 << "cannot convert " << memRefType
180 << " due to being unable to find memory space in map\n");
181 return std::nullopt;
182 }
183
184 auto storageAttr =
185 spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
186 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
187 return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
188 rankedType.getLayout(), storageAttr);
189 }
190 return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
191 });
192
193 addConversion([this](FunctionType type) {
194 auto inputs = llvm::map_to_vector(
195 type.getInputs(), [this](Type ty) { return convertType(ty); });
196 auto results = llvm::map_to_vector(
197 type.getResults(), [this](Type ty) { return convertType(ty); });
198 return FunctionType::get(type.getContext(), inputs, results);
199 });
200}
201
202//===----------------------------------------------------------------------===//
203// Conversion Target
204//===----------------------------------------------------------------------===//
205
206/// Returns true if the given `type` is considered as legal for SPIR-V
207/// conversion.
208static bool isLegalType(Type type) {
209 if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
210 Attribute spaceAttr = memRefType.getMemorySpace();
211 return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
212 }
213 return true;
214}
215
216/// Returns true if the given `attr` is considered as legal for SPIR-V
217/// conversion.
218static bool isLegalAttr(Attribute attr) {
219 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
220 return isLegalType(typeAttr.getValue());
221 return true;
222}
223
224/// Returns true if the given `op` is considered as legal for SPIR-V conversion.
225static bool isLegalOp(Operation *op) {
226 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
227 return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
228 llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
229 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
231 }
232
233 auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
234 return attr.getValue();
235 });
236
237 return llvm::all_of(op->getOperandTypes(), isLegalType) &&
238 llvm::all_of(op->getResultTypes(), isLegalType) &&
239 llvm::all_of(attrs, isLegalAttr);
240}
241
242std::unique_ptr<ConversionTarget>
244 auto target = std::make_unique<ConversionTarget>(context);
245 target->markUnknownOpDynamicallyLegal(isLegalOp);
246 return target;
247}
248
250 Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
251 AttrTypeReplacer replacer;
252 replacer.addReplacement([&typeConverter](BaseMemRefType origType)
253 -> std::optional<BaseMemRefType> {
254 return typeConverter.convertType<BaseMemRefType>(origType);
255 });
256
257 replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
258 /*replaceLocs=*/false,
259 /*replaceTypes=*/true);
260}
261
262//===----------------------------------------------------------------------===//
263// Conversion Pass
264//===----------------------------------------------------------------------===//
265
266namespace {
267class MapMemRefStorageClassPass final
268 : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
269public:
270 MapMemRefStorageClassPass() = default;
271
272 explicit MapMemRefStorageClassPass(
273 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
274 : memorySpaceMap(memorySpaceMap) {}
275
276 LogicalResult initializeOptions(
277 StringRef options,
278 function_ref<LogicalResult(const Twine &)> errorHandler) override {
279 if (failed(Pass::initializeOptions(options, errorHandler)))
280 return failure();
281
282 if (clientAPI == "opencl")
284 else if (clientAPI != "vulkan")
285 return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
286
287 return success();
288 }
289
290 void runOnOperation() override {
291 MLIRContext *context = &getContext();
292 Operation *op = getOperation();
293
294 spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
295 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
296 spirv::TargetEnv targetEnv(attr);
297 if (targetEnv.allows(spirv::Capability::Kernel)) {
299 } else if (targetEnv.allows(spirv::Capability::Shader)) {
301 }
302 }
303
304 spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
305 // Perform the replacement.
307
308 // Check if there are any illegal ops remaining.
309 std::unique_ptr<ConversionTarget> target =
311 op->walk([&target, this](Operation *childOp) {
312 if (target->isIllegal(childOp)) {
313 childOp->emitOpError("failed to legalize memory space");
314 signalPassFailure();
315 return WalkResult::interrupt();
316 }
317 return WalkResult::advance();
318 });
319 }
320
321private:
324};
325} // namespace
326
327std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
328 return std::make_unique<MapMemRefStorageClassPass>();
329}
return success()
b getContext())
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
virtual LogicalResult initializeOptions(StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)
Attempt to initialize the options of this pass from the given string.
Definition Pass.cpp:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
AttrTypeReplacerBase.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152