19#define DEBUG_TYPE "complex-to-spirv-pattern"
29struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
33 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
34 ConversionPatternRewriter &rewriter)
const override {
36 getTypeConverter()->convertType<ShapedType>(constOp.getType());
38 return rewriter.notifyMatchFailure(constOp,
39 "unable to convert result type");
41 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
48struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
52 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 Type spirvType = getTypeConverter()->convertType(createOp.getType());
56 return rewriter.notifyMatchFailure(createOp,
57 "unable to convert result type");
59 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
60 createOp, spirvType, adaptor.getOperands());
65struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
69 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter)
const override {
71 Type spirvType = getTypeConverter()->convertType(reOp.getType());
73 return rewriter.notifyMatchFailure(reOp,
"unable to convert result type");
75 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
81struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
85 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter)
const override {
87 Type spirvType = getTypeConverter()->convertType(imOp.getType());
89 return rewriter.notifyMatchFailure(imOp,
"unable to convert result type");
91 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
97template <
typename ComplexOp,
typename SPIRVOp>
98struct ElementwiseBinaryOpPattern final : OpConversionPattern<ComplexOp> {
99 using OpConversionPattern<ComplexOp>::OpConversionPattern;
100 using OpAdaptor =
typename ComplexOp::Adaptor;
103 matchAndRewrite(ComplexOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const override {
106 this->getTypeConverter()->convertType(op.getResult().getType());
108 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
114 Value lhsRe = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
115 Value lhsIm = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
116 Value rhsRe = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
117 Value rhsIm = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
119 Value resultRe = SPIRVOp::create(rewriter, loc, lhsRe, rhsRe);
120 Value resultIm = SPIRVOp::create(rewriter, loc, lhsIm, rhsIm);
122 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
128struct MulOpPattern final : OpConversionPattern<complex::MulOp> {
132 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
136 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
142 Value a = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
143 Value b = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
144 Value c = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
145 Value d = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
147 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
148 Value bd = spirv::FMulOp::create(rewriter, loc,
b, d);
149 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
150 Value bc = spirv::FMulOp::create(rewriter, loc,
b, c);
151 Value resultRe = spirv::FSubOp::create(rewriter, loc, ac, bd);
152 Value resultIm = spirv::FAddOp::create(rewriter, loc, ad, bc);
154 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
160template <
typename SqrtOp>
161struct AbsOpPattern final : OpConversionPattern<complex::AbsOp> {
162 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
165 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter)
const override {
168 this->getTypeConverter()->convertType(op.getResult().getType());
170 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
173 Value complexVal = adaptor.getComplex();
176 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {0});
178 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {1});
180 Value reSq = spirv::FMulOp::create(rewriter, loc, re, re);
181 Value imSq = spirv::FMulOp::create(rewriter, loc, im, im);
182 Value sum = spirv::FAddOp::create(rewriter, loc, reSq, imSq);
184 rewriter.replaceOpWithNewOp<SqrtOp>(op, sum);
189template <
typename ComplexOp,
bool NegateReal>
190struct NegationOpPattern final : OpConversionPattern<ComplexOp> {
191 using OpConversionPattern<ComplexOp>::OpConversionPattern;
192 using OpAdaptor =
typename ComplexOp::Adaptor;
195 matchAndRewrite(ComplexOp op, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter)
const override {
198 this->getTypeConverter()->convertType(op.getResult().getType());
200 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
203 Value complexVal = adaptor.getComplex();
206 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {0});
208 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {1});
211 NegateReal ? spirv::FNegateOp::create(rewriter, loc, re) : re;
212 Value resultIm = spirv::FNegateOp::create(rewriter, loc, im);
214 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
220struct DivOpPattern final : OpConversionPattern<complex::DivOp> {
224 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter)
const override {
226 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
228 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
234 Value a = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
235 Value b = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
236 Value c = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
237 Value d = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
239 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
240 Value bd = spirv::FMulOp::create(rewriter, loc,
b, d);
241 Value bc = spirv::FMulOp::create(rewriter, loc,
b, c);
242 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
243 Value cc = spirv::FMulOp::create(rewriter, loc, c, c);
244 Value dd = spirv::FMulOp::create(rewriter, loc, d, d);
245 Value denom = spirv::FAddOp::create(rewriter, loc, cc, dd);
246 Value numRe = spirv::FAddOp::create(rewriter, loc, ac, bd);
247 Value numIm = spirv::FSubOp::create(rewriter, loc, bc, ad);
248 Value resultRe = spirv::FDivOp::create(rewriter, loc, numRe, denom);
249 Value resultIm = spirv::FDivOp::create(rewriter, loc, numIm, denom);
251 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
267 patterns.
add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern,
268 ElementwiseBinaryOpPattern<complex::AddOp, spirv::FAddOp>,
269 ElementwiseBinaryOpPattern<complex::SubOp, spirv::FSubOp>,
270 MulOpPattern, DivOpPattern,
271 NegationOpPattern<complex::NegOp,
true>,
272 NegationOpPattern<complex::ConjOp,
false>,
273 AbsOpPattern<spirv::GLSqrtOp>, AbsOpPattern<spirv::CLSqrtOp>>(
274 typeConverter, context);
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Complex ops to SPIR-V ops.