27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63 MAP_FN(spirv::StorageClass::Image, 12)
65 std::optional<spirv::StorageClass>
69 return spirv::StorageClass::StorageBuffer;
73 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
76 unsigned memorySpace = intAttr.getInt();
78 #define STORAGE_SPACE_MAP_FN(storage, space) \
82 switch (memorySpace) {
89 #undef STORAGE_SPACE_MAP_FN
92 std::optional<unsigned>
94 #define STORAGE_SPACE_MAP_FN(storage, space) \
98 switch (storageClass) {
105 #undef STORAGE_SPACE_MAP_FN
108 #undef VULKAN_STORAGE_SPACE_MAP_LIST
110 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
111 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
112 MAP_FN(spirv::StorageClass::Generic, 1) \
113 MAP_FN(spirv::StorageClass::Workgroup, 3) \
114 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
115 MAP_FN(spirv::StorageClass::Private, 5) \
116 MAP_FN(spirv::StorageClass::Function, 6) \
117 MAP_FN(spirv::StorageClass::Image, 7)
119 std::optional<spirv::StorageClass>
122 if (!memorySpaceAttr)
123 return spirv::StorageClass::CrossWorkgroup;
127 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
130 unsigned memorySpace = intAttr.getInt();
132 #define STORAGE_SPACE_MAP_FN(storage, space) \
136 switch (memorySpace) {
143 #undef STORAGE_SPACE_MAP_FN
146 std::optional<unsigned>
148 #define STORAGE_SPACE_MAP_FN(storage, space) \
152 switch (storageClass) {
159 #undef STORAGE_SPACE_MAP_FN
162 #undef OPENCL_STORAGE_SPACE_MAP_LIST
170 : memorySpaceMap(memorySpaceMap) {
175 std::optional<spirv::StorageClass> storage =
178 LLVM_DEBUG(llvm::dbgs()
179 <<
"cannot convert " << memRefType
180 <<
" due to being unable to find memory space in map\n");
186 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
188 rankedType.getLayout(), storageAttr);
194 auto inputs = llvm::map_to_vector(
195 type.getInputs(), [
this](
Type ty) { return convertType(ty); });
196 auto results = llvm::map_to_vector(
197 type.getResults(), [
this](
Type ty) { return convertType(ty); });
209 if (
auto memRefType = dyn_cast<BaseMemRefType>(type)) {
210 Attribute spaceAttr = memRefType.getMemorySpace();
211 return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
219 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
226 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
227 return llvm::all_of(funcOp.getArgumentTypes(),
isLegalType) &&
228 llvm::all_of(funcOp.getResultTypes(),
isLegalType) &&
229 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
234 return attr.getValue();
242 std::unique_ptr<ConversionTarget>
244 auto target = std::make_unique<ConversionTarget>(context);
245 target->markUnknownOpDynamicallyLegal(
isLegalOp);
253 -> std::optional<BaseMemRefType> {
267 class MapMemRefStorageClassPass final
268 :
public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
270 MapMemRefStorageClassPass() =
default;
272 explicit MapMemRefStorageClassPass(
274 : memorySpaceMap(memorySpaceMap) {}
276 LogicalResult initializeOptions(
278 function_ref<LogicalResult(
const Twine &)> errorHandler)
override {
282 if (clientAPI ==
"opencl")
284 else if (clientAPI !=
"vulkan")
285 return errorHandler(llvm::Twine(
"Invalid clienAPI: ") + clientAPI);
290 void runOnOperation()
override {
297 if (targetEnv.allows(spirv::Capability::Kernel)) {
299 }
else if (targetEnv.allows(spirv::Capability::Shader)) {
309 std::unique_ptr<ConversionTarget> target =
312 if (target->isIllegal(childOp)) {
313 childOp->emitOpError(
"failed to legalize memory space");
315 return WalkResult::interrupt();
328 return std::make_unique<MapMemRefStorageClassPass>();
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
virtual LogicalResult initializeOptions(StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)
Attempt to initialize the options of this pass from the given string.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...