MLIR 23.0.0git
TosaToSPIRVTosaOps.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaOps.cpp - TOSA to SPIR-V Graph/TOSA patterns --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
10//
11//===----------------------------------------------------------------------===//
12
17
18#define DEBUG_TYPE "tosa-to-spirv-tosa-ops-pattern"
19
20namespace mlir::tosa {
21namespace {
22
23template <typename OpAdaptor>
24Value getInput1(OpAdaptor adaptor) {
25 return adaptor.getInput1();
26}
27
28Value getInput1(tosa::ErfOpAdaptor adaptor) { return adaptor.getInput(); }
29
30Value getInput1(tosa::SigmoidOpAdaptor adaptor) { return adaptor.getInput(); }
31
32Value getInput1(tosa::TanhOpAdaptor adaptor) { return adaptor.getInput(); }
33
34template <typename SourceOp, typename TargetOp>
35struct UnaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
36 using OpConversionPattern<SourceOp>::OpConversionPattern;
37
38 LogicalResult
39 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
40 ConversionPatternRewriter &rewriter) const override {
41 Type type = this->getTypeConverter()->convertType(op.getType());
42 if (!type)
43 return rewriter.notifyMatchFailure(op, "type conversion failed");
44 rewriter.replaceOpWithNewOp<TargetOp>(op, type, getInput1(adaptor));
45 return success();
46 }
47};
48
49template <typename SourceOp, typename TargetOp>
50struct BinaryElementwiseOpConvert final : public OpConversionPattern<SourceOp> {
51 using OpConversionPattern<SourceOp>::OpConversionPattern;
52
53 LogicalResult
54 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
55 ConversionPatternRewriter &rewriter) const override {
56 Type type = this->getTypeConverter()->convertType(op.getType());
57 if (!type)
58 return rewriter.notifyMatchFailure(op, "type conversion failed");
59 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1(),
60 adaptor.getInput2());
61 return success();
62 }
63};
64
65template <typename SourceOp, typename TargetOp>
66struct BinaryNanModeElementwiseOpConvert final
67 : public OpConversionPattern<SourceOp> {
68 using OpConversionPattern<SourceOp>::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 auto nanMode =
74 static_cast<spirv::TosaExtNaNPropagationModeType>(adaptor.getNanMode());
75 Type type = this->getTypeConverter()->convertType(op.getType());
76 if (!type)
77 return rewriter.notifyMatchFailure(op, "type conversion failed");
78 rewriter.replaceOpWithNewOp<TargetOp>(
79 op, type, nanMode, adaptor.getInput1(), adaptor.getInput2());
80 return success();
81 }
82};
83
84} // namespace
85
87 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
88 patterns.add<
89 UnaryElementwiseOpConvert<tosa::ErfOp, spirv::TosaErfOp>,
90 UnaryElementwiseOpConvert<tosa::SigmoidOp, spirv::TosaSigmoidOp>,
91 UnaryElementwiseOpConvert<tosa::TanhOp, spirv::TosaTanhOp>,
92 BinaryElementwiseOpConvert<tosa::AddOp, spirv::TosaAddOp>,
93 BinaryElementwiseOpConvert<tosa::BitwiseAndOp, spirv::TosaBitwiseAndOp>,
94 BinaryElementwiseOpConvert<tosa::BitwiseOrOp, spirv::TosaBitwiseOrOp>,
95 BinaryElementwiseOpConvert<tosa::BitwiseXorOp, spirv::TosaBitwiseXorOp>,
96 BinaryElementwiseOpConvert<tosa::IntDivOp, spirv::TosaIntDivOp>,
97 BinaryElementwiseOpConvert<tosa::LogicalAndOp, spirv::TosaLogicalAndOp>,
98 BinaryElementwiseOpConvert<tosa::LogicalLeftShiftOp,
99 spirv::TosaLogicalLeftShiftOp>,
100 BinaryElementwiseOpConvert<tosa::LogicalRightShiftOp,
101 spirv::TosaLogicalRightShiftOp>,
102 BinaryElementwiseOpConvert<tosa::LogicalOrOp, spirv::TosaLogicalOrOp>,
103 BinaryElementwiseOpConvert<tosa::LogicalXorOp, spirv::TosaLogicalXorOp>,
104 BinaryNanModeElementwiseOpConvert<tosa::MaximumOp, spirv::TosaMaximumOp>,
105 BinaryNanModeElementwiseOpConvert<tosa::MinimumOp, spirv::TosaMinimumOp>,
106 BinaryElementwiseOpConvert<tosa::PowOp, spirv::TosaPowOp>,
107 BinaryElementwiseOpConvert<tosa::SubOp, spirv::TosaSubOp>,
108 UnaryElementwiseOpConvert<tosa::AbsOp, spirv::TosaAbsOp>,
109 UnaryElementwiseOpConvert<tosa::BitwiseNotOp, spirv::TosaBitwiseNotOp>,
110 UnaryElementwiseOpConvert<tosa::CeilOp, spirv::TosaCeilOp>,
111 UnaryElementwiseOpConvert<tosa::ClzOp, spirv::TosaClzOp>,
112 UnaryElementwiseOpConvert<tosa::CosOp, spirv::TosaCosOp>,
113 UnaryElementwiseOpConvert<tosa::ExpOp, spirv::TosaExpOp>,
114 UnaryElementwiseOpConvert<tosa::FloorOp, spirv::TosaFloorOp>,
115 UnaryElementwiseOpConvert<tosa::LogOp, spirv::TosaLogOp>,
116 UnaryElementwiseOpConvert<tosa::LogicalNotOp, spirv::TosaLogicalNotOp>,
117 UnaryElementwiseOpConvert<tosa::ReciprocalOp, spirv::TosaReciprocalOp>,
118 UnaryElementwiseOpConvert<tosa::RsqrtOp, spirv::TosaRsqrtOp>,
119 UnaryElementwiseOpConvert<tosa::SinOp, spirv::TosaSinOp>,
120 BinaryElementwiseOpConvert<tosa::EqualOp, spirv::TosaEqualOp>,
121 BinaryElementwiseOpConvert<tosa::GreaterOp, spirv::TosaGreaterOp>,
122 BinaryElementwiseOpConvert<tosa::GreaterEqualOp,
123 spirv::TosaGreaterEqualOp>>(
124 typeConverter, patterns.getContext());
125}
126
127} // namespace mlir::tosa
return success()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)