18#include "llvm/ADT/STLExtras.h"
20#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
29constexpr StringLiteral graphARMInterfaceVarABIAttrName =
30 "grapharm.interface_var_abi";
32void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
33 spirv::GraphARMOp graphOp) {
34 for (NamedAttribute attr : adaptor.getAttributes()) {
35 StringRef attrName = attr.getName().getValue();
37 funcOp.getFunctionTypeAttrName().getValue(),
38 funcOp.getArgAttrsAttrName().getValue(),
39 funcOp.getResAttrsAttrName().getValue(),
40 graphOp.getEntryPointAttrName().getValue()},
44 graphOp->setAttr(attr.getName(), attr.getValue());
48struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
49 FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
50 spirv::TargetEnvAttr targetAttr)
51 : OpConversionPattern<func::FuncOp>(typeConverter, context),
52 targetAttr(targetAttr) {}
55 spirv::TargetEnvAttr targetAttr;
60 void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
61 MLIRContext *context,
unsigned index,
63 uint32_t defaultDescriptorSet,
64 uint32_t defaultBinding)
const {
66 isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
67 index, graphARMInterfaceVarABIAttrName)
68 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
69 index, graphARMInterfaceVarABIAttrName);
73 ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
75 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
81 defaultDescriptorSet, defaultBinding, std::nullopt, context);
87 graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
90 graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
94 void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
95 MLIRContext *context,
unsigned inputs,
96 unsigned outputs)
const {
97 constexpr uint32_t defaultDescriptorSet = 0;
98 for (
auto argIndex : llvm::seq<unsigned>(0, inputs)) {
99 normalizeInterfaceVarABIAttr(graphOp, context, argIndex,
false,
100 defaultDescriptorSet, argIndex);
102 for (
auto resIndex : llvm::seq<unsigned>(0, outputs)) {
103 normalizeInterfaceVarABIAttr(graphOp, context, resIndex,
true,
104 defaultDescriptorSet, resIndex + inputs);
110 matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const override {
112 MLIRContext *context = rewriter.getContext();
114 StringRef name = adaptor.getSymName();
115 auto spvModule = spirv::ModuleOp::create(
116 rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
117 spirv::MemoryModel::Vulkan, std::nullopt,
118 (
"_spirv_tosa_" + name).str());
121 rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
123 FunctionType ftype = adaptor.getFunctionType();
124 ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
125 ArrayAttr resAttrs = adaptor.getResAttrsAttr();
127 TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
128 if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
129 signatureConverter))) {
130 return funcOp.emitError(
"failed to convert function argument types");
134 SmallVector<Type, 2> newResultTypes;
135 if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
137 return funcOp.emitError(
"failed to convert function result types");
143 auto graphTy = GraphType::get(
144 context, signatureConverter.getConvertedTypes(), newResultTypes);
146 spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
147 resAttrs, entryPointAttr, name);
148 copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
150 rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
152 if (failed(rewriter.convertRegionTypes(
153 &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
154 return funcOp.emitError(
"failed to convert function regions");
157 normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
158 ftype.getNumResults());
160 rewriter.eraseOp(funcOp);
166struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
170 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const override {
172 rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
173 returnOp, adaptor.getOperands());
183 patterns.
add<FuncGraphConvert>(typeConverter, patterns.
getContext(),
185 patterns.
add<ReturnGraphOutputConvert>(typeConverter, patterns.
getContext());
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)