MLIR 23.0.0git
FuncToEmitC.cpp
Go to the documentation of this file.
1//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Func dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/IR/SymbolTable.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/Support/LogicalResult.h"
24
25using namespace mlir;
26
27namespace {
28
29//===----------------------------------------------------------------------===//
30// Multi-return struct helpers
31//===----------------------------------------------------------------------===//
32
33// Looks up or creates an `emitc.class` named after `types` in the nearest
34// enclosing symbol table of `op`, suitable for packing those types as plain
35// struct fields (field0, field1, ...). If the class already exists it is
36// verified to have exactly the right fields and no methods. Returns the
37// corresponding !emitc.opaque<"struct ..."> type on success.
38static FailureOr<emitc::OpaqueType>
39getOrCreateMultiReturnType(ConversionPatternRewriter &rewriter, Location loc,
40 Operation *op, TypeRange types) {
41 // Build the struct name from the types, e.g. "return_i32_i32". Each type is
42 // printed and non-alphanumeric characters are replaced with '_'.
43 std::string structName = "return";
44 for (Type type : types) {
45 std::string typeName;
46 llvm::raw_string_ostream os(typeName);
47 type.print(os);
48 std::replace_if(
49 typeName.begin(), typeName.end(),
50 [](char c) { return !llvm::isAlnum(c); }, '_');
51 structName += "_" + typeName;
52 }
53
54 // Find the enclosing symbol table and the direct child op within it that
55 // contains `op`; the class will be inserted immediately before that child.
57 Operation *insertBefore = op;
58 while (insertBefore->getParentOp() != symbolTableOp)
59 insertBefore = insertBefore->getParentOp();
60
61 if (Operation *sym = SymbolTable::lookupSymbolIn(symbolTableOp, structName)) {
62 auto classOp = dyn_cast<emitc::ClassOp>(sym);
63 if (!classOp)
64 return emitError(loc) << "symbol '" << structName
65 << "' exists but is not an emitc.class";
66
67 if (classOp.getClassType() != emitc::ClassType::struct_)
68 return emitError(loc)
69 << "existing class '" << structName << "' is not a struct";
70
72 for (Operation &bodyOp : classOp.getBody().front()) {
73 if (isa<emitc::FuncOp>(bodyOp))
74 return emitError(loc) << "existing class '" << structName
75 << "' has methods; expected a plain struct";
76 if (auto fieldOp = dyn_cast<emitc::FieldOp>(bodyOp))
77 fields.push_back(fieldOp);
78 }
79 if (fields.size() != types.size())
80 return emitError(loc) << "existing class '" << structName
81 << "' has wrong number of fields";
82 for (auto [i, fieldOp] : llvm::enumerate(fields)) {
83 if (fieldOp.getSymName() != "field" + std::to_string(i))
84 return emitError(loc) << "existing class '" << structName
85 << "': unexpected field name at index " << i;
86 if (fieldOp.getTypeAttr().getValue() != types[i])
87 return emitError(loc) << "existing class '" << structName
88 << "': wrong type for field " << i;
89 }
90 } else {
91 // Create the ClassOp before `insertBefore`, then restore the insertion
92 // point.
93 auto savedIP = rewriter.saveInsertionPoint();
94 rewriter.setInsertionPoint(insertBefore);
95
96 emitc::ClassOp classOp = emitc::ClassOp::create(rewriter, loc, structName,
97 /*final_specifier=*/false,
98 emitc::ClassType::struct_);
99 rewriter.createBlock(&classOp.getBody());
100 rewriter.setInsertionPointToStart(&classOp.getBody().front());
101
102 for (auto [i, type] : llvm::enumerate(types)) {
103 auto fieldName = rewriter.getStringAttr("field" + std::to_string(i));
104 emitc::FieldOp::create(rewriter, loc, fieldName, TypeAttr::get(type),
105 nullptr);
106 }
107
108 rewriter.restoreInsertionPoint(savedIP);
109 }
110 return emitc::OpaqueType::get(rewriter.getContext(), "struct " + structName);
111}
112
113// Packs multiple SSA values into an emitc.class struct variable and loads the
114// result as a single SSA value of the opaque struct type.
115static Value packValuesIntoStruct(ConversionPatternRewriter &rewriter,
116 Location loc, ValueRange values,
117 emitc::OpaqueType structType) {
118 MLIRContext *ctx = rewriter.getContext();
119 auto noInit = emitc::OpaqueAttr::get(ctx, "");
120 Value structLv =
121 emitc::VariableOp::create(rewriter, loc,
122 emitc::LValueType::get(structType), noInit)
123 .getResult();
124 for (auto [i, val] : llvm::enumerate(values)) {
125 Value fieldLv =
126 emitc::MemberOp::create(
127 rewriter, loc, emitc::LValueType::get(val.getType()),
128 rewriter.getStringAttr("field" + std::to_string(i)), structLv)
129 .getResult();
130 emitc::AssignOp::create(rewriter, loc, fieldLv, val);
131 }
132 return emitc::LoadOp::create(rewriter, loc, structType, structLv).getResult();
133}
134
135/// Implement the interface to convert Func to EmitC.
136struct FuncToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
137 FuncToEmitCDialectInterface(Dialect *dialect)
138 : ConvertToEmitCPatternInterface(dialect) {}
139
140 /// Hook for derived dialect interface to provide conversion patterns
141 /// and mark dialect legal for the conversion target.
142 void populateConvertToEmitCConversionPatterns(
143 ConversionTarget &target, TypeConverter &typeConverter,
144 RewritePatternSet &patterns, std::optional<bool> lowerToCpp) const final {
145 populateFuncToEmitCPatterns(typeConverter, patterns,
146 lowerToCpp.value_or(true));
147 }
148};
149} // namespace
150
152 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
153 dialect->addInterfaces<FuncToEmitCDialectInterface>();
154 });
155}
156
157//===----------------------------------------------------------------------===//
158// Conversion Patterns
159//===----------------------------------------------------------------------===//
160
161namespace {
162class CallOpConversion final : public OpConversionPattern<func::CallOp> {
163public:
164 CallOpConversion(const TypeConverter &typeConverter, MLIRContext *ctx,
165 bool lowerToCpp)
166 : OpConversionPattern<func::CallOp>(typeConverter, ctx),
167 lowerToCpp(lowerToCpp) {}
168
169 LogicalResult
170 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter) const override {
172 // Do not convert multiple-return functions if lowering target is Cpp.
173 // The translator will emit the return values as an std::tuple.
174 if (callOp.getNumResults() > 1 && lowerToCpp)
175 return rewriter.notifyMatchFailure(
176 callOp, "only functions with zero or one result can be converted");
177
178 SmallVector<Type> convertedResultTypes;
179 for (Type t : callOp.getResultTypes()) {
180 Type resultType = getTypeConverter()->convertType(t);
181 if (!resultType)
182 return rewriter.notifyMatchFailure(callOp,
183 "result type conversion failed");
184 if (isa<emitc::ArrayType>(resultType))
185 return rewriter.notifyMatchFailure(
186 callOp, "function calls returning arrays are not supported");
187 convertedResultTypes.push_back(resultType);
188 }
189
190 if (callOp.getNumResults() <= 1) {
191 rewriter.replaceOpWithNewOp<emitc::CallOp>(
192 callOp, callOp.getResultTypes(), adaptor.getOperands(),
193 callOp->getAttrs());
194 return success();
195 }
196
197 // Multi-result call: determine the struct type.
198 Location loc = callOp.getLoc();
199
200 auto structType =
201 getOrCreateMultiReturnType(rewriter, loc, callOp, convertedResultTypes);
202 if (failed(structType))
203 return rewriter.notifyMatchFailure(callOp,
204 "incompatible multi-return struct");
205
206 // Emit a call returning the packed struct.
207 Value structVal =
208 emitc::CallOp::create(rewriter, loc, callOp.getCalleeAttr(),
209 TypeRange{*structType}, adaptor.getOperands())
210 .getResult(0);
211
212 // Unpack struct fields to replace the original multiple results.
213 SmallVector<Value> results;
214 for (auto [i, result] : llvm::enumerate(callOp.getResults())) {
215 if (result.use_empty()) {
216 results.push_back(Value()); // No replacement needed.
217 continue;
218 }
219 Type fieldType = convertedResultTypes[i];
220 StringAttr fieldName =
221 rewriter.getStringAttr("field" + std::to_string(i));
222 Value fieldValue = emitc::MemberOp::create(rewriter, loc, fieldType,
223 fieldName, structVal)
224 .getResult();
225 results.push_back(fieldValue);
226 }
227
228 rewriter.replaceOp(callOp, results);
229 return success();
230 }
231
232private:
233 bool lowerToCpp;
234};
235
236class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
237public:
238 FuncOpConversion(const TypeConverter &typeConverter, MLIRContext *ctx,
239 bool lowerToCpp)
240 : OpConversionPattern<func::FuncOp>(typeConverter, ctx),
241 lowerToCpp(lowerToCpp) {}
242
243 LogicalResult
244 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 FunctionType fnType = funcOp.getFunctionType();
247
248 // Do not convert multiple-return functions if lowering target is Cpp.
249 // The translator will emit the return values as an std::tuple.
250 if (fnType.getNumResults() > 1 && lowerToCpp)
251 return rewriter.notifyMatchFailure(
252 funcOp, "only functions with zero or one result can be converted");
253
254 TypeConverter::SignatureConversion signatureConverter(
255 fnType.getNumInputs());
256 for (const auto &argType : enumerate(fnType.getInputs())) {
257 auto convertedType = getTypeConverter()->convertType(argType.value());
258 if (!convertedType)
259 return rewriter.notifyMatchFailure(funcOp,
260 "argument type conversion failed");
261 signatureConverter.addInputs(argType.index(), convertedType);
262 }
263
264 SmallVector<Type> convertedResultTypes;
265 for (Type t : fnType.getResults()) {
266 Type resultType = getTypeConverter()->convertType(t);
267 if (!resultType)
268 return rewriter.notifyMatchFailure(funcOp,
269 "result type conversion failed");
270 if (isa<emitc::ArrayType>(resultType))
271 return rewriter.notifyMatchFailure(
272 funcOp, "functions returning arrays are not supported");
273 convertedResultTypes.push_back(resultType);
274 }
275
276 Type resultType;
277 if (fnType.getNumResults() == 1) {
278 resultType = convertedResultTypes[0];
279 } else if (fnType.getNumResults() > 1) {
280 auto structTypeOrErr = getOrCreateMultiReturnType(
281 rewriter, funcOp.getLoc(), funcOp, convertedResultTypes);
282 if (failed(structTypeOrErr))
283 return rewriter.notifyMatchFailure(funcOp,
284 "incompatible multi-return struct");
285 resultType = *structTypeOrErr;
286 }
287
288 // Create the converted `emitc.func` op.
289 emitc::FuncOp newFuncOp = emitc::FuncOp::create(
290 rewriter, funcOp.getLoc(), funcOp.getName(),
291 FunctionType::get(rewriter.getContext(),
292 signatureConverter.getConvertedTypes(),
293 resultType ? TypeRange(resultType) : TypeRange()));
294
295 // Copy over all attributes other than the function name and type.
296 for (const auto &namedAttr : funcOp->getAttrs()) {
297 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
298 namedAttr.getName() != SymbolTable::getSymbolAttrName())
299 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
300 }
301
302 // Add `extern` to specifiers if `func.func` is declaration only.
303 if (funcOp.isDeclaration()) {
304 ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
305 newFuncOp.setSpecifiersAttr(specifiers);
306 }
307
308 // Add `static` to specifiers if `func.func` is private but not a
309 // declaration.
310 if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
311 ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
312 newFuncOp.setSpecifiersAttr(specifiers);
313 }
314
315 if (!funcOp.isDeclaration()) {
316 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
317 newFuncOp.end());
318 if (failed(rewriter.convertRegionTypes(
319 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
320 return failure();
321 }
322 rewriter.eraseOp(funcOp);
323
324 return success();
325 }
326
327private:
328 bool lowerToCpp;
329};
330
331class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
332public:
333 ReturnOpConversion(const TypeConverter &typeConverter, MLIRContext *ctx,
334 bool lowerToCpp)
335 : OpConversionPattern<func::ReturnOp>(typeConverter, ctx),
336 lowerToCpp(lowerToCpp) {}
337
338 LogicalResult
339 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 // Do not convert multiple-return functions if lowering target is Cpp.
342 // The translator will emit the return values as an std::tuple.
343 if (returnOp.getNumOperands() > 1 && lowerToCpp)
344 return rewriter.notifyMatchFailure(
345 returnOp, "only zero or one operand is supported");
346
347 if (llvm::any_of(adaptor.getOperands(), [](Value operand) {
348 return isa<emitc::ArrayType>(operand.getType());
349 }))
350 return rewriter.notifyMatchFailure(returnOp,
351 "returning arrays is not supported");
352
353 if (returnOp.getNumOperands() <= 1) {
354 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
355 returnOp,
356 returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
357 return success();
358 }
359
360 // Multi-operand return: pack values into a struct.
361 Location loc = returnOp.getLoc();
362
363 auto structType = getOrCreateMultiReturnType(
364 rewriter, loc, returnOp, adaptor.getOperands().getTypes());
365 if (failed(structType))
366 return rewriter.notifyMatchFailure(returnOp,
367 "incompatible multi-return struct");
368
369 Value structVal =
370 packValuesIntoStruct(rewriter, loc, adaptor.getOperands(), *structType);
371 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(returnOp, structVal);
372 return success();
373 }
374
375private:
376 bool lowerToCpp;
377};
378} // namespace
379
380//===----------------------------------------------------------------------===//
381// Pattern population
382//===----------------------------------------------------------------------===//
383
385 RewritePatternSet &patterns,
386 bool lowerToCpp) {
387 MLIRContext *ctx = patterns.getContext();
388
389 patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
390 typeConverter, ctx, lowerToCpp);
391}
return success()
ArrayAttr()
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateFuncToEmitCPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool lowerToCpp=true)
void registerConvertFuncToEmitCInterface(DialectRegistry &registry)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.