MLIR 23.0.0git
TosaToSPIRVTosa.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/STLExtras.h"
19
20#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
21
22namespace mlir::tosa {
23namespace {
24
25// Allows users to specify descriptor sets and binding ids on the source
26// function inputs and outputs. Use a source-side GraphARM attribute because
27// `spirv.interface_var_abi` is verified by the SPIR-V dialect before this
28// conversion runs, and result attrs are only accepted on `spirv.ARM.Graph`.
29constexpr StringLiteral graphARMInterfaceVarABIAttrName =
30 "grapharm.interface_var_abi";
31
32void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
33 spirv::GraphARMOp graphOp) {
34 for (NamedAttribute attr : adaptor.getAttributes()) {
35 StringRef attrName = attr.getName().getValue();
36 if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
37 funcOp.getFunctionTypeAttrName().getValue(),
38 funcOp.getArgAttrsAttrName().getValue(),
39 funcOp.getResAttrsAttrName().getValue(),
40 graphOp.getEntryPointAttrName().getValue()},
41 attrName))
42 continue;
43
44 graphOp->setAttr(attr.getName(), attr.getValue());
45 }
46}
47
48struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
49 FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
50 spirv::TargetEnvAttr targetAttr)
51 : OpConversionPattern<func::FuncOp>(typeConverter, context),
52 targetAttr(targetAttr) {}
53
54private:
55 spirv::TargetEnvAttr targetAttr;
56
57 // Prefer an explicit source-side GraphARM ABI annotation, then preserve an
58 // already-canonical SPIR-V ABI annotation, and otherwise synthesize the
59 // default descriptor set and binding id.
60 void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
61 MLIRContext *context, unsigned index,
62 bool isResult,
63 uint32_t defaultDescriptorSet,
64 uint32_t defaultBinding) const {
65 auto abiInfo =
66 isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
67 index, graphARMInterfaceVarABIAttrName)
68 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
69 index, graphARMInterfaceVarABIAttrName);
70
71 if (!abiInfo) {
72 abiInfo = isResult
73 ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
75 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
77 }
78
79 if (!abiInfo) {
81 defaultDescriptorSet, defaultBinding, std::nullopt, context);
82 }
83
84 if (isResult) {
85 graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
86 abiInfo);
87 graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
88 } else {
89 graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
90 graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
91 }
92 }
93
94 void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
95 MLIRContext *context, unsigned inputs,
96 unsigned outputs) const {
97 constexpr uint32_t defaultDescriptorSet = 0;
98 for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
99 normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
100 defaultDescriptorSet, argIndex);
101 }
102 for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
103 normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
104 defaultDescriptorSet, resIndex + inputs);
105 }
106 }
107
108public:
109 LogicalResult
110 matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const override {
112 MLIRContext *context = rewriter.getContext();
113
114 StringRef name = adaptor.getSymName();
115 auto spvModule = spirv::ModuleOp::create(
116 rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
117 spirv::MemoryModel::Vulkan, std::nullopt,
118 ("_spirv_tosa_" + name).str());
119 spvModule->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
120
121 rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
122
123 FunctionType ftype = adaptor.getFunctionType();
124 ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
125 ArrayAttr resAttrs = adaptor.getResAttrsAttr();
126
127 TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
128 if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
129 signatureConverter))) {
130 return funcOp.emitError("failed to convert function argument types");
131 }
132
133 // Update the signature of the function.
134 SmallVector<Type, 2> newResultTypes;
135 if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
136 newResultTypes))) {
137 return funcOp.emitError("failed to convert function result types");
138 }
139
140 // TOSA graphs cannot contain nested funcs, so the converted GraphARM op is
141 // an entry point.
142 auto entryPointAttr = BoolAttr::get(context, true);
143 auto graphTy = GraphType::get(
144 context, signatureConverter.getConvertedTypes(), newResultTypes);
145 auto graphOp =
146 spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
147 resAttrs, entryPointAttr, name);
148 copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
149
150 rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
151 graphOp.end());
152 if (failed(rewriter.convertRegionTypes(
153 &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
154 return funcOp.emitError("failed to convert function regions");
155 }
156
157 normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
158 ftype.getNumResults());
159
160 rewriter.eraseOp(funcOp);
161 return success();
162 }
163};
164
165/// Converts func.return to spirv.ARM.GraphOutputs.
166struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
167 using Base::Base;
168
169 LogicalResult
170 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter) const override {
172 rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
173 returnOp, adaptor.getOperands());
174 return success();
175 }
176};
177
178} // namespace
179
181 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
182 spirv::TargetEnvAttr targetAttr) {
183 patterns.add<FuncGraphConvert>(typeConverter, patterns.getContext(),
184 targetAttr);
185 patterns.add<ReturnGraphOutputConvert>(typeConverter, patterns.getContext());
186}
187
188} // namespace mlir::tosa
return success()
ArrayAttr()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)