MLIR  16.0.0git
MapMemRefStorageClassPass.cpp
Go to the documentation of this file.
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/BuiltinTypes.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "mlir-map-memref-storage-class"
34 
35 using namespace mlir;
36 
37 //===----------------------------------------------------------------------===//
38 // Mappings
39 //===----------------------------------------------------------------------===//
40 
41 /// Mapping between SPIR-V storage classes to memref memory spaces.
42 ///
43 /// Note: memref does not have a defined semantics for each memory space; it
44 /// depends on the context where it is used. There are no particular reasons
45 /// behind the number assignments; we try to follow NVVM conventions and largely
46 /// give common storage classes a smaller number.
47 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
48  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
49  MAP_FN(spirv::StorageClass::Generic, 1) \
50  MAP_FN(spirv::StorageClass::Workgroup, 3) \
51  MAP_FN(spirv::StorageClass::Uniform, 4) \
52  MAP_FN(spirv::StorageClass::Private, 5) \
53  MAP_FN(spirv::StorageClass::Function, 6) \
54  MAP_FN(spirv::StorageClass::PushConstant, 7) \
55  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
56  MAP_FN(spirv::StorageClass::Input, 9) \
57  MAP_FN(spirv::StorageClass::Output, 10)
58 
61  // Handle null memory space attribute specially.
62  if (!memorySpaceAttr)
63  return spirv::StorageClass::StorageBuffer;
64 
65  // Unknown dialect custom attributes are not supported by default.
66  // Downstream callers should plug in more specialized ones.
67  auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
68  if (!intAttr)
69  return std::nullopt;
70  unsigned memorySpace = intAttr.getInt();
71 
72 #define STORAGE_SPACE_MAP_FN(storage, space) \
73  case space: \
74  return storage;
75 
76  switch (memorySpace) {
78  default:
79  break;
80  }
81  return std::nullopt;
82 
83 #undef STORAGE_SPACE_MAP_FN
84 }
85 
87 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
88 #define STORAGE_SPACE_MAP_FN(storage, space) \
89  case storage: \
90  return space;
91 
92  switch (storageClass) {
94  default:
95  break;
96  }
97  return std::nullopt;
98 
99 #undef STORAGE_SPACE_MAP_FN
100 }
101 
102 #undef VULKAN_STORAGE_SPACE_MAP_LIST
103 
104 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
105  MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
106  MAP_FN(spirv::StorageClass::Generic, 1) \
107  MAP_FN(spirv::StorageClass::Workgroup, 3) \
108  MAP_FN(spirv::StorageClass::UniformConstant, 4) \
109  MAP_FN(spirv::StorageClass::Private, 5) \
110  MAP_FN(spirv::StorageClass::Function, 6) \
111  MAP_FN(spirv::StorageClass::Image, 7)
112 
115  // Handle null memory space attribute specially.
116  if (!memorySpaceAttr)
117  return spirv::StorageClass::CrossWorkgroup;
118 
119  // Unknown dialect custom attributes are not supported by default.
120  // Downstream callers should plug in more specialized ones.
121  auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
122  if (!intAttr)
123  return std::nullopt;
124  unsigned memorySpace = intAttr.getInt();
125 
126 #define STORAGE_SPACE_MAP_FN(storage, space) \
127  case space: \
128  return storage;
129 
130  switch (memorySpace) {
132  default:
133  break;
134  }
135  return std::nullopt;
136 
137 #undef STORAGE_SPACE_MAP_FN
138 }
139 
141 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
142 #define STORAGE_SPACE_MAP_FN(storage, space) \
143  case storage: \
144  return space;
145 
146  switch (storageClass) {
148  default:
149  break;
150  }
151  return std::nullopt;
152 
153 #undef STORAGE_SPACE_MAP_FN
154 }
155 
156 #undef OPENCL_STORAGE_SPACE_MAP_LIST
157 
158 //===----------------------------------------------------------------------===//
159 // Type Converter
160 //===----------------------------------------------------------------------===//
161 
163  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
164  : memorySpaceMap(memorySpaceMap) {
165  // Pass through for all other types.
166  addConversion([](Type type) { return type; });
167 
168  addConversion([this](BaseMemRefType memRefType) -> Optional<Type> {
170  this->memorySpaceMap(memRefType.getMemorySpace());
171  if (!storage) {
172  LLVM_DEBUG(llvm::dbgs()
173  << "cannot convert " << memRefType
174  << " due to being unable to find memory space in map\n");
175  return std::nullopt;
176  }
177 
178  auto storageAttr =
179  spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
180  if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
181  return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
182  rankedType.getLayout(), storageAttr);
183  }
184  return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
185  });
186 
187  addConversion([this](FunctionType type) {
188  SmallVector<Type> inputs, results;
189  inputs.reserve(type.getNumInputs());
190  results.reserve(type.getNumResults());
191  for (Type input : type.getInputs())
192  inputs.push_back(convertType(input));
193  for (Type result : type.getResults())
194  results.push_back(convertType(result));
195  return FunctionType::get(type.getContext(), inputs, results);
196  });
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // Conversion Target
201 //===----------------------------------------------------------------------===//
202 
203 /// Returns true if the given `type` is considered as legal for SPIR-V
204 /// conversion.
205 static bool isLegalType(Type type) {
206  if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
207  Attribute spaceAttr = memRefType.getMemorySpace();
208  return spaceAttr && spaceAttr.isa<spirv::StorageClassAttr>();
209  }
210  return true;
211 }
212 
213 /// Returns true if the given `attr` is considered as legal for SPIR-V
214 /// conversion.
215 static bool isLegalAttr(Attribute attr) {
216  if (auto typeAttr = attr.dyn_cast<TypeAttr>())
217  return isLegalType(typeAttr.getValue());
218  return true;
219 }
220 
221 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
222 static bool isLegalOp(Operation *op) {
223  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
224  return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
225  llvm::all_of(funcOp.getResultTypes(), isLegalType);
226  }
227 
228  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
229  return attr.getValue();
230  });
231 
232  return llvm::all_of(op->getOperandTypes(), isLegalType) &&
233  llvm::all_of(op->getResultTypes(), isLegalType) &&
234  llvm::all_of(attrs, isLegalAttr);
235 }
236 
237 std::unique_ptr<ConversionTarget>
239  auto target = std::make_unique<ConversionTarget>(context);
240  target->markUnknownOpDynamicallyLegal(isLegalOp);
241  return target;
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Conversion Pattern
246 //===----------------------------------------------------------------------===//
247 
248 namespace {
249 /// Converts any op that has operands/results/attributes with numeric MemRef
250 /// memory spaces.
251 struct MapMemRefStoragePattern final : public ConversionPattern {
252  MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
253  : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
254 
256  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
257  ConversionPatternRewriter &rewriter) const override;
258 };
259 } // namespace
260 
261 LogicalResult MapMemRefStoragePattern::matchAndRewrite(
262  Operation *op, ArrayRef<Value> operands,
263  ConversionPatternRewriter &rewriter) const {
265  newAttrs.reserve(op->getAttrs().size());
266  for (auto attr : op->getAttrs()) {
267  if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
268  auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
269  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
270  } else {
271  newAttrs.push_back(attr);
272  }
273  }
274 
275  llvm::SmallVector<Type, 4> newResults;
276  (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
277 
278  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
279  newResults, newAttrs, op->getSuccessors());
280 
281  for (Region &region : op->getRegions()) {
282  Region *newRegion = state.addRegion();
283  rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
285  (void)getTypeConverter()->convertSignatureArgs(
286  newRegion->getArgumentTypes(), result);
287  rewriter.applySignatureConversion(newRegion, result);
288  }
289 
290  Operation *newOp = rewriter.create(state);
291  rewriter.replaceOp(op, newOp->getResults());
292  return success();
293 }
294 
297  RewritePatternSet &patterns) {
298  patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // Conversion Pass
303 //===----------------------------------------------------------------------===//
304 
305 namespace {
306 class MapMemRefStorageClassPass final
307  : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
308 public:
309  explicit MapMemRefStorageClassPass() {
311  }
312  explicit MapMemRefStorageClassPass(
313  const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
314  : memorySpaceMap(memorySpaceMap) {}
315 
316  LogicalResult initializeOptions(StringRef options) override;
317 
318  void runOnOperation() override;
319 
320 private:
322 };
323 } // namespace
324 
325 LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
327  return failure();
328 
329  if (clientAPI == "opencl") {
331  }
332 
333  if (clientAPI != "vulkan" && clientAPI != "opencl")
334  return failure();
335 
336  return success();
337 }
338 
339 void MapMemRefStorageClassPass::runOnOperation() {
340  MLIRContext *context = &getContext();
341  Operation *op = getOperation();
342 
344  spirv::TargetEnv targetEnv(attr);
345  if (targetEnv.allows(spirv::Capability::Kernel)) {
347  } else if (targetEnv.allows(spirv::Capability::Shader)) {
349  }
350  }
351 
352  auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
353  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
354 
355  RewritePatternSet patterns(context);
357 
358  if (failed(applyFullConversion(op, *target, std::move(patterns))))
359  return signalPassFailure();
360 }
361 
362 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
363  return std::make_unique<MapMemRefStorageClassPass>();
364 }
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
#define STORAGE_SPACE_MAP_FN(storage, space)
static bool isLegalAttr(Attribute attr)
Returns true if the given attr is considered as legal for SPIR-V conversion.
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for SPIR-V conversion.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:114
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Base class for the conversion patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
SuccessorRange getSuccessors()
Definition: Operation.h:503
result_range getResults()
Definition: Operation.h:332
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:42
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
Definition: Region.cpp:36
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
U dyn_cast() const
Definition: Types.h:270
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:46
MemorySpaceToStorageClassConverter(const MemorySpaceToStorageClassMap &memorySpaceMap)
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
std::unique_ptr< ConversionTarget > getMemorySpaceToStorageClassTarget(MLIRContext &)
Creates the target that populates legality of ops with MemRef types.
Optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
Optional< unsigned > mapVulkanStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces using the default rule.
void populateMemorySpaceToStorageClassPatterns(MemorySpaceToStorageClassConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for converting numeric MemRef memory spaces into SPIR-V...
Optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
std::function< Optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:26
Optional< unsigned > mapOpenCLStorageClassToMemorySpace(spirv::StorageClass)
Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces using the default rule.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::unique_ptr< OperationPass<> > createMapMemRefStorageClassPass()
Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V storage classes.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.