MLIR  18.0.0git
MapMemRefStorageClassPass.cpp
Go to the documentation of this file.
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/BuiltinTypes.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "mlir-map-memref-storage-class"
34 
35 using namespace mlir;
36 
37 //===----------------------------------------------------------------------===//
38 // Mappings
39 //===----------------------------------------------------------------------===//
40 
41 /// Mapping between SPIR-V storage classes to memref memory spaces.
42 ///
43 /// Note: memref does not have a defined semantics for each memory space; it
44 /// depends on the context where it is used. There are no particular reasons
45 /// behind the number assignments; we try to follow NVVM conventions and largely
46 /// give common storage classes a smaller number.
47 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
48  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
49  MAP_FN(spirv::StorageClass::Generic, 1) \
50  MAP_FN(spirv::StorageClass::Workgroup, 3) \
51  MAP_FN(spirv::StorageClass::Uniform, 4) \
52  MAP_FN(spirv::StorageClass::Private, 5) \
53  MAP_FN(spirv::StorageClass::Function, 6) \
54  MAP_FN(spirv::StorageClass::PushConstant, 7) \
55  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
56  MAP_FN(spirv::StorageClass::Input, 9) \
57  MAP_FN(spirv::StorageClass::Output, 10)
58 
59 std::optional<spirv::StorageClass>
61  // Handle null memory space attribute specially.
62  if (!memorySpaceAttr)
63  return spirv::StorageClass::StorageBuffer;
64 
65  // Unknown dialect custom attributes are not supported by default.
66  // Downstream callers should plug in more specialized ones.
67  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
68  if (!intAttr)
69  return std::nullopt;
70  unsigned memorySpace = intAttr.getInt();
71 
72 #define STORAGE_SPACE_MAP_FN(storage, space) \
73  case space: \
74  return storage;
75 
76  switch (memorySpace) {
78  default:
79  break;
80  }
81  return std::nullopt;
82 
83 #undef STORAGE_SPACE_MAP_FN
84 }
85 
86 std::optional<unsigned>
87 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
88 #define STORAGE_SPACE_MAP_FN(storage, space) \
89  case storage: \
90  return space;
91 
92  switch (storageClass) {
94  default:
95  break;
96  }
97  return std::nullopt;
98 
99 #undef STORAGE_SPACE_MAP_FN
100 }
101 
102 #undef VULKAN_STORAGE_SPACE_MAP_LIST
103 
104 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
105  MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
106  MAP_FN(spirv::StorageClass::Generic, 1) \
107  MAP_FN(spirv::StorageClass::Workgroup, 3) \
108  MAP_FN(spirv::StorageClass::UniformConstant, 4) \
109  MAP_FN(spirv::StorageClass::Private, 5) \
110  MAP_FN(spirv::StorageClass::Function, 6) \
111  MAP_FN(spirv::StorageClass::Image, 7)
112 
113 std::optional<spirv::StorageClass>
115  // Handle null memory space attribute specially.
116  if (!memorySpaceAttr)
117  return spirv::StorageClass::CrossWorkgroup;
118 
119  // Unknown dialect custom attributes are not supported by default.
120  // Downstream callers should plug in more specialized ones.
121  auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
122  if (!intAttr)
123  return std::nullopt;
124  unsigned memorySpace = intAttr.getInt();
125 
126 #define STORAGE_SPACE_MAP_FN(storage, space) \
127  case space: \
128  return storage;
129 
130  switch (memorySpace) {
132  default:
133  break;
134  }
135  return std::nullopt;
136 
137 #undef STORAGE_SPACE_MAP_FN
138 }
139 
140 std::optional<unsigned>
141 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
142 #define STORAGE_SPACE_MAP_FN(storage, space) \
143  case storage: \
144  return space;
145 
146  switch (storageClass) {
148  default:
149  break;
150  }
151  return std::nullopt;
152 
153 #undef STORAGE_SPACE_MAP_FN
154 }
155 
156 #undef OPENCL_STORAGE_SPACE_MAP_LIST
157 
158 //===----------------------------------------------------------------------===//
159 // Type Converter
160 //===----------------------------------------------------------------------===//
161 
163  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
164  : memorySpaceMap(memorySpaceMap) {
165  // Pass through for all other types.
166  addConversion([](Type type) { return type; });
167 
168  addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
169  std::optional<spirv::StorageClass> storage =
170  this->memorySpaceMap(memRefType.getMemorySpace());
171  if (!storage) {
172  LLVM_DEBUG(llvm::dbgs()
173  << "cannot convert " << memRefType
174  << " due to being unable to find memory space in map\n");
175  return std::nullopt;
176  }
177 
178  auto storageAttr =
179  spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
180  if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
181  return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
182  rankedType.getLayout(), storageAttr);
183  }
184  return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
185  });
186 
187  addConversion([this](FunctionType type) {
188  SmallVector<Type> inputs, results;
189  inputs.reserve(type.getNumInputs());
190  results.reserve(type.getNumResults());
191  for (Type input : type.getInputs())
192  inputs.push_back(convertType(input));
193  for (Type result : type.getResults())
194  results.push_back(convertType(result));
195  return FunctionType::get(type.getContext(), inputs, results);
196  });
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // Conversion Target
201 //===----------------------------------------------------------------------===//
202 
203 /// Returns true if the given `type` is considered as legal for SPIR-V
204 /// conversion.
205 static bool isLegalType(Type type) {
206  if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
207  Attribute spaceAttr = memRefType.getMemorySpace();
208  return spaceAttr && isa<spirv::StorageClassAttr>(spaceAttr);
209  }
210  return true;
211 }
212 
213 /// Returns true if the given `attr` is considered as legal for SPIR-V
214 /// conversion.
215 static bool isLegalAttr(Attribute attr) {
216  if (auto typeAttr = dyn_cast<TypeAttr>(attr))
217  return isLegalType(typeAttr.getValue());
218  return true;
219 }
220 
221 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
222 static bool isLegalOp(Operation *op) {
223  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
224  return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
225  llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
226  llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
227  isLegalType);
228  }
229 
230  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
231  return attr.getValue();
232  });
233 
234  return llvm::all_of(op->getOperandTypes(), isLegalType) &&
235  llvm::all_of(op->getResultTypes(), isLegalType) &&
236  llvm::all_of(attrs, isLegalAttr);
237 }
238 
239 std::unique_ptr<ConversionTarget>
241  auto target = std::make_unique<ConversionTarget>(context);
242  target->markUnknownOpDynamicallyLegal(isLegalOp);
243  return target;
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // Conversion Pattern
248 //===----------------------------------------------------------------------===//
249 
250 namespace {
251 /// Converts any op that has operands/results/attributes with numeric MemRef
252 /// memory spaces.
253 struct MapMemRefStoragePattern final : public ConversionPattern {
254  MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
255  : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
256 
258  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
259  ConversionPatternRewriter &rewriter) const override;
260 };
261 } // namespace
262 
263 LogicalResult MapMemRefStoragePattern::matchAndRewrite(
264  Operation *op, ArrayRef<Value> operands,
265  ConversionPatternRewriter &rewriter) const {
267  newAttrs.reserve(op->getAttrs().size());
268  for (auto attr : op->getAttrs()) {
269  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
270  auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
271  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
272  } else {
273  newAttrs.push_back(attr);
274  }
275  }
276 
277  llvm::SmallVector<Type, 4> newResults;
278  (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
279 
280  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
281  newResults, newAttrs, op->getSuccessors());
282 
283  for (Region &region : op->getRegions()) {
284  Region *newRegion = state.addRegion();
285  rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
287  (void)getTypeConverter()->convertSignatureArgs(
288  newRegion->getArgumentTypes(), result);
289  rewriter.applySignatureConversion(newRegion, result);
290  }
291 
292  Operation *newOp = rewriter.create(state);
293  rewriter.replaceOp(op, newOp->getResults());
294  return success();
295 }
296 
299  RewritePatternSet &patterns) {
300  patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Conversion Pass
305 //===----------------------------------------------------------------------===//
306 
307 namespace {
308 class MapMemRefStorageClassPass final
309  : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
310 public:
311  explicit MapMemRefStorageClassPass() {
313  }
314  explicit MapMemRefStorageClassPass(
315  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
316  : memorySpaceMap(memorySpaceMap) {}
317 
318  LogicalResult initializeOptions(StringRef options) override;
319 
320  void runOnOperation() override;
321 
322 private:
324 };
325 } // namespace
326 
327 LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
329  return failure();
330 
331  if (clientAPI == "opencl") {
333  }
334 
335  if (clientAPI != "vulkan" && clientAPI != "opencl")
336  return failure();
337 
338  return success();
339 }
340 
341 void MapMemRefStorageClassPass::runOnOperation() {
342  MLIRContext *context = &getContext();
343  Operation *op = getOperation();
344 
346  spirv::TargetEnv targetEnv(attr);
347  if (targetEnv.allows(spirv::Capability::Kernel)) {
349  } else if (targetEnv.allows(spirv::Capability::Shader)) {
351  }
352  }
353 
354  auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
355  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
356 
357  RewritePatternSet patterns(context);
359 
360  if (failed(applyFullConversion(op, *target, std::move(patterns))))
361  return signalPassFailure();
362 }
363 
364 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
365  return std::make_unique<MapMemRefStorageClassPass>();
366 }
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
Base class for the conversion patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
SuccessorRange getSuccessors()
Definition: Operation.h:682
result_range getResults()
Definition: Operation.h:410
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:53
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
Definition: Region.cpp:36
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:48
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
void populateMemorySpaceToStorageClassPatterns(MemorySpaceToStorageClassConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for converting numeric MemRef memory spaces into SPIR-V...
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:26
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
std::optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.