MLIR  19.0.0git
MapMemRefStorageClassPass.cpp
Go to the documentation of this file.
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Visitors.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
38 
39 using namespace mlir;
40 
41 //===----------------------------------------------------------------------===//
42 // Mappings
43 //===----------------------------------------------------------------------===//
44 
45 /// Mapping between SPIR-V storage classes to memref memory spaces.
46 ///
47 /// Note: memref does not have a defined semantics for each memory space; it
48 /// depends on the context where it is used. There are no particular reasons
49 /// behind the number assignments; we try to follow NVVM conventions and largely
50 /// give common storage classes a smaller number.
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53  MAP_FN(spirv::StorageClass::Generic, 1) \
54  MAP_FN(spirv::StorageClass::Workgroup, 3) \
55  MAP_FN(spirv::StorageClass::Uniform, 4) \
56  MAP_FN(spirv::StorageClass::Private, 5) \
57  MAP_FN(spirv::StorageClass::Function, 6) \
58  MAP_FN(spirv::StorageClass::PushConstant, 7) \
59  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60  MAP_FN(spirv::StorageClass::Input, 9) \
61  MAP_FN(spirv::StorageClass::Output, 10) \
62  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
63 
64 std::optional<spirv::StorageClass>
66  // Handle null memory space attribute specially.
67  if (!memorySpaceAttr)
68  return spirv::StorageClass::StorageBuffer;
69 
70  // Unknown dialect custom attributes are not supported by default.
71  // Downstream callers should plug in more specialized ones.
72  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
73  if (!intAttr)
74  return std::nullopt;
75  unsigned memorySpace = intAttr.getInt();
76 
77 #define STORAGE_SPACE_MAP_FN(storage, space) \
78  case space: \
79  return storage;
80 
81  switch (memorySpace) {
83  default:
84  break;
85  }
86  return std::nullopt;
87 
88 #undef STORAGE_SPACE_MAP_FN
89 }
90 
91 std::optional<unsigned>
92 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
93 #define STORAGE_SPACE_MAP_FN(storage, space) \
94  case storage: \
95  return space;
96 
97  switch (storageClass) {
99  default:
100  break;
101  }
102  return std::nullopt;
103 
104 #undef STORAGE_SPACE_MAP_FN
105 }
106 
107 #undef VULKAN_STORAGE_SPACE_MAP_LIST
108 
109 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
110  MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
111  MAP_FN(spirv::StorageClass::Generic, 1) \
112  MAP_FN(spirv::StorageClass::Workgroup, 3) \
113  MAP_FN(spirv::StorageClass::UniformConstant, 4) \
114  MAP_FN(spirv::StorageClass::Private, 5) \
115  MAP_FN(spirv::StorageClass::Function, 6) \
116  MAP_FN(spirv::StorageClass::Image, 7)
117 
118 std::optional<spirv::StorageClass>
120  // Handle null memory space attribute specially.
121  if (!memorySpaceAttr)
122  return spirv::StorageClass::CrossWorkgroup;
123 
124  // Unknown dialect custom attributes are not supported by default.
125  // Downstream callers should plug in more specialized ones.
126  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
127  if (!intAttr)
128  return std::nullopt;
129  unsigned memorySpace = intAttr.getInt();
130 
131 #define STORAGE_SPACE_MAP_FN(storage, space) \
132  case space: \
133  return storage;
134 
135  switch (memorySpace) {
137  default:
138  break;
139  }
140  return std::nullopt;
141 
142 #undef STORAGE_SPACE_MAP_FN
143 }
144 
145 std::optional<unsigned>
146 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
147 #define STORAGE_SPACE_MAP_FN(storage, space) \
148  case storage: \
149  return space;
150 
151  switch (storageClass) {
153  default:
154  break;
155  }
156  return std::nullopt;
157 
158 #undef STORAGE_SPACE_MAP_FN
159 }
160 
161 #undef OPENCL_STORAGE_SPACE_MAP_LIST
162 
163 //===----------------------------------------------------------------------===//
164 // Type Converter
165 //===----------------------------------------------------------------------===//
166 
168  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
169  : memorySpaceMap(memorySpaceMap) {
170  // Pass through for all other types.
171  addConversion([](Type type) { return type; });
172 
173  addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
174  std::optional<spirv::StorageClass> storage =
175  this->memorySpaceMap(memRefType.getMemorySpace());
176  if (!storage) {
177  LLVM_DEBUG(llvm::dbgs()
178  << "cannot convert " << memRefType
179  << " due to being unable to find memory space in map\n");
180  return std::nullopt;
181  }
182 
183  auto storageAttr =
184  spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
185  if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
186  return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
187  rankedType.getLayout(), storageAttr);
188  }
189  return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
190  });
191 
192  addConversion([this](FunctionType type) {
193  auto inputs = llvm::map_to_vector(
194  type.getInputs(), [this](Type ty) { return convertType(ty); });
195  auto results = llvm::map_to_vector(
196  type.getResults(), [this](Type ty) { return convertType(ty); });
197  return FunctionType::get(type.getContext(), inputs, results);
198  });
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // Conversion Target
203 //===----------------------------------------------------------------------===//
204 
205 /// Returns true if the given `type` is considered as legal for SPIR-V
206 /// conversion.
207 static bool isLegalType(Type type) {
208  if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
209  Attribute spaceAttr = memRefType.getMemorySpace();
210  return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
211  }
212  return true;
213 }
214 
215 /// Returns true if the given `attr` is considered as legal for SPIR-V
216 /// conversion.
217 static bool isLegalAttr(Attribute attr) {
218  if (auto typeAttr = dyn_cast<TypeAttr>(attr))
219  return isLegalType(typeAttr.getValue());
220  return true;
221 }
222 
223 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
224 static bool isLegalOp(Operation *op) {
225  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
226  return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
227  llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
228  llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
229  isLegalType);
230  }
231 
232  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
233  return attr.getValue();
234  });
235 
236  return llvm::all_of(op->getOperandTypes(), isLegalType) &&
237  llvm::all_of(op->getResultTypes(), isLegalType) &&
238  llvm::all_of(attrs, isLegalAttr);
239 }
240 
241 std::unique_ptr<ConversionTarget>
243  auto target = std::make_unique<ConversionTarget>(context);
244  target->markUnknownOpDynamicallyLegal(isLegalOp);
245  return target;
246 }
247 
249  Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
250  AttrTypeReplacer replacer;
251  replacer.addReplacement([&typeConverter](BaseMemRefType origType)
252  -> std::optional<BaseMemRefType> {
253  return typeConverter.convertType<BaseMemRefType>(origType);
254  });
255 
256  replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
257  /*replaceLocs=*/false,
258  /*replaceTypes=*/true);
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Conversion Pass
263 //===----------------------------------------------------------------------===//
264 
265 namespace {
266 class MapMemRefStorageClassPass final
267  : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
268 public:
269  MapMemRefStorageClassPass() = default;
270 
271  explicit MapMemRefStorageClassPass(
272  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
273  : memorySpaceMap(memorySpaceMap) {}
274 
275  LogicalResult initializeOptions(
276  StringRef options,
277  function_ref<LogicalResult(const Twine &)> errorHandler) override {
278  if (failed(Pass::initializeOptions(options, errorHandler)))
279  return failure();
280 
281  if (clientAPI == "opencl")
283  else if (clientAPI != "vulkan")
284  return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
285 
286  return success();
287  }
288 
289  void runOnOperation() override {
290  MLIRContext *context = &getContext();
291  Operation *op = getOperation();
292 
293  spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
295  spirv::TargetEnv targetEnv(attr);
296  if (targetEnv.allows(spirv::Capability::Kernel)) {
298  } else if (targetEnv.allows(spirv::Capability::Shader)) {
300  }
301  }
302 
303  spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
304  // Perform the replacement.
305  spirv::convertMemRefTypesAndAttrs(op, converter);
306 
307  // Check if there are any illegal ops remaining.
308  std::unique_ptr<ConversionTarget> target =
310  op->walk([&target, this](Operation *childOp) {
311  if (target->isIllegal(childOp)) {
312  childOp->emitOpError("failed to legalize memory space");
313  signalPassFailure();
314  return WalkResult::interrupt();
315  }
316  return WalkResult::advance();
317  });
318  }
319 
320 private:
321  spirv::MemorySpaceToStorageClassMap memorySpaceMap =
323 };
324 } // namespace
325 
326 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
327  return std::make_unique<MapMemRefStorageClassPass>();
328 }
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
virtual LogicalResult initializeOptions(StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:63
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
static WalkResult advance()
Definition: Visitors.h:52
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:48
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:26
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26