MLIR 23.0.0git
FuncToEmitC.cpp
Go to the documentation of this file.
1//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Func dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
20
21using namespace mlir;
22
23namespace {
24
25/// Implement the interface to convert Func to EmitC.
26struct FuncToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
27 FuncToEmitCDialectInterface(Dialect *dialect)
28 : ConvertToEmitCPatternInterface(dialect) {}
29
30 /// Hook for derived dialect interface to provide conversion patterns
31 /// and mark dialect legal for the conversion target.
32 void populateConvertToEmitCConversionPatterns(
33 ConversionTarget &target, TypeConverter &typeConverter,
34 RewritePatternSet &patterns) const final {
35 populateFuncToEmitCPatterns(typeConverter, patterns);
36 }
37};
38} // namespace
39
41 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
42 dialect->addInterfaces<FuncToEmitCDialectInterface>();
43 });
44}
45
46//===----------------------------------------------------------------------===//
47// Conversion Patterns
48//===----------------------------------------------------------------------===//
49
50namespace {
51class CallOpConversion final : public OpConversionPattern<func::CallOp> {
52public:
53 using OpConversionPattern<func::CallOp>::OpConversionPattern;
54
55 LogicalResult
56 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
57 ConversionPatternRewriter &rewriter) const override {
58 // Multiple results func cannot be converted to `emitc.func`.
59 if (callOp.getNumResults() > 1)
60 return rewriter.notifyMatchFailure(
61 callOp, "only functions with zero or one result can be converted");
62
63 rewriter.replaceOpWithNewOp<emitc::CallOp>(callOp, callOp.getResultTypes(),
64 adaptor.getOperands(),
65 callOp->getAttrs());
66
67 return success();
68 }
69};
70
71class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
72public:
73 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
74
75 LogicalResult
76 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override {
78 FunctionType fnType = funcOp.getFunctionType();
79
80 if (fnType.getNumResults() > 1)
81 return rewriter.notifyMatchFailure(
82 funcOp, "only functions with zero or one result can be converted");
83
84 TypeConverter::SignatureConversion signatureConverter(
85 fnType.getNumInputs());
86 for (const auto &argType : enumerate(fnType.getInputs())) {
87 auto convertedType = getTypeConverter()->convertType(argType.value());
88 if (!convertedType)
89 return rewriter.notifyMatchFailure(funcOp,
90 "argument type conversion failed");
91 signatureConverter.addInputs(argType.index(), convertedType);
92 }
93
94 Type resultType;
95 if (fnType.getNumResults() == 1) {
96 resultType = getTypeConverter()->convertType(fnType.getResult(0));
97 if (!resultType)
98 return rewriter.notifyMatchFailure(funcOp,
99 "result type conversion failed");
100 }
101
102 // Create the converted `emitc.func` op.
103 emitc::FuncOp newFuncOp = emitc::FuncOp::create(
104 rewriter, funcOp.getLoc(), funcOp.getName(),
105 FunctionType::get(rewriter.getContext(),
106 signatureConverter.getConvertedTypes(),
107 resultType ? TypeRange(resultType) : TypeRange()));
108
109 // Copy over all attributes other than the function name and type.
110 for (const auto &namedAttr : funcOp->getAttrs()) {
111 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
112 namedAttr.getName() != SymbolTable::getSymbolAttrName())
113 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
114 }
115
116 // Add `extern` to specifiers if `func.func` is declaration only.
117 if (funcOp.isDeclaration()) {
118 ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
119 newFuncOp.setSpecifiersAttr(specifiers);
120 }
121
122 // Add `static` to specifiers if `func.func` is private but not a
123 // declaration.
124 if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
125 ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
126 newFuncOp.setSpecifiersAttr(specifiers);
127 }
128
129 if (!funcOp.isDeclaration()) {
130 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
131 newFuncOp.end());
132 if (failed(rewriter.convertRegionTypes(
133 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
134 return failure();
135 }
136 rewriter.eraseOp(funcOp);
137
138 return success();
139 }
140};
141
142class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
143public:
144 using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
145
146 LogicalResult
147 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
148 ConversionPatternRewriter &rewriter) const override {
149 if (returnOp.getNumOperands() > 1)
150 return rewriter.notifyMatchFailure(
151 returnOp, "only zero or one operand is supported");
152
153 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
154 returnOp,
155 returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
156 return success();
157 }
158};
159} // namespace
160
161//===----------------------------------------------------------------------===//
162// Pattern population
163//===----------------------------------------------------------------------===//
164
166 RewritePatternSet &patterns) {
167 MLIRContext *ctx = patterns.getContext();
168
169 patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
170 typeConverter, ctx);
171}
return success()
ArrayAttr()
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void registerConvertFuncToEmitCInterface(DialectRegistry &registry)
void populateFuncToEmitCPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)