MLIR 23.0.0git
TosaToSPIRVTosaPass.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaPass.cpp - Lower TOSA to SPIR-V Graph/TOSA ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers TOSA IR to the SPIR-V Graph/TOSA representation.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "llvm/ADT/StringMap.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/CommandLine.h"
25
26#include <algorithm>
27#include <string>
28#include <utility>
29
30namespace llvm::cl {
31template <>
32class parser<std::pair<std::string, int32_t>>
33 : public basic_parser<std::pair<std::string, int32_t>> {
34public:
35 parser(Option &option) : basic_parser(option) {}
36
37 bool parse(Option &option, StringRef argName, StringRef arg,
38 std::pair<std::string, int32_t> &value) {
39 auto [domain, opcodeString] = arg.rsplit(":");
40 if (domain.empty() || opcodeString.empty())
41 return option.error("expected <domain>:<opcode>", argName);
42
43 int32_t opcode;
44 if (opcodeString.getAsInteger(0, opcode))
45 return option.error("invalid opcode in custom op domain mapping",
46 argName);
47
48 value = {domain.str(), opcode};
49 return false;
50 }
51
52 StringRef getValueName() const override { return "domain:opcode"; }
53
54 static void print(raw_ostream &os,
55 const std::pair<std::string, int32_t> &value) {
56 os << value.first << ":" << value.second;
57 }
58};
59} // namespace llvm::cl
60
61namespace mlir {
62#define GEN_PASS_DEF_TOSATOSPIRVTOSA
63#include "mlir/Conversion/Passes.h.inc"
64
65namespace tosa {
66
69 spirv::Version::V_1_5,
70 {
71 spirv::Capability::VulkanMemoryModel,
72 spirv::Capability::Shader,
73 spirv::Capability::Int8,
74 spirv::Capability::Int16,
75 spirv::Capability::Int64,
76 spirv::Capability::Float16,
77 spirv::Capability::BFloat16TypeKHR,
78 spirv::Capability::Float8EXT,
79 spirv::Capability::TensorsARM,
80 spirv::Capability::GraphARM,
81 spirv::Capability::ReplicatedCompositesEXT,
82 },
83 {
84 spirv::Extension::SPV_ARM_tensors,
85 spirv::Extension::SPV_ARM_graph,
86 spirv::Extension::SPV_KHR_vulkan_memory_model,
87 spirv::Extension::SPV_EXT_replicated_composites,
88 spirv::Extension::SPV_KHR_bfloat16,
89 spirv::Extension::SPV_EXT_float8,
90 spirv::Extension::SPV_KHR_non_semantic_info,
91 },
92 context);
93}
94
96 MLIRContext *context, spirv::ResourceLimitsAttr limits,
97 spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
98 spirv::DeviceType deviceType, uint32_t deviceID) {
99 if (!limits)
100 limits = spirv::getDefaultResourceLimits(context);
101
103 clientAPI, vendorID, deviceType, deviceID);
104}
105
106namespace {
107
108LogicalResult verifyGraphTargetEnv(Operation *op,
109 spirv::TargetEnvAttr targetAttr) {
110 spirv::TargetEnv targetEnv(targetAttr);
111 if (targetEnv.allows(spirv::Capability::GraphARM) &&
112 targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
113 targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
114 return success();
115 }
116
117 return op->emitOpError()
118 << "requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
119 "extensions in spirv.target_env";
120}
121
122LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
123 WalkResult result = op->walk([](Operation *op) -> WalkResult {
124 if (isa<func::CallOp, func::CallIndirectOp>(op)) {
125 op->emitOpError()
126 << "is not supported in TOSA to SPIR-V Graph conversion; inline "
127 "calls before running this pass";
128 return WalkResult::interrupt();
129 }
130 if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
131 if (funcOp->getParentOfType<func::FuncOp>()) {
132 funcOp.emitOpError()
133 << "nesting is not supported in TOSA to SPIR-V Graph conversion";
134 return WalkResult::interrupt();
135 }
136 }
137 return WalkResult::advance();
138 });
139
140 return failure(result.wasInterrupted());
141}
142
143LogicalResult verifyGraphConstantIdAttrs(Operation *op) {
144 WalkResult result = op->walk([](Operation *op) -> WalkResult {
145 if (!isa<tosa::ConstOp, tosa::ConstShapeOp>(op))
146 return WalkResult::advance();
147
148 auto graphConstantId =
149 op->getAttrOfType<IntegerAttr>(graphARMGraphConstantIdAttrName);
150 if (!graphConstantId)
151 return WalkResult::advance();
152
153 if (graphConstantId.getType().isSignlessInteger(32))
154 return WalkResult::advance();
155
156 op->emitOpError() << "requires `" << graphARMGraphConstantIdAttrName
157 << "` to be a signless i32 integer attribute";
158 return WalkResult::interrupt();
159 });
160
161 return failure(result.wasInterrupted());
162}
163
164struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
165 void runOnOperation() override {
166 MLIRContext *context = &getContext();
167 RewritePatternSet patterns(context);
168 Operation *op = getOperation();
169 llvm::StringMap<int32_t> domainToOpcode;
170 for (const auto &[domain, opcode] : customOpDomainToOpcode) {
171 // Allow later entries to override earlier ones, matching command-line
172 // option precedence when the same key is specified multiple times.
173 domainToOpcode[domain] = opcode;
174 }
175
176 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op);
177 if (!targetAttr) {
178 targetAttr = constructTargetEnvAttrWithCapExtDefaults(context);
179 }
180
181 if (failed(verifyGraphTargetEnv(op, targetAttr)) ||
182 failed(verifyNoUnsupportedFuncOps(op)) ||
183 failed(verifyGraphConstantIdAttrs(op))) {
184 signalPassFailure();
185 return;
186 }
187
188 std::unique_ptr<ConversionTarget> target =
189 SPIRVConversionTarget::get(targetAttr);
190
191 target->addIllegalDialect<tosa::TosaDialect>();
192 target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
193
194 SPIRVTypeConverter typeConverter(targetAttr);
195 typeConverter.addConversion([this](IntegerType integerType) {
196 return this->convertIntegerType(integerType);
197 });
198 typeConverter.addConversion([this](TensorType tensorType) {
199 return this->convertTensorType(tensorType);
200 });
201 typeConverter.addConversion([this](tosa::shapeType shapeType) {
202 return this->convertShapeType(shapeType);
203 });
204
205 populateTosaToSPIRVTosaConversionPatterns(typeConverter, patterns,
206 targetAttr);
207 populateTosaToSPIRVTosaOpsConversionPatterns(typeConverter, patterns);
208
209 if (!domainToOpcode.empty())
211 typeConverter, patterns, std::move(domainToOpcode));
212
213 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
214
215 if (failed(applyPartialConversion(op, *target, frozenPatterns))) {
216 signalPassFailure();
217 }
218 }
219
220private:
221 IntegerType convertIntegerType(IntegerType integerType) {
222 if (integerType.getWidth() == 48) {
223 return IntegerType::get(&getContext(), 64, integerType.getSignedness());
224 }
225
226 if (integerType.getWidth() == 4) {
227 return IntegerType::get(&getContext(), 8, integerType.getSignedness());
228 }
229
230 return integerType;
231 }
232
233 std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
234 // Scalar ARM tensors are not supported, so convert them to
235 // tensors with shape [1].
236 if (shape.empty())
237 return SmallVector<int64_t>({1});
238
239 if (llvm::is_contained(shape, 0))
240 return std::nullopt;
241
242 bool isPartiallyDynamic =
243 llvm::is_contained(shape, ShapedType::kDynamic) &&
244 llvm::any_of(shape, [](int64_t dim) { return dim > 0; });
245 // Partially shaped ARM tensors are not supported, so convert them to
246 // unshaped tensors.
247 if (isPartiallyDynamic)
248 return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
249 return SmallVector<int64_t>(shape);
250 }
251
252 std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
253 Type elementType = getElementTypeOrSelf(tensorType);
254 if (elementType.isIndex())
255 elementType = IntegerType::get(&getContext(), 32);
256 if (auto integerType = dyn_cast<IntegerType>(elementType))
257 elementType = convertIntegerType(integerType);
258
259 SmallVector<int64_t> shape;
260 if (tensorType.hasRank()) {
261 std::optional<SmallVector<int64_t>> convertedShape =
262 convertShape(tensorType.getShape());
263 if (!convertedShape)
264 return std::nullopt;
265 shape = std::move(*convertedShape);
266 }
267
268 return spirv::TensorArmType::get(shape, elementType);
269 }
270
271 spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
272 const int64_t rank = std::max(shapeType.getRank(), 1);
273 return spirv::TensorArmType::get({rank},
274 IntegerType::get(&getContext(), 32));
275 }
276};
277} // namespace
278
279std::unique_ptr<Pass> createTosaToSPIRVTosa() {
280 return std::make_unique<TosaToSPIRVTosa>();
281}
282
283} // namespace tosa
284} // namespace mlir
return success()
b getContext())
bool parse(Option &option, StringRef argName, StringRef arg, std::pair< std::string, int32_t > &value)
static void print(raw_ostream &os, const std::pair< std::string, int32_t > &value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
An attribute that specifies the target version, allowed extensions and capabilities,...
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
constexpr llvm::StringLiteral graphARMGraphConstantIdAttrName
spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(MLIRContext *context, spirv::ResourceLimitsAttr limits={}, spirv::ClientAPI clientAPI=spirv::ClientAPI::Unknown, spirv::Vendor vendorID=spirv::Vendor::Unknown, spirv::DeviceType deviceType=spirv::DeviceType::Unknown, uint32_t deviceID=spirv::TargetEnvAttr::kUnknownDeviceID)
spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context)
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
std::unique_ptr< Pass > createTosaToSPIRVTosa()
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
void populateTosaToSPIRVTosaCustomConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, llvm::StringMap< int32_t > domainToOpcode)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.