MLIR 22.0.0git
SPIRVConversion.h
Go to the documentation of this file.
1//===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines utilities to use while converting to the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
14#define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
15
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/Support/LogicalResult.h"
25
26namespace mlir {
27
28//===----------------------------------------------------------------------===//
29// Type Converter
30//===----------------------------------------------------------------------===//
31
32/// How sub-byte values are storaged in memory.
34 /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
36};
37
39 /// The number of bits to store a boolean value.
40 unsigned boolNumBits{8};
41
42 /// Whether to emulate unsupported floats with integer types of same bit
43 /// width.
45
46 /// How sub-byte values are storaged in memory.
48
49 /// Whether to emulate narrower scalar types with 32-bit scalar types if not
50 /// supported by the target.
51 ///
52 /// Non-32-bit scalar types require special hardware support that may not
53 /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
54 /// types require special capabilities or extensions. This option controls
55 /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
56 /// type of a certain bitwidth is not supported in the target environment.
57 /// This requires the runtime to also feed in data with a matched bitwidth and
58 /// layout for interface types. The runtime can do that by inspecting the
59 /// SPIR-V module.
60 ///
61 /// If the original scalar type has less than 32-bit, a multiple of its
62 /// values will be packed into one 32-bit value to be memory efficient.
64
65 /// Use 64-bit integers when converting index types.
66 bool use64bitIndex{false};
67};
68
69/// Type conversion from builtin types to SPIR-V types for shader interface.
70///
71/// For memref types, this converter additionally performs type wrapping to
72/// satisfy shader interface requirements: shader interface types must be
73/// pointers to structs.
75public:
76 explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
77 const SPIRVConversionOptions &options = {});
78
79 /// Gets the SPIR-V correspondence for the standard index type.
80 Type getIndexType() const;
81
82 /// Gets the bitwidth of the index type when converted to SPIR-V.
83 unsigned getIndexTypeBitwidth() const {
84 return options.use64bitIndex ? 64 : 32;
85 }
86
87 const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
88
89 /// Returns the options controlling the SPIR-V type converter.
90 const SPIRVConversionOptions &getOptions() const { return options; }
91
92 /// Checks if the SPIR-V capability inquired is supported.
93 bool allows(spirv::Capability capability) const;
94
95private:
96 spirv::TargetEnv targetEnv;
98
99 MLIRContext *getContext() const;
100};
101
102//===----------------------------------------------------------------------===//
103// Conversion Target
104//===----------------------------------------------------------------------===//
105
106// The default SPIR-V conversion target.
107//
108// It takes a SPIR-V target environment and controls operation legality based on
109// the their availability in the target environment.
110class SPIRVConversionTarget : public ConversionTarget {
111public:
112 /// Creates a SPIR-V conversion target for the given target environment.
113 static std::unique_ptr<SPIRVConversionTarget>
114 get(spirv::TargetEnvAttr targetAttr);
115
116private:
117 explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
118
119 // Be explicit that instance of this class cannot be copied or moved: there
120 // are lambdas capturing fields of the instance.
121 SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
122 SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
123 SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
124 SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
125
126 /// Returns true if the given `op` is legal to use under the current target
127 /// environment.
128 bool isLegalOp(Operation *op);
129
130 spirv::TargetEnv targetEnv;
131};
132
133//===----------------------------------------------------------------------===//
134// Patterns and Utility Functions
135//===----------------------------------------------------------------------===//
136
137/// Appends to a pattern list additional patterns for translating the builtin
138/// `func` op to the SPIR-V dialect. These patterns do not handle shader
139/// interface/ABI; they convert function parameters to be of SPIR-V allowed
140/// types.
143
145
147
148namespace spirv {
149class AccessChainOp;
150
151/// Returns the value for the given `builtin` variable. This function gets or
152/// inserts the global variable associated for the builtin within the nearest
153/// symbol table enclosing `op`. Returns null Value on error.
154///
155/// The global name being generated will be mangled using `preffix` and
156/// `suffix`.
158 OpBuilder &builder,
159 StringRef prefix = "__builtin__",
160 StringRef suffix = "__");
161
162/// Gets the value at the given `offset` of the push constant storage with a
163/// total of `elementCount` `integerType` integers. A global variable will be
164/// created in the nearest symbol table enclosing `op` for the push constant
165/// storage if not existing. Load ops will be created via the given `builder` to
166/// load values from the push constant. Returns null Value on error.
167Value getPushConstantValue(Operation *op, unsigned elementCount,
168 unsigned offset, Type integerType,
169 OpBuilder &builder);
170
171/// Generates IR to perform index linearization with the given `indices` and
172/// their corresponding `strides`, adding an initial `offset`.
174 int64_t offset, Type integerType, Location loc,
175 OpBuilder &builder);
176
177/// Performs the index computation to get to the element at `indices` of the
178/// memory pointed to by `basePtr`, using the layout map of `baseType`.
179/// Returns null if index computation cannot be performed.
180
181// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
182// that has static strides. Extend to handle dynamic strides.
183Value getElementPtr(const SPIRVTypeConverter &typeConverter,
184 MemRefType baseType, Value basePtr, ValueRange indices,
185 Location loc, OpBuilder &builder);
186
187// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
188Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
189 MemRefType baseType, Value basePtr,
190 ValueRange indices, Location loc, OpBuilder &builder);
191
192// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
193Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
194 MemRefType baseType, Value basePtr,
195 ValueRange indices, Location loc, OpBuilder &builder);
196
197// Find the largest factor of size among {2,3,4} for the lowest dimension of
198// the target shape.
200
201// GetNativeVectorShape implementation for reduction ops.
202SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
203
204// GetNativeVectorShape implementation for transpose ops.
205SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
206
207// For general ops.
208std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
209
210// Unroll vectors in function signatures to native size.
211LogicalResult unrollVectorsInSignatures(Operation *op);
212
213// Unroll vectors in function bodies to native size.
214LogicalResult unrollVectorsInFuncBodies(Operation *op);
215
216} // namespace spirv
217} // namespace mlir
218
219#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
const SPIRVConversionOptions & getOptions() const
Returns the options controlling the SPIR-V type converter.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to SPIR-V.
const spirv::TargetEnv & getTargetEnv() const
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
SPIRVSubByteTypeStorage
How sub-byte values are storaged in memory.
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
bool use64bitIndex
Use 64-bit integers when converting index types.
bool emulateUnsupportedFloatTypes
Whether to emulate unsupported floats with integer types of same bit width.
unsigned boolNumBits
The number of bits to store a boolean value.
bool emulateLT32BitScalarTypes
Whether to emulate narrower scalar types with 32-bit scalar types if not supported by the target.
SPIRVSubByteTypeStorage subByteTypeStorage
How sub-byte values are storaged in memory.