25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
30 #include "mlir/Conversion/Passes.h.inc"
33 #define DEBUG_TYPE "mlir-map-memref-storage-class"
47 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
48 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
49 MAP_FN(spirv::StorageClass::Generic, 1) \
50 MAP_FN(spirv::StorageClass::Workgroup, 3) \
51 MAP_FN(spirv::StorageClass::Uniform, 4) \
52 MAP_FN(spirv::StorageClass::Private, 5) \
53 MAP_FN(spirv::StorageClass::Function, 6) \
54 MAP_FN(spirv::StorageClass::PushConstant, 7) \
55 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
56 MAP_FN(spirv::StorageClass::Input, 9) \
57 MAP_FN(spirv::StorageClass::Output, 10)
59 std::optional<spirv::StorageClass>
63 return spirv::StorageClass::StorageBuffer;
67 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
70 unsigned memorySpace = intAttr.getInt();
72 #define STORAGE_SPACE_MAP_FN(storage, space) \
76 switch (memorySpace) {
83 #undef STORAGE_SPACE_MAP_FN
86 std::optional<unsigned>
88 #define STORAGE_SPACE_MAP_FN(storage, space) \
92 switch (storageClass) {
99 #undef STORAGE_SPACE_MAP_FN
102 #undef VULKAN_STORAGE_SPACE_MAP_LIST
104 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
105 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
106 MAP_FN(spirv::StorageClass::Generic, 1) \
107 MAP_FN(spirv::StorageClass::Workgroup, 3) \
108 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
109 MAP_FN(spirv::StorageClass::Private, 5) \
110 MAP_FN(spirv::StorageClass::Function, 6) \
111 MAP_FN(spirv::StorageClass::Image, 7)
113 std::optional<spirv::StorageClass>
116 if (!memorySpaceAttr)
117 return spirv::StorageClass::CrossWorkgroup;
121 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
124 unsigned memorySpace = intAttr.getInt();
126 #define STORAGE_SPACE_MAP_FN(storage, space) \
130 switch (memorySpace) {
137 #undef STORAGE_SPACE_MAP_FN
140 std::optional<unsigned>
142 #define STORAGE_SPACE_MAP_FN(storage, space) \
146 switch (storageClass) {
153 #undef STORAGE_SPACE_MAP_FN
156 #undef OPENCL_STORAGE_SPACE_MAP_LIST
164 : memorySpaceMap(memorySpaceMap) {
169 std::optional<spirv::StorageClass> storage =
172 LLVM_DEBUG(llvm::dbgs()
173 <<
"cannot convert " << memRefType
174 <<
" due to being unable to find memory space in map\n");
180 if (
auto rankedType = dyn_cast<MemRefType>(memRefType)) {
182 rankedType.getLayout(), storageAttr);
189 inputs.reserve(type.getNumInputs());
190 results.reserve(type.getNumResults());
191 for (
Type input : type.getInputs())
193 for (
Type result : type.getResults())
206 if (
auto memRefType = dyn_cast<BaseMemRefType>(type)) {
207 Attribute spaceAttr = memRefType.getMemorySpace();
208 return spaceAttr && isa<spirv::StorageClassAttr>(spaceAttr);
216 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
223 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
224 return llvm::all_of(funcOp.getArgumentTypes(),
isLegalType) &&
225 llvm::all_of(funcOp.getResultTypes(),
isLegalType) &&
226 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
231 return attr.getValue();
239 std::unique_ptr<ConversionTarget>
241 auto target = std::make_unique<ConversionTarget>(context);
242 target->markUnknownOpDynamicallyLegal(
isLegalOp);
267 newAttrs.reserve(op->
getAttrs().size());
269 if (
auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
270 auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
271 newAttrs.emplace_back(attr.getName(),
TypeAttr::get(newAttr));
273 newAttrs.push_back(attr);
278 (void)getTypeConverter()->convertTypes(op->
getResultTypes(), newResults);
284 Region *newRegion = state.addRegion();
287 (void)getTypeConverter()->convertSignatureArgs(
300 patterns.
add<MapMemRefStoragePattern>(patterns.
getContext(), typeConverter);
308 class MapMemRefStorageClassPass final
309 :
public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
311 explicit MapMemRefStorageClassPass() {
314 explicit MapMemRefStorageClassPass(
316 : memorySpaceMap(memorySpaceMap) {}
320 void runOnOperation()
override;
331 if (clientAPI ==
"opencl") {
335 if (clientAPI !=
"vulkan" && clientAPI !=
"opencl")
341 void MapMemRefStorageClassPass::runOnOperation() {
347 if (targetEnv.allows(spirv::Capability::Kernel)) {
349 }
else if (targetEnv.allows(spirv::Capability::Shader)) {
361 return signalPassFailure();
365 return std::make_unique<MapMemRefStorageClassPass>();
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
Base class for the conversion patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
void populateMemorySpaceToStorageClassPatterns(MemorySpaceToStorageClassConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for converting numeric MemRef memory spaces into SPIR-V...
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.