18#include "llvm/ADT/STLExtras.h"
20#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
25void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
26 spirv::GraphARMOp graphOp) {
27 for (NamedAttribute attr : adaptor.getAttributes()) {
28 StringRef attrName = attr.getName().getValue();
30 funcOp.getFunctionTypeAttrName().getValue(),
31 funcOp.getArgAttrsAttrName().getValue(),
32 funcOp.getResAttrsAttrName().getValue(),
33 graphOp.getEntryPointAttrName().getValue()},
37 graphOp->setAttr(attr.getName(), attr.getValue());
41struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
42 FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
43 spirv::TargetEnvAttr targetAttr)
44 : OpConversionPattern<func::FuncOp>(typeConverter, context),
45 targetAttr(targetAttr) {}
48 spirv::TargetEnvAttr targetAttr;
53 void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
54 MLIRContext *context,
unsigned index,
56 uint32_t defaultDescriptorSet,
57 uint32_t defaultBinding)
const {
59 isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
61 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
66 ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
68 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
74 defaultDescriptorSet, defaultBinding, std::nullopt, context);
87 void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
88 MLIRContext *context,
unsigned inputs,
89 unsigned outputs)
const {
90 constexpr uint32_t defaultDescriptorSet = 0;
91 for (
auto argIndex : llvm::seq<unsigned>(0, inputs)) {
92 normalizeInterfaceVarABIAttr(graphOp, context, argIndex,
false,
93 defaultDescriptorSet, argIndex);
95 for (
auto resIndex : llvm::seq<unsigned>(0, outputs)) {
96 normalizeInterfaceVarABIAttr(graphOp, context, resIndex,
true,
97 defaultDescriptorSet, resIndex + inputs);
103 matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const override {
105 MLIRContext *context = rewriter.getContext();
107 StringRef name = adaptor.getSymName();
108 auto spvModule = spirv::ModuleOp::create(
109 rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
110 spirv::MemoryModel::Vulkan, std::nullopt,
111 (
"_spirv_tosa_" + name).str());
114 rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
116 FunctionType ftype = adaptor.getFunctionType();
117 ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
118 ArrayAttr resAttrs = adaptor.getResAttrsAttr();
120 TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
121 if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
122 signatureConverter))) {
123 return funcOp.emitError(
"failed to convert function argument types");
127 SmallVector<Type, 2> newResultTypes;
128 if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
130 return funcOp.emitError(
"failed to convert function result types");
136 auto graphTy = GraphType::get(
137 context, signatureConverter.getConvertedTypes(), newResultTypes);
139 spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
140 resAttrs, entryPointAttr, name);
141 copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
143 rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
145 if (failed(rewriter.convertRegionTypes(
146 &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
147 return funcOp.emitError(
"failed to convert function regions");
150 normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
151 ftype.getNumResults());
153 rewriter.eraseOp(funcOp);
159struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
163 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter)
const override {
165 rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
166 returnOp, adaptor.getOperands());
176 patterns.
add<FuncGraphConvert>(typeConverter, patterns.
getContext(),
178 patterns.
add<ReturnGraphOutputConvert>(typeConverter, patterns.
getContext());
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
constexpr llvm::StringLiteral graphARMInterfaceVarABIAttrName