MLIR 23.0.0git
ComplexToSPIRV.cpp
Go to the documentation of this file.
1//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Complex dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
18
19#define DEBUG_TYPE "complex-to-spirv-pattern"
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Operation conversion
25//===----------------------------------------------------------------------===//
26
27namespace {
28
29struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
30 using Base::Base;
31
32 LogicalResult
33 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
34 ConversionPatternRewriter &rewriter) const override {
35 auto spirvType =
36 getTypeConverter()->convertType<ShapedType>(constOp.getType());
37 if (!spirvType)
38 return rewriter.notifyMatchFailure(constOp,
39 "unable to convert result type");
40
41 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
42 constOp, spirvType,
43 DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
44 return success();
45 }
46};
47
48struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
49 using Base::Base;
50
51 LogicalResult
52 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 Type spirvType = getTypeConverter()->convertType(createOp.getType());
55 if (!spirvType)
56 return rewriter.notifyMatchFailure(createOp,
57 "unable to convert result type");
58
59 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
60 createOp, spirvType, adaptor.getOperands());
61 return success();
62 }
63};
64
65struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
66 using Base::Base;
67
68 LogicalResult
69 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter) const override {
71 Type spirvType = getTypeConverter()->convertType(reOp.getType());
72 if (!spirvType)
73 return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
74
75 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
76 reOp, adaptor.getComplex(), llvm::ArrayRef(0));
77 return success();
78 }
79};
80
81struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
82 using Base::Base;
83
84 LogicalResult
85 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter) const override {
87 Type spirvType = getTypeConverter()->convertType(imOp.getType());
88 if (!spirvType)
89 return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
90
91 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
92 imOp, adaptor.getComplex(), llvm::ArrayRef(1));
93 return success();
94 }
95};
96
97template <typename ComplexOp, typename SPIRVOp>
98struct ElementwiseBinaryOpPattern final : OpConversionPattern<ComplexOp> {
99 using OpConversionPattern<ComplexOp>::OpConversionPattern;
100 using OpAdaptor = typename ComplexOp::Adaptor;
101
102 LogicalResult
103 matchAndRewrite(ComplexOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 Type spirvType =
106 this->getTypeConverter()->convertType(op.getResult().getType());
107 if (!spirvType)
108 return rewriter.notifyMatchFailure(op, "unable to convert result type");
109
110 Location loc = op.getLoc();
111 Value lhs = adaptor.getLhs();
112 Value rhs = adaptor.getRhs();
113
114 Value lhsRe = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {0});
115 Value lhsIm = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {1});
116 Value rhsRe = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {0});
117 Value rhsIm = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {1});
118
119 Value resultRe = SPIRVOp::create(rewriter, loc, lhsRe, rhsRe);
120 Value resultIm = SPIRVOp::create(rewriter, loc, lhsIm, rhsIm);
121
122 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
123 op, spirvType, llvm::ArrayRef<Value>{resultRe, resultIm});
124 return success();
125 }
126};
127
128struct MulOpPattern final : OpConversionPattern<complex::MulOp> {
129 using Base::Base;
130
131 LogicalResult
132 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
135 if (!spirvType)
136 return rewriter.notifyMatchFailure(op, "unable to convert result type");
137
138 Location loc = op.getLoc();
139 Value lhs = adaptor.getLhs();
140 Value rhs = adaptor.getRhs();
141
142 Value a = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {0});
143 Value b = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {1});
144 Value c = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {0});
145 Value d = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {1});
146
147 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
148 Value bd = spirv::FMulOp::create(rewriter, loc, b, d);
149 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
150 Value bc = spirv::FMulOp::create(rewriter, loc, b, c);
151 Value resultRe = spirv::FSubOp::create(rewriter, loc, ac, bd);
152 Value resultIm = spirv::FAddOp::create(rewriter, loc, ad, bc);
153
154 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
155 op, spirvType, llvm::ArrayRef<Value>{resultRe, resultIm});
156 return success();
157 }
158};
159
160template <typename SqrtOp>
161struct AbsOpPattern final : OpConversionPattern<complex::AbsOp> {
162 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
163
164 LogicalResult
165 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter) const override {
167 Type spirvType =
168 this->getTypeConverter()->convertType(op.getResult().getType());
169 if (!spirvType)
170 return rewriter.notifyMatchFailure(op, "unable to convert result type");
171
172 Location loc = op.getLoc();
173 Value complexVal = adaptor.getComplex();
174
175 Value re =
176 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {0});
177 Value im =
178 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {1});
179
180 Value reSq = spirv::FMulOp::create(rewriter, loc, re, re);
181 Value imSq = spirv::FMulOp::create(rewriter, loc, im, im);
182 Value sum = spirv::FAddOp::create(rewriter, loc, reSq, imSq);
183
184 rewriter.replaceOpWithNewOp<SqrtOp>(op, sum);
185 return success();
186 }
187};
188
189struct DivOpPattern final : OpConversionPattern<complex::DivOp> {
190 using Base::Base;
191
192 LogicalResult
193 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const override {
195 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
196 if (!spirvType)
197 return rewriter.notifyMatchFailure(op, "unable to convert result type");
198
199 Location loc = op.getLoc();
200 Value lhs = adaptor.getLhs();
201 Value rhs = adaptor.getRhs();
202
203 Value a = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {0});
204 Value b = spirv::CompositeExtractOp::create(rewriter, loc, lhs, {1});
205 Value c = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {0});
206 Value d = spirv::CompositeExtractOp::create(rewriter, loc, rhs, {1});
207
208 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
209 Value bd = spirv::FMulOp::create(rewriter, loc, b, d);
210 Value bc = spirv::FMulOp::create(rewriter, loc, b, c);
211 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
212 Value cc = spirv::FMulOp::create(rewriter, loc, c, c);
213 Value dd = spirv::FMulOp::create(rewriter, loc, d, d);
214 Value denom = spirv::FAddOp::create(rewriter, loc, cc, dd);
215 Value numRe = spirv::FAddOp::create(rewriter, loc, ac, bd);
216 Value numIm = spirv::FSubOp::create(rewriter, loc, bc, ad);
217 Value resultRe = spirv::FDivOp::create(rewriter, loc, numRe, denom);
218 Value resultIm = spirv::FDivOp::create(rewriter, loc, numIm, denom);
219
220 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
221 op, spirvType, llvm::ArrayRef<Value>{resultRe, resultIm});
222 return success();
223 }
224};
225
226} // namespace
227
228//===----------------------------------------------------------------------===//
229// Pattern population
230//===----------------------------------------------------------------------===//
231
233 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
234 MLIRContext *context = patterns.getContext();
235
236 patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern,
237 ElementwiseBinaryOpPattern<complex::AddOp, spirv::FAddOp>,
238 ElementwiseBinaryOpPattern<complex::SubOp, spirv::FSubOp>,
239 MulOpPattern, DivOpPattern, AbsOpPattern<spirv::GLSqrtOp>,
240 AbsOpPattern<spirv::CLSqrtOp>>(typeConverter, context);
241}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Complex ops to SPIR-V ops.