MLIR 23.0.0git
TosaToSPIRVTosaPass.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaPass.cpp - Lower TOSA to SPIR-V Graph/TOSA ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers TOSA IR to the SPIR-V Graph/TOSA representation.
10//
11//===----------------------------------------------------------------------===//
12
14
22
23#include <algorithm>
24
25namespace mlir {
26#define GEN_PASS_DEF_TOSATOSPIRVTOSA
27#include "mlir/Conversion/Passes.h.inc"
28
29namespace tosa {
30
33 spirv::Version::V_1_5,
34 {
35 spirv::Capability::VulkanMemoryModel,
36 spirv::Capability::Shader,
37 spirv::Capability::Int8,
38 spirv::Capability::Int16,
39 spirv::Capability::Int64,
40 spirv::Capability::Float16,
41 spirv::Capability::BFloat16TypeKHR,
42 spirv::Capability::Float8EXT,
43 spirv::Capability::TensorsARM,
44 spirv::Capability::GraphARM,
45 spirv::Capability::ReplicatedCompositesEXT,
46 },
47 {
48 spirv::Extension::SPV_ARM_tensors,
49 spirv::Extension::SPV_ARM_graph,
50 spirv::Extension::SPV_KHR_vulkan_memory_model,
51 spirv::Extension::SPV_EXT_replicated_composites,
52 spirv::Extension::SPV_KHR_bfloat16,
53 spirv::Extension::SPV_EXT_float8,
54 },
55 context);
56}
57
59 MLIRContext *context, spirv::ResourceLimitsAttr limits,
60 spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
61 spirv::DeviceType deviceType, uint32_t deviceID) {
62 if (!limits)
63 limits = spirv::getDefaultResourceLimits(context);
64
66 clientAPI, vendorID, deviceType, deviceID);
67}
68
69namespace {
70
71LogicalResult verifyGraphTargetEnv(Operation *op,
72 spirv::TargetEnvAttr targetAttr) {
73 spirv::TargetEnv targetEnv(targetAttr);
74 if (targetEnv.allows(spirv::Capability::GraphARM) &&
75 targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
76 targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
77 return success();
78 }
79
80 return op->emitOpError()
81 << "requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
82 "extensions in spirv.target_env";
83}
84
85LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
86 WalkResult result = op->walk([](Operation *op) -> WalkResult {
87 if (isa<func::CallOp, func::CallIndirectOp>(op)) {
88 op->emitOpError()
89 << "is not supported in TOSA to SPIR-V Graph conversion; inline "
90 "calls before running this pass";
91 return WalkResult::interrupt();
92 }
93 if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
94 if (funcOp->getParentOfType<func::FuncOp>()) {
95 funcOp.emitOpError()
96 << "nesting is not supported in TOSA to SPIR-V Graph conversion";
97 return WalkResult::interrupt();
98 }
99 }
100 return WalkResult::advance();
101 });
102
103 return failure(result.wasInterrupted());
104}
105
106struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
107 void runOnOperation() override {
108 MLIRContext *context = &getContext();
109 RewritePatternSet patterns(context);
110 Operation *op = getOperation();
111
112 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op);
113 if (!targetAttr) {
114 targetAttr = constructTargetEnvAttrWithCapExtDefaults(context);
115 }
116
117 if (failed(verifyGraphTargetEnv(op, targetAttr)) ||
118 failed(verifyNoUnsupportedFuncOps(op))) {
119 signalPassFailure();
120 return;
121 }
122
123 std::unique_ptr<ConversionTarget> target =
124 SPIRVConversionTarget::get(targetAttr);
125
126 target->addIllegalDialect<tosa::TosaDialect>();
127 target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
128
129 SPIRVTypeConverter typeConverter(targetAttr);
130 typeConverter.addConversion([this](IntegerType integerType) {
131 return this->convertIntegerType(integerType);
132 });
133 typeConverter.addConversion([this](TensorType tensorType) {
134 return this->convertTensorType(tensorType);
135 });
136 typeConverter.addConversion([this](tosa::shapeType shapeType) {
137 return this->convertShapeType(shapeType);
138 });
139
140 populateTosaToSPIRVTosaConversionPatterns(typeConverter, patterns,
141 targetAttr);
142 populateTosaToSPIRVTosaOpsConversionPatterns(typeConverter, patterns);
143
144 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
145
146 if (failed(applyPartialConversion(op, *target, frozenPatterns))) {
147 signalPassFailure();
148 }
149 }
150
151private:
152 IntegerType convertIntegerType(IntegerType integerType) {
153 if (integerType.getWidth() == 48) {
154 return IntegerType::get(&getContext(), 64, integerType.getSignedness());
155 }
156
157 if (integerType.getWidth() == 4) {
158 return IntegerType::get(&getContext(), 8, integerType.getSignedness());
159 }
160
161 return integerType;
162 }
163
164 std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
165 // Scalar ARM tensors are not supported, so convert them to
166 // tensors with shape [1].
167 if (shape.empty())
168 return SmallVector<int64_t>({1});
169
170 if (llvm::is_contained(shape, 0))
171 return std::nullopt;
172
173 bool isPartiallyDynamic =
174 llvm::is_contained(shape, ShapedType::kDynamic) &&
175 llvm::any_of(shape, [](int64_t dim) { return dim > 0; });
176 // Partially shaped ARM tensors are not supported, so convert them to
177 // unshaped tensors.
178 if (isPartiallyDynamic)
179 return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
180 return SmallVector<int64_t>(shape);
181 }
182
183 std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
184 Type elementType = getElementTypeOrSelf(tensorType);
185 if (elementType.isIndex())
186 elementType = IntegerType::get(&getContext(), 32);
187 if (auto integerType = dyn_cast<IntegerType>(elementType))
188 elementType = convertIntegerType(integerType);
189
190 SmallVector<int64_t> shape;
191 if (tensorType.hasRank()) {
192 std::optional<SmallVector<int64_t>> convertedShape =
193 convertShape(tensorType.getShape());
194 if (!convertedShape)
195 return std::nullopt;
196 shape = std::move(*convertedShape);
197 }
198
199 return spirv::TensorArmType::get(shape, elementType);
200 }
201
202 spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
203 const int64_t rank = std::max(shapeType.getRank(), 1);
204 return spirv::TensorArmType::get({rank},
205 IntegerType::get(&getContext(), 32));
206 }
207};
208} // namespace
209
210std::unique_ptr<Pass> createTosaToSPIRVTosa() {
211 return std::make_unique<TosaToSPIRVTosa>();
212}
213
214} // namespace tosa
215} // namespace mlir
return success()
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
An attribute that specifies the target version, allowed extensions and capabilities,...
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(MLIRContext *context, spirv::ResourceLimitsAttr limits={}, spirv::ClientAPI clientAPI=spirv::ClientAPI::Unknown, spirv::Vendor vendorID=spirv::Vendor::Unknown, spirv::DeviceType deviceType=spirv::DeviceType::Unknown, uint32_t deviceID=spirv::TargetEnvAttr::kUnknownDeviceID)
spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context)
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
std::unique_ptr< Pass > createTosaToSPIRVTosa()
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.