22#include "llvm/ADT/StringExtras.h"
23#include "llvm/Support/LogicalResult.h"
38static FailureOr<emitc::OpaqueType>
39getOrCreateMultiReturnType(ConversionPatternRewriter &rewriter,
Location loc,
43 std::string structName =
"return";
44 for (
Type type : types) {
46 llvm::raw_string_ostream os(typeName);
49 typeName.begin(), typeName.end(),
50 [](
char c) { return !llvm::isAlnum(c); },
'_');
51 structName +=
"_" + typeName;
58 while (insertBefore->
getParentOp() != symbolTableOp)
62 auto classOp = dyn_cast<emitc::ClassOp>(sym);
64 return emitError(loc) <<
"symbol '" << structName
65 <<
"' exists but is not an emitc.class";
67 if (classOp.getClassType() != emitc::ClassType::struct_)
69 <<
"existing class '" << structName <<
"' is not a struct";
72 for (
Operation &bodyOp : classOp.getBody().front()) {
73 if (isa<emitc::FuncOp>(bodyOp))
74 return emitError(loc) <<
"existing class '" << structName
75 <<
"' has methods; expected a plain struct";
76 if (
auto fieldOp = dyn_cast<emitc::FieldOp>(bodyOp))
77 fields.push_back(fieldOp);
79 if (fields.size() != types.size())
80 return emitError(loc) <<
"existing class '" << structName
81 <<
"' has wrong number of fields";
82 for (
auto [i, fieldOp] : llvm::enumerate(fields)) {
83 if (fieldOp.getSymName() !=
"field" + std::to_string(i))
84 return emitError(loc) <<
"existing class '" << structName
85 <<
"': unexpected field name at index " << i;
86 if (fieldOp.getTypeAttr().getValue() != types[i])
87 return emitError(loc) <<
"existing class '" << structName
88 <<
"': wrong type for field " << i;
93 auto savedIP = rewriter.saveInsertionPoint();
94 rewriter.setInsertionPoint(insertBefore);
96 emitc::ClassOp classOp = emitc::ClassOp::create(rewriter, loc, structName,
98 emitc::ClassType::struct_);
99 rewriter.createBlock(&classOp.getBody());
100 rewriter.setInsertionPointToStart(&classOp.getBody().front());
102 for (
auto [i, type] : llvm::enumerate(types)) {
103 auto fieldName = rewriter.getStringAttr(
"field" + std::to_string(i));
104 emitc::FieldOp::create(rewriter, loc, fieldName, TypeAttr::get(type),
108 rewriter.restoreInsertionPoint(savedIP);
110 return emitc::OpaqueType::get(rewriter.getContext(),
"struct " + structName);
115static Value packValuesIntoStruct(ConversionPatternRewriter &rewriter,
117 emitc::OpaqueType structType) {
119 auto noInit = emitc::OpaqueAttr::get(ctx,
"");
121 emitc::VariableOp::create(rewriter, loc,
122 emitc::LValueType::get(structType), noInit)
124 for (
auto [i, val] : llvm::enumerate(values)) {
126 emitc::MemberOp::create(
127 rewriter, loc, emitc::LValueType::get(val.getType()),
128 rewriter.getStringAttr(
"field" + std::to_string(i)), structLv)
130 emitc::AssignOp::create(rewriter, loc, fieldLv, val);
132 return emitc::LoadOp::create(rewriter, loc, structType, structLv).getResult();
136struct FuncToEmitCDialectInterface :
public ConvertToEmitCPatternInterface {
137 FuncToEmitCDialectInterface(Dialect *dialect)
138 : ConvertToEmitCPatternInterface(dialect) {}
142 void populateConvertToEmitCConversionPatterns(
143 ConversionTarget &
target, TypeConverter &typeConverter,
144 RewritePatternSet &patterns, std::optional<bool> lowerToCpp)
const final {
146 lowerToCpp.value_or(
true));
153 dialect->addInterfaces<FuncToEmitCDialectInterface>();
162class CallOpConversion final :
public OpConversionPattern<func::CallOp> {
166 : OpConversionPattern<
func::CallOp>(typeConverter, ctx),
167 lowerToCpp(lowerToCpp) {}
170 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const override {
174 if (callOp.getNumResults() > 1 && lowerToCpp)
175 return rewriter.notifyMatchFailure(
176 callOp,
"only functions with zero or one result can be converted");
178 SmallVector<Type> convertedResultTypes;
179 for (Type t : callOp.getResultTypes()) {
180 Type resultType = getTypeConverter()->convertType(t);
182 return rewriter.notifyMatchFailure(callOp,
183 "result type conversion failed");
184 if (isa<emitc::ArrayType>(resultType))
185 return rewriter.notifyMatchFailure(
186 callOp,
"function calls returning arrays are not supported");
187 convertedResultTypes.push_back(resultType);
190 if (callOp.getNumResults() <= 1) {
191 rewriter.replaceOpWithNewOp<emitc::CallOp>(
192 callOp, callOp.getResultTypes(), adaptor.getOperands(),
198 Location loc = callOp.getLoc();
201 getOrCreateMultiReturnType(rewriter, loc, callOp, convertedResultTypes);
203 return rewriter.notifyMatchFailure(callOp,
204 "incompatible multi-return struct");
208 emitc::CallOp::create(rewriter, loc, callOp.getCalleeAttr(),
209 TypeRange{*structType}, adaptor.getOperands())
213 SmallVector<Value> results;
214 for (
auto [i,
result] : llvm::enumerate(callOp.getResults())) {
216 results.push_back(Value());
219 Type fieldType = convertedResultTypes[i];
220 StringAttr fieldName =
221 rewriter.getStringAttr(
"field" + std::to_string(i));
222 Value fieldValue = emitc::MemberOp::create(rewriter, loc, fieldType,
223 fieldName, structVal)
225 results.push_back(fieldValue);
228 rewriter.replaceOp(callOp, results);
236class FuncOpConversion final :
public OpConversionPattern<func::FuncOp> {
238 FuncOpConversion(
const TypeConverter &typeConverter, MLIRContext *ctx,
240 : OpConversionPattern<func::FuncOp>(typeConverter, ctx),
241 lowerToCpp(lowerToCpp) {}
244 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter)
const override {
246 FunctionType fnType = funcOp.getFunctionType();
250 if (fnType.getNumResults() > 1 && lowerToCpp)
251 return rewriter.notifyMatchFailure(
252 funcOp,
"only functions with zero or one result can be converted");
254 TypeConverter::SignatureConversion signatureConverter(
255 fnType.getNumInputs());
256 for (
const auto &argType :
enumerate(fnType.getInputs())) {
257 auto convertedType = getTypeConverter()->convertType(argType.value());
259 return rewriter.notifyMatchFailure(funcOp,
260 "argument type conversion failed");
261 signatureConverter.addInputs(argType.index(), convertedType);
264 SmallVector<Type> convertedResultTypes;
265 for (Type t : fnType.getResults()) {
266 Type resultType = getTypeConverter()->convertType(t);
268 return rewriter.notifyMatchFailure(funcOp,
269 "result type conversion failed");
270 if (isa<emitc::ArrayType>(resultType))
271 return rewriter.notifyMatchFailure(
272 funcOp,
"functions returning arrays are not supported");
273 convertedResultTypes.push_back(resultType);
277 if (fnType.getNumResults() == 1) {
278 resultType = convertedResultTypes[0];
279 }
else if (fnType.getNumResults() > 1) {
280 auto structTypeOrErr = getOrCreateMultiReturnType(
281 rewriter, funcOp.getLoc(), funcOp, convertedResultTypes);
282 if (
failed(structTypeOrErr))
283 return rewriter.notifyMatchFailure(funcOp,
284 "incompatible multi-return struct");
285 resultType = *structTypeOrErr;
289 emitc::FuncOp newFuncOp = emitc::FuncOp::create(
290 rewriter, funcOp.getLoc(), funcOp.getName(),
291 FunctionType::get(rewriter.getContext(),
292 signatureConverter.getConvertedTypes(),
296 for (
const auto &namedAttr : funcOp->getAttrs()) {
297 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
299 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
303 if (funcOp.isDeclaration()) {
304 ArrayAttr specifiers = rewriter.getStrArrayAttr({
"extern"});
305 newFuncOp.setSpecifiersAttr(specifiers);
310 if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
311 ArrayAttr specifiers = rewriter.getStrArrayAttr({
"static"});
312 newFuncOp.setSpecifiersAttr(specifiers);
315 if (!funcOp.isDeclaration()) {
316 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
318 if (
failed(rewriter.convertRegionTypes(
319 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
322 rewriter.eraseOp(funcOp);
331class ReturnOpConversion final :
public OpConversionPattern<func::ReturnOp> {
333 ReturnOpConversion(
const TypeConverter &typeConverter, MLIRContext *ctx,
335 : OpConversionPattern<func::ReturnOp>(typeConverter, ctx),
336 lowerToCpp(lowerToCpp) {}
339 matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter)
const override {
343 if (returnOp.getNumOperands() > 1 && lowerToCpp)
344 return rewriter.notifyMatchFailure(
345 returnOp,
"only zero or one operand is supported");
347 if (llvm::any_of(adaptor.getOperands(), [](Value operand) {
348 return isa<emitc::ArrayType>(operand.getType());
350 return rewriter.notifyMatchFailure(returnOp,
351 "returning arrays is not supported");
353 if (returnOp.getNumOperands() <= 1) {
354 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
356 returnOp.getNumOperands() ? adaptor.getOperands()[0] :
nullptr);
361 Location loc = returnOp.getLoc();
363 auto structType = getOrCreateMultiReturnType(
364 rewriter, loc, returnOp, adaptor.getOperands().getTypes());
366 return rewriter.notifyMatchFailure(returnOp,
367 "incompatible multi-return struct");
370 packValuesIntoStruct(rewriter, loc, adaptor.getOperands(), *structType);
371 rewriter.replaceOpWithNewOp<emitc::ReturnOp>(returnOp, structVal);
389 patterns.
add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(
390 typeConverter, ctx, lowerToCpp);
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateFuncToEmitCPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool lowerToCpp=true)
void registerConvertFuncToEmitCInterface(DialectRegistry ®istry)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.