MLIR 23.0.0git
FuncToEmitC.cpp
Go to the documentation of this file.
1//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Func dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
20
21using namespace mlir;
22
23namespace {
24
25/// Implement the interface to convert Func to EmitC.
26struct FuncToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
28
29 /// Hook for derived dialect interface to provide conversion patterns
30 /// and mark dialect legal for the conversion target.
31 void populateConvertToEmitCConversionPatterns(
32 ConversionTarget &target, TypeConverter &typeConverter,
33 RewritePatternSet &patterns) const final {
34 populateFuncToEmitCPatterns(typeConverter, patterns);
35 }
36};
37} // namespace
38
40 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
41 dialect->addInterfaces<FuncToEmitCDialectInterface>();
42 });
43}
44
45//===----------------------------------------------------------------------===//
46// Conversion Patterns
47//===----------------------------------------------------------------------===//
48
49namespace {
50class CallOpConversion final : public OpConversionPattern<func::CallOp> {
51public:
52 using OpConversionPattern<func::CallOp>::OpConversionPattern;
53
54 LogicalResult
55 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 // Multiple results func cannot be converted to `emitc.func`.
58 if (callOp.getNumResults() > 1)
59 return rewriter.notifyMatchFailure(
60 callOp, "only functions with zero or one result can be converted");
61
62 rewriter.replaceOpWithNewOp<emitc::CallOp>(callOp, callOp.getResultTypes(),
63 adaptor.getOperands(),
64 callOp->getAttrs());
65
66 return success();
67 }
68};
69
70class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
71public:
72 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
73
74 LogicalResult
75 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override {
77 FunctionType fnType = funcOp.getFunctionType();
78
79 if (fnType.getNumResults() > 1)
80 return rewriter.notifyMatchFailure(
81 funcOp, "only functions with zero or one result can be converted");
82
83 TypeConverter::SignatureConversion signatureConverter(
84 fnType.getNumInputs());
85 for (const auto &argType : enumerate(fnType.getInputs())) {
86 auto convertedType = getTypeConverter()->convertType(argType.value());
87 if (!convertedType)
88 return rewriter.notifyMatchFailure(funcOp,
89 "argument type conversion failed");
90 signatureConverter.addInputs(argType.index(), convertedType);
91 }
92
93 Type resultType;
94 if (fnType.getNumResults() == 1) {
95 resultType = getTypeConverter()->convertType(fnType.getResult(0));
96 if (!resultType)
97 return rewriter.notifyMatchFailure(funcOp,
98 "result type conversion failed");
99 }
100
101 // Create the converted `emitc.func` op.
102 emitc::FuncOp newFuncOp = emitc::FuncOp::create(
103 rewriter, funcOp.getLoc(), funcOp.getName(),
104 FunctionType::get(rewriter.getContext(),
105 signatureConverter.getConvertedTypes(),
106 resultType ? TypeRange(resultType) : TypeRange()));
107
108 // Copy over all attributes other than the function name and type.
109 for (const auto &namedAttr : funcOp->getAttrs()) {
110 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
111 namedAttr.getName() != SymbolTable::getSymbolAttrName())
112 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
113 }
114
115 // Add `extern` to specifiers if `func.func` is declaration only.
116 if (funcOp.isDeclaration()) {
117 ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
118 newFuncOp.setSpecifiersAttr(specifiers);
119 }
120
121 // Add `static` to specifiers if `func.func` is private but not a
122 // declaration.
123 if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
124 ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
125 newFuncOp.setSpecifiersAttr(specifiers);
126 }
127
128 if (!funcOp.isDeclaration()) {
129 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
130 newFuncOp.end());
131 if (failed(rewriter.convertRegionTypes(
132 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
133 return failure();
134 }
135 rewriter.eraseOp(funcOp);
136
137 return success();
138 }
139};
140
141class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
142public:
143 using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
144
145 LogicalResult
146 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 if (returnOp.getNumOperands() > 1)
149 return rewriter.notifyMatchFailure(
150 returnOp, "only zero or one operand is supported");
151
152 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
153 returnOp,
154 returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
155 return success();
156 }
157};
158} // namespace
159
160//===----------------------------------------------------------------------===//
161// Pattern population
162//===----------------------------------------------------------------------===//
163
165 RewritePatternSet &patterns) {
166 MLIRContext *ctx = patterns.getContext();
167
168 patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
169 typeConverter, ctx);
170}
return success()
ArrayAttr()
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void registerConvertFuncToEmitCInterface(DialectRegistry &registry)
void populateFuncToEmitCPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)