17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/Sequence.h"
20#define DEBUG_TYPE "tosa-to-spirv-tosa-custom-pattern"
25Value encodeStringAsI8Array(StringRef value, Location loc,
26 ConversionPatternRewriter &rewriter) {
27 Type i8Type = rewriter.getIntegerType(8);
30 StringRef encodedValue = value.empty() ? StringRef(
"\0", 1) : value;
32 SmallVector<Attribute> bytes;
33 bytes.reserve(encodedValue.size());
35 encodedValue, std::back_inserter(bytes),
36 [&](
unsigned char byte) {
return IntegerAttr::get(i8Type,
byte); });
40 auto arrayValue = ArrayAttr::get(rewriter.getContext(), bytes);
41 return spirv::ConstantOp::create(rewriter, loc, arrayType, arrayValue);
44struct TosaCustomOpConvert final :
public OpConversionPattern<tosa::CustomOp> {
45 TosaCustomOpConvert(
const TypeConverter &typeConverter, MLIRContext *context,
46 llvm::StringMap<int32_t> domainToOpcode)
47 : OpConversionPattern<tosa::CustomOp>(typeConverter, context),
48 domainToOpcode(std::move(domainToOpcode)) {}
51 matchAndRewrite(tosa::CustomOp op, tosa::CustomOpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter)
const override {
53 auto opCode = domainToOpcode.find(op.getDomainName());
54 if (opCode == domainToOpcode.end())
57 if (op->getResultTypes().empty())
58 return op.emitOpError(
"with mapped domain requires at least one result");
60 SmallVector<Type> types;
61 if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
63 return rewriter.notifyMatchFailure(op,
"type conversion failed");
69 encodeStringAsI8Array(op.getOperatorName(), op.getLoc(), rewriter);
70 Value implementationAttrsBlob = encodeStringAsI8Array(
71 op.getImplementationAttrs(), op.getLoc(), rewriter);
73 SmallVector<Value> inputs = {operatorName, implementationAttrsBlob};
74 inputs.append(adaptor.getInputList().begin(), adaptor.getInputList().end());
76 Value
result = spirv::ExperimentalMLCallOp::create(
77 rewriter, op.getLoc(), resultType,
78 rewriter.getI32IntegerAttr(opCode->second), inputs);
80 if (types.size() == 1) {
81 rewriter.replaceOp(op,
result);
85 SmallVector<Value> results;
86 for (
auto index : llvm::seq<int32_t>(0, types.size())) {
87 results.push_back(spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
90 rewriter.replaceOp(op, results);
95 llvm::StringMap<int32_t> domainToOpcode;
102 llvm::StringMap<int32_t> domainToOpcode) {
103 patterns.
add<TosaCustomOpConvert>(typeConverter, patterns.
getContext(),
104 std::move(domainToOpcode));
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static ArrayType get(Type elementType, unsigned elementCount)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
void populateTosaToSPIRVTosaCustomConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns, llvm::StringMap< int32_t > domainToOpcode)