MLIR  21.0.0git
SPIRVConversion.h
Go to the documentation of this file.
1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines utilities to use while converting to the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
14 #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
15 
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/Support/LogicalResult.h"
25 
26 namespace mlir {
27 
28 //===----------------------------------------------------------------------===//
29 // Type Converter
30 //===----------------------------------------------------------------------===//
31 
32 /// How sub-byte values are storaged in memory.
34  /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
35  Packed,
36 };
37 
39  /// The number of bits to store a boolean value.
40  unsigned boolNumBits{8};
41 
42  /// How sub-byte values are storaged in memory.
44 
45  /// Whether to emulate narrower scalar types with 32-bit scalar types if not
46  /// supported by the target.
47  ///
48  /// Non-32-bit scalar types require special hardware support that may not
49  /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
50  /// types require special capabilities or extensions. This option controls
51  /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
52  /// type of a certain bitwidth is not supported in the target environment.
53  /// This requires the runtime to also feed in data with a matched bitwidth and
54  /// layout for interface types. The runtime can do that by inspecting the
55  /// SPIR-V module.
56  ///
57  /// If the original scalar type has less than 32-bit, a multiple of its
58  /// values will be packed into one 32-bit value to be memory efficient.
60 
61  /// Use 64-bit integers when converting index types.
62  bool use64bitIndex{false};
63 };
64 
65 /// Type conversion from builtin types to SPIR-V types for shader interface.
66 ///
67 /// For memref types, this converter additionally performs type wrapping to
68 /// satisfy shader interface requirements: shader interface types must be
69 /// pointers to structs.
71 public:
72  explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
73  const SPIRVConversionOptions &options = {});
74 
75  /// Gets the SPIR-V correspondence for the standard index type.
76  Type getIndexType() const;
77 
78  /// Gets the bitwidth of the index type when converted to SPIR-V.
79  unsigned getIndexTypeBitwidth() const {
80  return options.use64bitIndex ? 64 : 32;
81  }
82 
83  const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
84 
85  /// Returns the options controlling the SPIR-V type converter.
86  const SPIRVConversionOptions &getOptions() const { return options; }
87 
88  /// Checks if the SPIR-V capability inquired is supported.
89  bool allows(spirv::Capability capability) const;
90 
91 private:
92  spirv::TargetEnv targetEnv;
93  SPIRVConversionOptions options;
94 
95  MLIRContext *getContext() const;
96 };
97 
98 //===----------------------------------------------------------------------===//
99 // Conversion Target
100 //===----------------------------------------------------------------------===//
101 
102 // The default SPIR-V conversion target.
103 //
104 // It takes a SPIR-V target environment and controls operation legality based on
105 // the their availability in the target environment.
107 public:
108  /// Creates a SPIR-V conversion target for the given target environment.
109  static std::unique_ptr<SPIRVConversionTarget>
110  get(spirv::TargetEnvAttr targetAttr);
111 
112 private:
113  explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
114 
115  // Be explicit that instance of this class cannot be copied or moved: there
116  // are lambdas capturing fields of the instance.
119  SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
120  SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
121 
122  /// Returns true if the given `op` is legal to use under the current target
123  /// environment.
124  bool isLegalOp(Operation *op);
125 
126  spirv::TargetEnv targetEnv;
127 };
128 
129 //===----------------------------------------------------------------------===//
130 // Patterns and Utility Functions
131 //===----------------------------------------------------------------------===//
132 
133 /// Appends to a pattern list additional patterns for translating the builtin
134 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
135 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
136 /// types.
139 
141 
143 
144 namespace spirv {
145 class AccessChainOp;
146 
147 /// Returns the value for the given `builtin` variable. This function gets or
148 /// inserts the global variable associated for the builtin within the nearest
149 /// symbol table enclosing `op`. Returns null Value on error.
150 ///
151 /// The global name being generated will be mangled using `preffix` and
152 /// `suffix`.
153 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
154  OpBuilder &builder,
155  StringRef prefix = "__builtin__",
156  StringRef suffix = "__");
157 
158 /// Gets the value at the given `offset` of the push constant storage with a
159 /// total of `elementCount` `integerType` integers. A global variable will be
160 /// created in the nearest symbol table enclosing `op` for the push constant
161 /// storage if not existing. Load ops will be created via the given `builder` to
162 /// load values from the push constant. Returns null Value on error.
163 Value getPushConstantValue(Operation *op, unsigned elementCount,
164  unsigned offset, Type integerType,
165  OpBuilder &builder);
166 
167 /// Generates IR to perform index linearization with the given `indices` and
168 /// their corresponding `strides`, adding an initial `offset`.
170  int64_t offset, Type integerType, Location loc,
171  OpBuilder &builder);
172 
173 /// Performs the index computation to get to the element at `indices` of the
174 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
175 /// Returns null if index computation cannot be performed.
176 
177 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
178 // that has static strides. Extend to handle dynamic strides.
179 Value getElementPtr(const SPIRVTypeConverter &typeConverter,
180  MemRefType baseType, Value basePtr, ValueRange indices,
181  Location loc, OpBuilder &builder);
182 
183 // GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
184 Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
185  MemRefType baseType, Value basePtr,
186  ValueRange indices, Location loc, OpBuilder &builder);
187 
188 // GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
189 Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
190  MemRefType baseType, Value basePtr,
191  ValueRange indices, Location loc, OpBuilder &builder);
192 
193 // Find the largest factor of size among {2,3,4} for the lowest dimension of
194 // the target shape.
195 int getComputeVectorSize(int64_t size);
196 
197 // GetNativeVectorShape implementation for reduction ops.
198 SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
199 
200 // GetNativeVectorShape implementation for transpose ops.
201 SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
202 
203 // For general ops.
204 std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
205 
206 // Unroll vectors in function signatures to native size.
207 LogicalResult unrollVectorsInSignatures(Operation *op);
208 
209 // Unroll vectors in function bodies to native size.
210 LogicalResult unrollVectorsInFuncBodies(Operation *op);
211 
212 } // namespace spirv
213 } // namespace mlir
214 
215 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:204
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
const SPIRVConversionOptions & getOptions() const
Returns the options controlling the SPIR-V type converter.
const spirv::TargetEnv & getTargetEnv() const
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to SPIR-V.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
SPIRVSubByteTypeStorage
How sub-byte values are storaged in memory.
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
bool use64bitIndex
Use 64-bit integers when converting index types.
unsigned boolNumBits
The number of bits to store a boolean value.
bool emulateLT32BitScalarTypes
Whether to emulate narrower scalar types with 32-bit scalar types if not supported by the target.
SPIRVSubByteTypeStorage subByteTypeStorage
How sub-byte values are storaged in memory.