26#define GEN_PASS_DEF_TOSATOSPIRVTOSA
27#include "mlir/Conversion/Passes.h.inc"
33 spirv::Version::V_1_5,
35 spirv::Capability::VulkanMemoryModel,
36 spirv::Capability::Shader,
37 spirv::Capability::Int8,
38 spirv::Capability::Int16,
39 spirv::Capability::Int64,
40 spirv::Capability::Float16,
41 spirv::Capability::BFloat16TypeKHR,
42 spirv::Capability::Float8EXT,
43 spirv::Capability::TensorsARM,
44 spirv::Capability::GraphARM,
45 spirv::Capability::ReplicatedCompositesEXT,
48 spirv::Extension::SPV_ARM_tensors,
49 spirv::Extension::SPV_ARM_graph,
50 spirv::Extension::SPV_KHR_vulkan_memory_model,
51 spirv::Extension::SPV_EXT_replicated_composites,
52 spirv::Extension::SPV_KHR_bfloat16,
53 spirv::Extension::SPV_EXT_float8,
59 MLIRContext *context, spirv::ResourceLimitsAttr limits,
60 spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
61 spirv::DeviceType deviceType, uint32_t deviceID) {
66 clientAPI, vendorID, deviceType, deviceID);
71LogicalResult verifyGraphTargetEnv(
Operation *op,
74 if (targetEnv.allows(spirv::Capability::GraphARM) &&
75 targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
76 targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
81 <<
"requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
82 "extensions in spirv.target_env";
85LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
86 WalkResult
result = op->walk([](Operation *op) -> WalkResult {
87 if (isa<func::CallOp, func::CallIndirectOp>(op)) {
89 <<
"is not supported in TOSA to SPIR-V Graph conversion; inline "
90 "calls before running this pass";
93 if (
auto funcOp = dyn_cast<func::FuncOp>(op)) {
94 if (funcOp->getParentOfType<func::FuncOp>()) {
96 <<
"nesting is not supported in TOSA to SPIR-V Graph conversion";
103 return failure(
result.wasInterrupted());
106struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
107 void runOnOperation()
override {
109 RewritePatternSet patterns(context);
110 Operation *op = getOperation();
117 if (
failed(verifyGraphTargetEnv(op, targetAttr)) ||
118 failed(verifyNoUnsupportedFuncOps(op))) {
123 std::unique_ptr<ConversionTarget>
target =
126 target->addIllegalDialect<tosa::TosaDialect>();
127 target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
129 SPIRVTypeConverter typeConverter(targetAttr);
130 typeConverter.addConversion([
this](IntegerType integerType) {
131 return this->convertIntegerType(integerType);
133 typeConverter.addConversion([
this](TensorType tensorType) {
134 return this->convertTensorType(tensorType);
136 typeConverter.addConversion([
this](tosa::shapeType shapeType) {
137 return this->convertShapeType(shapeType);
144 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
146 if (
failed(applyPartialConversion(op, *
target, frozenPatterns))) {
152 IntegerType convertIntegerType(IntegerType integerType) {
153 if (integerType.getWidth() == 48) {
154 return IntegerType::get(&
getContext(), 64, integerType.getSignedness());
157 if (integerType.getWidth() == 4) {
158 return IntegerType::get(&
getContext(), 8, integerType.getSignedness());
164 std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
168 return SmallVector<int64_t>({1});
170 if (llvm::is_contained(shape, 0))
173 bool isPartiallyDynamic =
174 llvm::is_contained(shape, ShapedType::kDynamic) &&
175 llvm::any_of(shape, [](int64_t dim) {
return dim > 0; });
178 if (isPartiallyDynamic)
179 return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
180 return SmallVector<int64_t>(shape);
183 std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
185 if (elementType.isIndex())
186 elementType = IntegerType::get(&
getContext(), 32);
187 if (
auto integerType = dyn_cast<IntegerType>(elementType))
188 elementType = convertIntegerType(integerType);
190 SmallVector<int64_t> shape;
191 if (tensorType.hasRank()) {
192 std::optional<SmallVector<int64_t>> convertedShape =
193 convertShape(tensorType.getShape());
196 shape = std::move(*convertedShape);
202 spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
203 const int64_t rank = std::max(shapeType.getRank(), 1);
211 return std::make_unique<TosaToSPIRVTosa>();
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
static WalkResult advance()
static WalkResult interrupt()
An attribute that specifies the target version, allowed extensions and capabilities,...
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(MLIRContext *context, spirv::ResourceLimitsAttr limits={}, spirv::ClientAPI clientAPI=spirv::ClientAPI::Unknown, spirv::Vendor vendorID=spirv::Vendor::Unknown, spirv::DeviceType deviceType=spirv::DeviceType::Unknown, uint32_t deviceID=spirv::TargetEnvAttr::kUnknownDeviceID)
spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context)
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
std::unique_ptr< Pass > createTosaToSPIRVTosa()
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.