MLIR  21.0.0git
ComplexToSPIRV.cpp
Go to the documentation of this file.
1 //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Complex dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 
19 #define DEBUG_TYPE "complex-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation conversion
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 
29 struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
31 
32  LogicalResult
33  matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
34  ConversionPatternRewriter &rewriter) const override {
35  auto spirvType =
36  getTypeConverter()->convertType<ShapedType>(constOp.getType());
37  if (!spirvType)
38  return rewriter.notifyMatchFailure(constOp,
39  "unable to convert result type");
40 
41  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
42  constOp, spirvType,
43  DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
44  return success();
45  }
46 };
47 
48 struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
50 
51  LogicalResult
52  matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  Type spirvType = getTypeConverter()->convertType(createOp.getType());
55  if (!spirvType)
56  return rewriter.notifyMatchFailure(createOp,
57  "unable to convert result type");
58 
59  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
60  createOp, spirvType, adaptor.getOperands());
61  return success();
62  }
63 };
64 
65 struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
67 
68  LogicalResult
69  matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
70  ConversionPatternRewriter &rewriter) const override {
71  Type spirvType = getTypeConverter()->convertType(reOp.getType());
72  if (!spirvType)
73  return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
74 
75  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
76  reOp, adaptor.getComplex(), llvm::ArrayRef(0));
77  return success();
78  }
79 };
80 
81 struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
83 
84  LogicalResult
85  matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
86  ConversionPatternRewriter &rewriter) const override {
87  Type spirvType = getTypeConverter()->convertType(imOp.getType());
88  if (!spirvType)
89  return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
90 
91  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
92  imOp, adaptor.getComplex(), llvm::ArrayRef(1));
93  return success();
94  }
95 };
96 
97 } // namespace
98 
99 //===----------------------------------------------------------------------===//
100 // Pattern population
101 //===----------------------------------------------------------------------===//
102 
104  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
105  MLIRContext *context = patterns.getContext();
106 
107  patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
108  typeConverter, context);
109 }
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateComplexToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Complex ops to SPIR-V ops.