22#include "llvm/ADT/StringMap.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/CommandLine.h"
32class parser<std::pair<std::string, int32_t>>
37 bool parse(Option &option, StringRef argName, StringRef arg,
38 std::pair<std::string, int32_t> &value) {
39 auto [domain, opcodeString] = arg.rsplit(
":");
40 if (domain.empty() || opcodeString.empty())
41 return option.error(
"expected <domain>:<opcode>", argName);
44 if (opcodeString.getAsInteger(0, opcode))
45 return option.error(
"invalid opcode in custom op domain mapping",
48 value = {domain.str(), opcode};
52 StringRef
getValueName()
const override {
return "domain:opcode"; }
54 static void print(raw_ostream &os,
55 const std::pair<std::string, int32_t> &value) {
56 os << value.first <<
":" << value.second;
62#define GEN_PASS_DEF_TOSATOSPIRVTOSA
63#include "mlir/Conversion/Passes.h.inc"
69 spirv::Version::V_1_5,
71 spirv::Capability::VulkanMemoryModel,
72 spirv::Capability::Shader,
73 spirv::Capability::Int8,
74 spirv::Capability::Int16,
75 spirv::Capability::Int64,
76 spirv::Capability::Float16,
77 spirv::Capability::BFloat16TypeKHR,
78 spirv::Capability::Float8EXT,
79 spirv::Capability::TensorsARM,
80 spirv::Capability::GraphARM,
81 spirv::Capability::ReplicatedCompositesEXT,
84 spirv::Extension::SPV_ARM_tensors,
85 spirv::Extension::SPV_ARM_graph,
86 spirv::Extension::SPV_KHR_vulkan_memory_model,
87 spirv::Extension::SPV_EXT_replicated_composites,
88 spirv::Extension::SPV_KHR_bfloat16,
89 spirv::Extension::SPV_EXT_float8,
90 spirv::Extension::SPV_KHR_non_semantic_info,
96 MLIRContext *context, spirv::ResourceLimitsAttr limits,
97 spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
98 spirv::DeviceType deviceType, uint32_t deviceID) {
103 clientAPI, vendorID, deviceType, deviceID);
108LogicalResult verifyGraphTargetEnv(
Operation *op,
111 if (targetEnv.allows(spirv::Capability::GraphARM) &&
112 targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
113 targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
118 <<
"requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
119 "extensions in spirv.target_env";
122LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
123 WalkResult
result = op->walk([](Operation *op) -> WalkResult {
124 if (isa<func::CallOp, func::CallIndirectOp>(op)) {
126 <<
"is not supported in TOSA to SPIR-V Graph conversion; inline "
127 "calls before running this pass";
130 if (
auto funcOp = dyn_cast<func::FuncOp>(op)) {
131 if (funcOp->getParentOfType<func::FuncOp>()) {
133 <<
"nesting is not supported in TOSA to SPIR-V Graph conversion";
140 return failure(
result.wasInterrupted());
143LogicalResult verifyGraphConstantIdAttrs(Operation *op) {
144 WalkResult
result = op->walk([](Operation *op) -> WalkResult {
145 if (!isa<tosa::ConstOp, tosa::ConstShapeOp>(op))
148 auto graphConstantId =
150 if (!graphConstantId)
153 if (graphConstantId.getType().isSignlessInteger(32))
157 <<
"` to be a signless i32 integer attribute";
161 return failure(
result.wasInterrupted());
164struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
165 void runOnOperation()
override {
167 RewritePatternSet patterns(context);
168 Operation *op = getOperation();
169 llvm::StringMap<int32_t> domainToOpcode;
170 for (
const auto &[domain, opcode] : customOpDomainToOpcode) {
173 domainToOpcode[domain] = opcode;
181 if (
failed(verifyGraphTargetEnv(op, targetAttr)) ||
182 failed(verifyNoUnsupportedFuncOps(op)) ||
183 failed(verifyGraphConstantIdAttrs(op))) {
188 std::unique_ptr<ConversionTarget>
target =
191 target->addIllegalDialect<tosa::TosaDialect>();
192 target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
194 SPIRVTypeConverter typeConverter(targetAttr);
195 typeConverter.addConversion([
this](IntegerType integerType) {
196 return this->convertIntegerType(integerType);
198 typeConverter.addConversion([
this](TensorType tensorType) {
199 return this->convertTensorType(tensorType);
201 typeConverter.addConversion([
this](tosa::shapeType shapeType) {
202 return this->convertShapeType(shapeType);
209 if (!domainToOpcode.empty())
211 typeConverter, patterns, std::move(domainToOpcode));
213 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
215 if (
failed(applyPartialConversion(op, *
target, frozenPatterns))) {
221 IntegerType convertIntegerType(IntegerType integerType) {
222 if (integerType.getWidth() == 48) {
223 return IntegerType::get(&
getContext(), 64, integerType.getSignedness());
226 if (integerType.getWidth() == 4) {
227 return IntegerType::get(&
getContext(), 8, integerType.getSignedness());
233 std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
237 return SmallVector<int64_t>({1});
239 if (llvm::is_contained(shape, 0))
242 bool isPartiallyDynamic =
243 llvm::is_contained(shape, ShapedType::kDynamic) &&
244 llvm::any_of(shape, [](int64_t dim) {
return dim > 0; });
247 if (isPartiallyDynamic)
248 return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
249 return SmallVector<int64_t>(shape);
252 std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
254 if (elementType.isIndex())
255 elementType = IntegerType::get(&
getContext(), 32);
256 if (
auto integerType = dyn_cast<IntegerType>(elementType))
257 elementType = convertIntegerType(integerType);
259 SmallVector<int64_t> shape;
260 if (tensorType.hasRank()) {
261 std::optional<SmallVector<int64_t>> convertedShape =
262 convertShape(tensorType.getShape());
265 shape = std::move(*convertedShape);
271 spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
272 const int64_t rank = std::max(shapeType.getRank(), 1);
280 return std::make_unique<TosaToSPIRVTosa>();
bool parse(Option &option, StringRef argName, StringRef arg, std::pair< std::string, int32_t > &value)
StringRef getValueName() const override
static void print(raw_ostream &os, const std::pair< std::string, int32_t > &value)
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
static WalkResult advance()
static WalkResult interrupt()
An attribute that specifies the target version, allowed extensions and capabilities,...
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
constexpr llvm::StringLiteral graphARMGraphConstantIdAttrName
spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(MLIRContext *context, spirv::ResourceLimitsAttr limits={}, spirv::ClientAPI clientAPI=spirv::ClientAPI::Unknown, spirv::Vendor vendorID=spirv::Vendor::Unknown, spirv::DeviceType deviceType=spirv::DeviceType::Unknown, uint32_t deviceID=spirv::TargetEnvAttr::kUnknownDeviceID)
spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context)
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
std::unique_ptr< Pass > createTosaToSPIRVTosa()
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
void populateTosaToSPIRVTosaCustomConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, llvm::StringMap< int32_t > domainToOpcode)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.