MLIR  22.0.0git
MapMemRefStorageClassPass.cpp
Go to the documentation of this file.
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Visitors.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
38 
39 using namespace mlir;
40 
41 //===----------------------------------------------------------------------===//
42 // Mappings
43 //===----------------------------------------------------------------------===//
44 
45 /// Mapping between SPIR-V storage classes to memref memory spaces.
46 ///
47 /// Note: memref does not have a defined semantics for each memory space; it
48 /// depends on the context where it is used. There are no particular reasons
49 /// behind the number assignments; we try to follow NVVM conventions and largely
50 /// give common storage classes a smaller number.
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53  MAP_FN(spirv::StorageClass::Generic, 1) \
54  MAP_FN(spirv::StorageClass::Workgroup, 3) \
55  MAP_FN(spirv::StorageClass::Uniform, 4) \
56  MAP_FN(spirv::StorageClass::Private, 5) \
57  MAP_FN(spirv::StorageClass::Function, 6) \
58  MAP_FN(spirv::StorageClass::PushConstant, 7) \
59  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60  MAP_FN(spirv::StorageClass::Input, 9) \
61  MAP_FN(spirv::StorageClass::Output, 10) \
62  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63  MAP_FN(spirv::StorageClass::Image, 12)
64 
65 std::optional<spirv::StorageClass>
67  // Handle null memory space attribute specially.
68  if (!memorySpaceAttr)
69  return spirv::StorageClass::StorageBuffer;
70 
71  // Unknown dialect custom attributes are not supported by default.
72  // Downstream callers should plug in more specialized ones.
73  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
74  if (!intAttr)
75  return std::nullopt;
76  unsigned memorySpace = intAttr.getInt();
77 
78 #define STORAGE_SPACE_MAP_FN(storage, space) \
79  case space: \
80  return storage;
81 
82  switch (memorySpace) {
84  default:
85  break;
86  }
87  return std::nullopt;
88 
89 #undef STORAGE_SPACE_MAP_FN
90 }
91 
92 std::optional<unsigned>
93 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
94 #define STORAGE_SPACE_MAP_FN(storage, space) \
95  case storage: \
96  return space;
97 
98  switch (storageClass) {
100  default:
101  break;
102  }
103  return std::nullopt;
104 
105 #undef STORAGE_SPACE_MAP_FN
106 }
107 
108 #undef VULKAN_STORAGE_SPACE_MAP_LIST
109 
110 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
111  MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
112  MAP_FN(spirv::StorageClass::Generic, 1) \
113  MAP_FN(spirv::StorageClass::Workgroup, 3) \
114  MAP_FN(spirv::StorageClass::UniformConstant, 4) \
115  MAP_FN(spirv::StorageClass::Private, 5) \
116  MAP_FN(spirv::StorageClass::Function, 6) \
117  MAP_FN(spirv::StorageClass::Image, 7)
118 
119 std::optional<spirv::StorageClass>
121  // Handle null memory space attribute specially.
122  if (!memorySpaceAttr)
123  return spirv::StorageClass::CrossWorkgroup;
124 
125  // Unknown dialect custom attributes are not supported by default.
126  // Downstream callers should plug in more specialized ones.
127  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
128  if (!intAttr)
129  return std::nullopt;
130  unsigned memorySpace = intAttr.getInt();
131 
132 #define STORAGE_SPACE_MAP_FN(storage, space) \
133  case space: \
134  return storage;
135 
136  switch (memorySpace) {
138  default:
139  break;
140  }
141  return std::nullopt;
142 
143 #undef STORAGE_SPACE_MAP_FN
144 }
145 
146 std::optional<unsigned>
147 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
148 #define STORAGE_SPACE_MAP_FN(storage, space) \
149  case storage: \
150  return space;
151 
152  switch (storageClass) {
154  default:
155  break;
156  }
157  return std::nullopt;
158 
159 #undef STORAGE_SPACE_MAP_FN
160 }
161 
162 #undef OPENCL_STORAGE_SPACE_MAP_LIST
163 
164 //===----------------------------------------------------------------------===//
165 // Type Converter
166 //===----------------------------------------------------------------------===//
167 
169  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
170  : memorySpaceMap(memorySpaceMap) {
171  // Pass through for all other types.
172  addConversion([](Type type) { return type; });
173 
174  addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
175  std::optional<spirv::StorageClass> storage =
176  this->memorySpaceMap(memRefType.getMemorySpace());
177  if (!storage) {
178  LLVM_DEBUG(llvm::dbgs()
179  << "cannot convert " << memRefType
180  << " due to being unable to find memory space in map\n");
181  return std::nullopt;
182  }
183 
184  auto storageAttr =
185  spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
186  if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
187  return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
188  rankedType.getLayout(), storageAttr);
189  }
190  return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
191  });
192 
193  addConversion([this](FunctionType type) {
194  auto inputs = llvm::map_to_vector(
195  type.getInputs(), [this](Type ty) { return convertType(ty); });
196  auto results = llvm::map_to_vector(
197  type.getResults(), [this](Type ty) { return convertType(ty); });
198  return FunctionType::get(type.getContext(), inputs, results);
199  });
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Conversion Target
204 //===----------------------------------------------------------------------===//
205 
206 /// Returns true if the given `type` is considered as legal for SPIR-V
207 /// conversion.
208 static bool isLegalType(Type type) {
209  if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
210  Attribute spaceAttr = memRefType.getMemorySpace();
211  return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
212  }
213  return true;
214 }
215 
216 /// Returns true if the given `attr` is considered as legal for SPIR-V
217 /// conversion.
218 static bool isLegalAttr(Attribute attr) {
219  if (auto typeAttr = dyn_cast<TypeAttr>(attr))
220  return isLegalType(typeAttr.getValue());
221  return true;
222 }
223 
224 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
225 static bool isLegalOp(Operation *op) {
226  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
227  return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
228  llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
229  llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
230  isLegalType);
231  }
232 
233  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
234  return attr.getValue();
235  });
236 
237  return llvm::all_of(op->getOperandTypes(), isLegalType) &&
238  llvm::all_of(op->getResultTypes(), isLegalType) &&
239  llvm::all_of(attrs, isLegalAttr);
240 }
241 
242 std::unique_ptr<ConversionTarget>
244  auto target = std::make_unique<ConversionTarget>(context);
245  target->markUnknownOpDynamicallyLegal(isLegalOp);
246  return target;
247 }
248 
250  Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
251  AttrTypeReplacer replacer;
252  replacer.addReplacement([&typeConverter](BaseMemRefType origType)
253  -> std::optional<BaseMemRefType> {
254  return typeConverter.convertType<BaseMemRefType>(origType);
255  });
256 
257  replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
258  /*replaceLocs=*/false,
259  /*replaceTypes=*/true);
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Conversion Pass
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 class MapMemRefStorageClassPass final
268  : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
269 public:
270  MapMemRefStorageClassPass() = default;
271 
272  explicit MapMemRefStorageClassPass(
273  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
274  : memorySpaceMap(memorySpaceMap) {}
275 
276  LogicalResult initializeOptions(
277  StringRef options,
278  function_ref<LogicalResult(const Twine &)> errorHandler) override {
279  if (failed(Pass::initializeOptions(options, errorHandler)))
280  return failure();
281 
282  if (clientAPI == "opencl")
284  else if (clientAPI != "vulkan")
285  return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
286 
287  return success();
288  }
289 
290  void runOnOperation() override {
291  MLIRContext *context = &getContext();
292  Operation *op = getOperation();
293 
294  spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
296  spirv::TargetEnv targetEnv(attr);
297  if (targetEnv.allows(spirv::Capability::Kernel)) {
299  } else if (targetEnv.allows(spirv::Capability::Shader)) {
301  }
302  }
303 
304  spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
305  // Perform the replacement.
306  spirv::convertMemRefTypesAndAttrs(op, converter);
307 
308  // Check if there are any illegal ops remaining.
309  std::unique_ptr<ConversionTarget> target =
311  op->walk([&target, this](Operation *childOp) {
312  if (target->isIllegal(childOp)) {
313  childOp->emitOpError("failed to legalize memory space");
314  signalPassFailure();
315  return WalkResult::interrupt();
316  }
317  return WalkResult::advance();
318  });
319  }
320 
321 private:
322  spirv::MemorySpaceToStorageClassMap memorySpaceMap =
324 };
325 } // namespace
326 
327 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
328  return std::make_unique<MapMemRefStorageClassPass>();
329 }
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
This is an attribute/type replacer that is naively cached.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
virtual LogicalResult initializeOptions(StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:62
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
static WalkResult advance()
Definition: WalkResult.h:47
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:48
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:26
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...