MLIR  20.0.0git
ConvertToSPIRVPass.cpp
Go to the documentation of this file.
1 //===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Pass/Pass.h"
31 #include <memory>
32 
33 #define DEBUG_TYPE "convert-to-spirv"
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_CONVERTTOSPIRVPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 
42 namespace {
43 
44 /// Map memRef memory space to SPIR-V storage class.
45 void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
46  spirv::TargetEnv targetEnv(targetAttr);
47  bool targetEnvSupportsKernelCapability =
48  targetEnv.allows(spirv::Capability::Kernel);
50  targetEnvSupportsKernelCapability
53  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
54  spirv::convertMemRefTypesAndAttrs(op, converter);
55 }
56 
57 /// Populate patterns for each dialect.
58 void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
59  ScfToSPIRVContext &scfToSPIRVContext,
60  RewritePatternSet &patterns) {
62  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
63  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
64  populateFuncToSPIRVPatterns(typeConverter, patterns);
65  populateGPUToSPIRVPatterns(typeConverter, patterns);
66  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
67  populateMemRefToSPIRVPatterns(typeConverter, patterns);
68  populateVectorToSPIRVPatterns(typeConverter, patterns);
69  populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
70  ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
71 }
72 
73 /// A pass to perform the SPIR-V conversion.
74 struct ConvertToSPIRVPass final
75  : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
76  using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
77 
78  void runOnOperation() override {
79  Operation *op = getOperation();
80  MLIRContext *context = &getContext();
81 
82  // Unroll vectors in function signatures to native size.
83  if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
84  return signalPassFailure();
85 
86  // Unroll vectors in function bodies to native size.
87  if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
88  return signalPassFailure();
89 
90  // Generic conversion.
91  if (!convertGPUModules) {
93  std::unique_ptr<ConversionTarget> target =
94  SPIRVConversionTarget::get(targetAttr);
95  SPIRVTypeConverter typeConverter(targetAttr);
96  RewritePatternSet patterns(context);
97  ScfToSPIRVContext scfToSPIRVContext;
98  mapToMemRef(op, targetAttr);
99  populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
100  patterns);
101  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
102  return signalPassFailure();
103  return;
104  }
105 
106  // Clone each GPU kernel module for conversion, given that the GPU
107  // launch op still needs the original GPU kernel module.
108  SmallVector<Operation *, 1> gpuModules;
109  OpBuilder builder(context);
110  op->walk([&](gpu::GPUModuleOp gpuModule) {
111  builder.setInsertionPoint(gpuModule);
112  gpuModules.push_back(builder.clone(*gpuModule));
113  });
114  // Run conversion for each module independently as they can have
115  // different TargetEnv attributes.
116  for (Operation *gpuModule : gpuModules) {
117  spirv::TargetEnvAttr targetAttr =
119  std::unique_ptr<ConversionTarget> target =
120  SPIRVConversionTarget::get(targetAttr);
121  SPIRVTypeConverter typeConverter(targetAttr);
122  RewritePatternSet patterns(context);
123  ScfToSPIRVContext scfToSPIRVContext;
124  mapToMemRef(gpuModule, targetAttr);
125  populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
126  patterns);
127  if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
128  return signalPassFailure();
129  }
130  }
131 };
132 
133 } // namespace
static MLIRContext * getContext(OpFoldResult val)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type converter for converting numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:48
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
bool allows(Capability) const
Returns true if the given capability is allowed.
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:390
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::function< std::optional< spirv::StorageClass >(Attribute)> MemorySpaceToStorageClassMap
Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
Definition: MemRefToSPIRV.h:26
void convertMemRefTypesAndAttrs(Operation *op, MemorySpaceToStorageClassConverter &typeConverter)
Converts all MemRef types and attributes in the op, as decided by the typeConverter.
std::optional< spirv::StorageClass > mapMemorySpaceToOpenCLStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V using the default rule.
LogicalResult unrollVectorsInFuncBodies(Operation *op)
LogicalResult unrollVectorsInSignatures(Operation *op)
std::optional< spirv::StorageClass > mapMemorySpaceToVulkanStorageClass(Attribute)
Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V using the default rule.
void populateUBToSPIRVConversionPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
Definition: UBToSPIRV.cpp:81
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
void populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Func ops to SPIR-V ops.
Definition: FuncToSPIRV.cpp:90
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
Definition: SCFToSPIRV.cpp:439
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:732
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.