MLIR 23.0.0git
TosaToSPIRVTosa.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/STLExtras.h"
19
20#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
21
22namespace mlir::tosa {
23namespace {
24
25void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
26 spirv::GraphARMOp graphOp) {
27 for (NamedAttribute attr : adaptor.getAttributes()) {
28 StringRef attrName = attr.getName().getValue();
29 if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
30 funcOp.getFunctionTypeAttrName().getValue(),
31 funcOp.getArgAttrsAttrName().getValue(),
32 funcOp.getResAttrsAttrName().getValue(),
33 graphOp.getEntryPointAttrName().getValue()},
34 attrName))
35 continue;
36
37 graphOp->setAttr(attr.getName(), attr.getValue());
38 }
39}
40
41struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
42 FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
43 spirv::TargetEnvAttr targetAttr)
44 : OpConversionPattern<func::FuncOp>(typeConverter, context),
45 targetAttr(targetAttr) {}
46
47private:
48 spirv::TargetEnvAttr targetAttr;
49
50 // Prefer an explicit source-side GraphARM ABI annotation, then preserve an
51 // already-canonical SPIR-V ABI annotation, and otherwise synthesize the
52 // default descriptor set and binding id.
53 void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
54 MLIRContext *context, unsigned index,
55 bool isResult,
56 uint32_t defaultDescriptorSet,
57 uint32_t defaultBinding) const {
58 auto abiInfo =
59 isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
61 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
63
64 if (!abiInfo) {
65 abiInfo = isResult
66 ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
68 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
70 }
71
72 if (!abiInfo) {
74 defaultDescriptorSet, defaultBinding, std::nullopt, context);
75 }
76
77 if (isResult) {
78 graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
79 abiInfo);
80 graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
81 } else {
82 graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
83 graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
84 }
85 }
86
87 void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
88 MLIRContext *context, unsigned inputs,
89 unsigned outputs) const {
90 constexpr uint32_t defaultDescriptorSet = 0;
91 for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
92 normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
93 defaultDescriptorSet, argIndex);
94 }
95 for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
96 normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
97 defaultDescriptorSet, resIndex + inputs);
98 }
99 }
100
101public:
102 LogicalResult
103 matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 MLIRContext *context = rewriter.getContext();
106
107 StringRef name = adaptor.getSymName();
108 auto spvModule = spirv::ModuleOp::create(
109 rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
110 spirv::MemoryModel::Vulkan, std::nullopt,
111 ("_spirv_tosa_" + name).str());
112 spvModule->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
113
114 rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
115
116 FunctionType ftype = adaptor.getFunctionType();
117 ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
118 ArrayAttr resAttrs = adaptor.getResAttrsAttr();
119
120 TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
121 if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
122 signatureConverter))) {
123 return funcOp.emitError("failed to convert function argument types");
124 }
125
126 // Update the signature of the function.
127 SmallVector<Type, 2> newResultTypes;
128 if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
129 newResultTypes))) {
130 return funcOp.emitError("failed to convert function result types");
131 }
132
133 // TOSA graphs cannot contain nested funcs, so the converted GraphARM op is
134 // an entry point.
135 auto entryPointAttr = BoolAttr::get(context, true);
136 auto graphTy = GraphType::get(
137 context, signatureConverter.getConvertedTypes(), newResultTypes);
138 auto graphOp =
139 spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
140 resAttrs, entryPointAttr, name);
141 copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
142
143 rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
144 graphOp.end());
145 if (failed(rewriter.convertRegionTypes(
146 &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
147 return funcOp.emitError("failed to convert function regions");
148 }
149
150 normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
151 ftype.getNumResults());
152
153 rewriter.eraseOp(funcOp);
154 return success();
155 }
156};
157
158/// Converts func.return to spirv.ARM.GraphOutputs.
159struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
160 using Base::Base;
161
162 LogicalResult
163 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter) const override {
165 rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
166 returnOp, adaptor.getOperands());
167 return success();
168 }
169};
170
171} // namespace
172
174 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
175 spirv::TargetEnvAttr targetAttr) {
176 patterns.add<FuncGraphConvert>(typeConverter, patterns.getContext(),
177 targetAttr);
178 patterns.add<ReturnGraphOutputConvert>(typeConverter, patterns.getContext());
179}
180
181} // namespace mlir::tosa
return success()
ArrayAttr()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
An attribute that specifies the target version, allowed extensions and capabilities,...
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
void populateTosaToSPIRVTosaConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::TargetEnvAttr targetAttr)
constexpr llvm::StringLiteral graphARMInterfaceVarABIAttrName