MLIR 23.0.0git
TosaToSPIRVTosaCustom.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaCustom.cpp - TOSA to SPIR-V Graph/TOSA patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
10//
11//===----------------------------------------------------------------------===//
12
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/Sequence.h"
19
20#define DEBUG_TYPE "tosa-to-spirv-tosa-custom-pattern"
21
22namespace mlir::tosa {
23namespace {
24
25Value encodeStringAsI8Array(StringRef value, Location loc,
26 ConversionPatternRewriter &rewriter) {
27 Type i8Type = rewriter.getIntegerType(8);
28 // Empty strings are encoded as a single NULL byte because SPIR-V array
29 // types require at least one element.
30 StringRef encodedValue = value.empty() ? StringRef("\0", 1) : value;
31
32 SmallVector<Attribute> bytes;
33 bytes.reserve(encodedValue.size());
34 llvm::transform(
35 encodedValue, std::back_inserter(bytes),
36 [&](unsigned char byte) { return IntegerAttr::get(i8Type, byte); });
37
38 auto arrayType =
39 spirv::ArrayType::get(i8Type, static_cast<unsigned>(bytes.size()));
40 auto arrayValue = ArrayAttr::get(rewriter.getContext(), bytes);
41 return spirv::ConstantOp::create(rewriter, loc, arrayType, arrayValue);
42}
43
44struct TosaCustomOpConvert final : public OpConversionPattern<tosa::CustomOp> {
45 TosaCustomOpConvert(const TypeConverter &typeConverter, MLIRContext *context,
46 llvm::StringMap<int32_t> domainToOpcode)
47 : OpConversionPattern<tosa::CustomOp>(typeConverter, context),
48 domainToOpcode(std::move(domainToOpcode)) {}
49
50 LogicalResult
51 matchAndRewrite(tosa::CustomOp op, tosa::CustomOpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 auto opCode = domainToOpcode.find(op.getDomainName());
54 if (opCode == domainToOpcode.end())
55 return failure();
56
57 if (op->getResultTypes().empty())
58 return op.emitOpError("with mapped domain requires at least one result");
59
60 SmallVector<Type> types;
61 if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
62 types)))
63 return rewriter.notifyMatchFailure(op, "type conversion failed");
64
65 Type resultType =
66 types.size() == 1 ? types.front() : spirv::StructType::get(types);
67
68 Value operatorName =
69 encodeStringAsI8Array(op.getOperatorName(), op.getLoc(), rewriter);
70 Value implementationAttrsBlob = encodeStringAsI8Array(
71 op.getImplementationAttrs(), op.getLoc(), rewriter);
72
73 SmallVector<Value> inputs = {operatorName, implementationAttrsBlob};
74 inputs.append(adaptor.getInputList().begin(), adaptor.getInputList().end());
75
76 Value result = spirv::ExperimentalMLCallOp::create(
77 rewriter, op.getLoc(), resultType,
78 rewriter.getI32IntegerAttr(opCode->second), inputs);
79
80 if (types.size() == 1) {
81 rewriter.replaceOp(op, result);
82 return success();
83 }
84
85 SmallVector<Value> results;
86 for (auto index : llvm::seq<int32_t>(0, types.size())) {
87 results.push_back(spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
88 result, {index}));
89 }
90 rewriter.replaceOp(op, results);
91 return success();
92 }
93
94private:
95 llvm::StringMap<int32_t> domainToOpcode;
96};
97
98} // namespace
99
101 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
102 llvm::StringMap<int32_t> domainToOpcode) {
103 patterns.add<TosaCustomOpConvert>(typeConverter, patterns.getContext(),
104 std::move(domainToOpcode));
105}
106
107} // namespace mlir::tosa
return success()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static ArrayType get(Type elementType, unsigned elementCount)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
void populateTosaToSPIRVTosaCustomConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, llvm::StringMap< int32_t > domainToOpcode)