19#define DEBUG_TYPE "complex-to-spirv-pattern"
29struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
33 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
34 ConversionPatternRewriter &rewriter)
const override {
36 getTypeConverter()->convertType<ShapedType>(constOp.getType());
38 return rewriter.notifyMatchFailure(constOp,
39 "unable to convert result type");
41 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
48struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
52 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 Type spirvType = getTypeConverter()->convertType(createOp.getType());
56 return rewriter.notifyMatchFailure(createOp,
57 "unable to convert result type");
59 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
60 createOp, spirvType, adaptor.getOperands());
65struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
69 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter)
const override {
71 Type spirvType = getTypeConverter()->convertType(reOp.getType());
73 return rewriter.notifyMatchFailure(reOp,
"unable to convert result type");
75 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
81struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
85 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter)
const override {
87 Type spirvType = getTypeConverter()->convertType(imOp.getType());
89 return rewriter.notifyMatchFailure(imOp,
"unable to convert result type");
91 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
97template <
typename ComplexOp,
typename SPIRVOp>
98struct ElementwiseBinaryOpPattern final : OpConversionPattern<ComplexOp> {
99 using OpConversionPattern<ComplexOp>::OpConversionPattern;
100 using OpAdaptor =
typename ComplexOp::Adaptor;
103 matchAndRewrite(ComplexOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const override {
106 this->getTypeConverter()->convertType(op.getResult().getType());
108 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
114 Value lhsRe = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
115 Value lhsIm = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
116 Value rhsRe = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
117 Value rhsIm = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
119 Value resultRe = SPIRVOp::create(rewriter, loc, lhsRe, rhsRe);
120 Value resultIm = SPIRVOp::create(rewriter, loc, lhsIm, rhsIm);
122 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
128struct MulOpPattern final : OpConversionPattern<complex::MulOp> {
132 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
136 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
142 Value a = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
143 Value b = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
144 Value c = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
145 Value d = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
147 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
148 Value bd = spirv::FMulOp::create(rewriter, loc,
b, d);
149 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
150 Value bc = spirv::FMulOp::create(rewriter, loc,
b, c);
151 Value resultRe = spirv::FSubOp::create(rewriter, loc, ac, bd);
152 Value resultIm = spirv::FAddOp::create(rewriter, loc, ad, bc);
154 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
160template <
typename SqrtOp>
161struct AbsOpPattern final : OpConversionPattern<complex::AbsOp> {
162 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
165 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter)
const override {
168 this->getTypeConverter()->convertType(op.getResult().getType());
170 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
173 Value complexVal = adaptor.getComplex();
176 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {0});
178 spirv::CompositeExtractOp::create(rewriter, loc, complexVal, {1});
180 Value reSq = spirv::FMulOp::create(rewriter, loc, re, re);
181 Value imSq = spirv::FMulOp::create(rewriter, loc, im, im);
182 Value sum = spirv::FAddOp::create(rewriter, loc, reSq, imSq);
184 rewriter.replaceOpWithNewOp<SqrtOp>(op, sum);
189struct DivOpPattern final : OpConversionPattern<complex::DivOp> {
193 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter)
const override {
195 Type spirvType = getTypeConverter()->convertType(op.getResult().getType());
197 return rewriter.notifyMatchFailure(op,
"unable to convert result type");
203 Value a = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {0});
204 Value b = spirv::CompositeExtractOp::create(rewriter, loc,
lhs, {1});
205 Value c = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {0});
206 Value d = spirv::CompositeExtractOp::create(rewriter, loc,
rhs, {1});
208 Value ac = spirv::FMulOp::create(rewriter, loc, a, c);
209 Value bd = spirv::FMulOp::create(rewriter, loc,
b, d);
210 Value bc = spirv::FMulOp::create(rewriter, loc,
b, c);
211 Value ad = spirv::FMulOp::create(rewriter, loc, a, d);
212 Value cc = spirv::FMulOp::create(rewriter, loc, c, c);
213 Value dd = spirv::FMulOp::create(rewriter, loc, d, d);
214 Value denom = spirv::FAddOp::create(rewriter, loc, cc, dd);
215 Value numRe = spirv::FAddOp::create(rewriter, loc, ac, bd);
216 Value numIm = spirv::FSubOp::create(rewriter, loc, bc, ad);
217 Value resultRe = spirv::FDivOp::create(rewriter, loc, numRe, denom);
218 Value resultIm = spirv::FDivOp::create(rewriter, loc, numIm, denom);
220 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
236 patterns.
add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern,
237 ElementwiseBinaryOpPattern<complex::AddOp, spirv::FAddOp>,
238 ElementwiseBinaryOpPattern<complex::SubOp, spirv::FSubOp>,
239 MulOpPattern, DivOpPattern, AbsOpPattern<spirv::GLSqrtOp>,
240 AbsOpPattern<spirv::CLSqrtOp>>(typeConverter, context);
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Complex ops to SPIR-V ops.