27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
64 std::optional<spirv::StorageClass>
68 return spirv::StorageClass::StorageBuffer;
72 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
75 unsigned memorySpace = intAttr.getInt();
77 #define STORAGE_SPACE_MAP_FN(storage, space) \
81 switch (memorySpace) {
88 #undef STORAGE_SPACE_MAP_FN
91 std::optional<unsigned>
93 #define STORAGE_SPACE_MAP_FN(storage, space) \
97 switch (storageClass) {
104 #undef STORAGE_SPACE_MAP_FN
107 #undef VULKAN_STORAGE_SPACE_MAP_LIST
109 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
110 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
111 MAP_FN(spirv::StorageClass::Generic, 1) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::Image, 7)
118 std::optional<spirv::StorageClass>
121 if (!memorySpaceAttr)
122 return spirv::StorageClass::CrossWorkgroup;
126 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
129 unsigned memorySpace = intAttr.getInt();
131 #define STORAGE_SPACE_MAP_FN(storage, space) \
135 switch (memorySpace) {
142 #undef STORAGE_SPACE_MAP_FN
145 std::optional<unsigned>
147 #define STORAGE_SPACE_MAP_FN(storage, space) \
151 switch (storageClass) {
158 #undef STORAGE_SPACE_MAP_FN
161 #undef OPENCL_STORAGE_SPACE_MAP_LIST
169 : memorySpaceMap(memorySpaceMap) {
174 std::optional<spirv::StorageClass> storage =
177 LLVM_DEBUG(llvm::dbgs()
178 <<
"cannot convert " << memRefType
179 <<
" due to being unable to find memory space in map\n");
185 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
187 rankedType.getLayout(), storageAttr);
193 auto inputs = llvm::map_to_vector(
194 type.getInputs(), [
this](
Type ty) { return convertType(ty); });
195 auto results = llvm::map_to_vector(
196 type.getResults(), [
this](
Type ty) { return convertType(ty); });
208 if (
auto memRefType = dyn_cast<BaseMemRefType>(type)) {
209 Attribute spaceAttr = memRefType.getMemorySpace();
210 return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
218 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
225 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
226 return llvm::all_of(funcOp.getArgumentTypes(),
isLegalType) &&
227 llvm::all_of(funcOp.getResultTypes(),
isLegalType) &&
228 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
233 return attr.getValue();
241 std::unique_ptr<ConversionTarget>
243 auto target = std::make_unique<ConversionTarget>(context);
244 target->markUnknownOpDynamicallyLegal(
isLegalOp);
252 -> std::optional<BaseMemRefType> {
266 class MapMemRefStorageClassPass final
267 :
public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
269 MapMemRefStorageClassPass() =
default;
271 explicit MapMemRefStorageClassPass(
273 : memorySpaceMap(memorySpaceMap) {}
275 LogicalResult initializeOptions(
277 function_ref<LogicalResult(
const Twine &)> errorHandler)
override {
281 if (clientAPI ==
"opencl")
283 else if (clientAPI !=
"vulkan")
284 return errorHandler(llvm::Twine(
"Invalid clienAPI: ") + clientAPI);
289 void runOnOperation()
override {
296 if (targetEnv.allows(spirv::Capability::Kernel)) {
298 }
else if (targetEnv.allows(spirv::Capability::Shader)) {
308 std::unique_ptr<ConversionTarget> target =
311 if (target->isIllegal(childOp)) {
312 childOp->emitOpError(
"failed to legalize memory space");
314 return WalkResult::interrupt();
327 return std::make_unique<MapMemRefStorageClassPass>();
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
virtual LogicalResult initializeOptions(StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)
Attempt to initialize the options of this pass from the given string.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...