25#include "llvm/Support/LogicalResult.h"
28#define GEN_PASS_DEF_SCFTOEMITC
29#include "mlir/Conversion/Passes.h.inc"
38struct SCFToEmitCDialectInterface :
public ConvertToEmitCPatternInterface {
39 SCFToEmitCDialectInterface(Dialect *dialect)
40 : ConvertToEmitCPatternInterface(dialect) {}
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &
target, TypeConverter &typeConverter,
46 RewritePatternSet &patterns, std::optional<bool> lowerToCpp)
const final {
55 dialect->addInterfaces<SCFToEmitCDialectInterface>();
61struct SCFToEmitCPass :
public impl::SCFToEmitCBase<SCFToEmitCPass> {
62 void runOnOperation()
override;
67struct ForLowering :
public OpConversionPattern<ForOp> {
68 using OpConversionPattern<ForOp>::OpConversionPattern;
71 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter)
const override;
78createVariablesForResults(T op,
const TypeConverter *typeConverter,
79 ConversionPatternRewriter &rewriter,
81 if (!op.getNumResults())
88 rewriter.setInsertionPoint(op);
91 Type resultType = typeConverter->convertType(
result.getType());
93 return rewriter.notifyMatchFailure(op,
"result type conversion failed");
94 if (isa<emitc::ArrayType>(resultType))
95 return rewriter.notifyMatchFailure(
96 op,
"cannot create variable for result of array type");
97 Type varType = emitc::LValueType::get(resultType);
98 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
99 emitc::VariableOp var =
100 emitc::VariableOp::create(rewriter, loc, varType, noInit);
101 resultVariables.push_back(var);
110 ConversionPatternRewriter &rewriter,
Location loc) {
111 for (
auto [value, var] : llvm::zip(values, variables))
112 emitc::AssignOp::create(rewriter, loc, var, value);
117 return llvm::map_to_vector<>(variables, [&](
Value var) {
118 Type type = cast<emitc::LValueType>(var.
getType()).getValueType();
119 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
124 ConversionPatternRewriter &rewriter,
125 scf::YieldOp yield,
bool createYield =
true) {
129 rewriter.setInsertionPoint(yield);
132 if (
failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
133 return rewriter.notifyMatchFailure(op,
"failed to lower yield operands");
135 assignValues(yieldOperands, resultVariables, rewriter, loc);
137 emitc::YieldOp::create(rewriter, loc);
138 rewriter.eraseOp(yield);
149 ConversionPatternRewriter &rewriter,
151 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
153 return lowerYield(op, resultVariables, rewriter,
154 cast<scf::YieldOp>(terminator));
158ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
159 ConversionPatternRewriter &rewriter)
const {
160 Location loc = forOp.getLoc();
162 if (forOp.getUnsignedCmp())
163 return rewriter.notifyMatchFailure(forOp,
164 "unsigned loops are not supported");
168 SmallVector<Value> resultVariables;
169 if (
failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
171 return rewriter.notifyMatchFailure(forOp,
172 "create variables for results failed");
174 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
176 emitc::ForOp loweredFor =
177 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
178 adaptor.getUpperBound(), adaptor.getStep());
180 Block *loweredBody = loweredFor.getBody();
185 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
186 rewriter.setInsertionPointToEnd(loweredBody);
188 SmallVector<Value> iterArgsValues =
189 loadValues(resultVariables, rewriter, loc);
191 rewriter.restoreInsertionPoint(ip);
195 if (
failed(rewriter.convertRegionTypes(&forOp.getRegion(),
196 *getTypeConverter(),
nullptr))) {
197 return rewriter.notifyMatchFailure(forOp,
"region types conversion failed");
202 Block *scfBody = &(forOp.getRegion().front());
203 SmallVector<Value> replacingValues;
204 replacingValues.push_back(loweredFor.getInductionVar());
205 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
206 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
208 auto result = lowerYield(forOp, resultVariables, rewriter,
216 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
218 rewriter.replaceOp(forOp, resultValues);
224struct IfLowering :
public OpConversionPattern<IfOp> {
225 using OpConversionPattern<IfOp>::OpConversionPattern;
228 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override;
235IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const {
237 Location loc = ifOp.getLoc();
241 SmallVector<Value> resultVariables;
242 if (
failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
244 return rewriter.notifyMatchFailure(ifOp,
245 "create variables for results failed");
252 auto lowerRegion = [&resultVariables, &rewriter,
253 &ifOp](Region ®ion, Region &loweredRegion) {
254 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
256 auto result = lowerYield(ifOp, resultVariables, rewriter,
257 cast<scf::YieldOp>(terminator));
264 Region &thenRegion = adaptor.getThenRegion();
265 Region &elseRegion = adaptor.getElseRegion();
267 bool hasElseBlock = !elseRegion.
empty();
270 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(),
false,
false);
272 Region &loweredThenRegion = loweredIf.getThenRegion();
273 auto result = lowerRegion(thenRegion, loweredThenRegion);
279 Region &loweredElseRegion = loweredIf.getElseRegion();
280 auto result = lowerRegion(elseRegion, loweredElseRegion);
286 rewriter.setInsertionPointAfter(ifOp);
287 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
289 rewriter.replaceOp(ifOp, results);
296 using OpConversionPattern::OpConversionPattern;
300 ConversionPatternRewriter &rewriter)
const override;
304 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const {
306 Location loc = indexSwitchOp.getLoc();
311 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
312 rewriter, resultVariables))) {
313 return rewriter.notifyMatchFailure(indexSwitchOp,
314 "create variables for results failed");
318 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
319 adaptor.getCases(), indexSwitchOp.getNumCases());
323 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
324 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
325 *std::get<0>(pair), std::get<1>(pair)))) {
331 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
332 adaptor.getDefaultRegion(),
333 loweredSwitch.getDefaultRegion()))) {
337 rewriter.setInsertionPointAfter(indexSwitchOp);
340 rewriter.replaceOp(indexSwitchOp, results);
348 using OpConversionPattern::OpConversionPattern;
352 ConversionPatternRewriter &rewriter)
const override {
359 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
361 return rewriter.notifyMatchFailure(whileOp,
362 "Failed to create result variables");
367 if (failed(createVariablesForLoopCarriedValues(
368 whileOp, rewriter, loopVariables, loc, context)))
371 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
375 rewriter.setInsertionPointAfter(whileOp);
379 loadValues(resultVariables, rewriter, loc);
380 rewriter.replaceOp(whileOp, finalResults);
388 LogicalResult createVariablesForLoopCarriedValues(
389 WhileOp whileOp, ConversionPatternRewriter &rewriter,
393 rewriter.setInsertionPoint(whileOp);
395 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
397 for (
Value init : whileOp.getInits()) {
398 Type convertedType = getTypeConverter()->convertType(init.getType());
400 return rewriter.notifyMatchFailure(whileOp,
"type conversion failed");
401 if (isa<emitc::ArrayType>(convertedType))
402 return rewriter.notifyMatchFailure(
404 "cannot create variable for loop-carried value of array type");
406 auto var = emitc::VariableOp::create(
407 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
408 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
409 loopVars.push_back(var);
416 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
417 ArrayRef<Value> resultVars, MLIRContext *context,
418 ConversionPatternRewriter &rewriter,
419 Location loc)
const {
421 Type i1Type = IntegerType::get(context, 1);
422 auto globalCondition =
423 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
424 emitc::OpaqueAttr::get(context,
""));
425 Value conditionVal = globalCondition.getResult();
427 auto loweredDo = emitc::DoOp::create(rewriter, loc);
430 if (
failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
431 *getTypeConverter(),
nullptr)) ||
432 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
433 *getTypeConverter(),
nullptr))) {
434 return rewriter.notifyMatchFailure(whileOp,
435 "region types conversion failed");
439 Block *beforeBlock = &whileOp.getBefore().front();
440 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
441 rewriter.setInsertionPointToStart(bodyBlock);
445 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
446 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
448 Operation *condTerminator =
449 loweredDo.getBodyRegion().back().getTerminator();
450 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
451 rewriter.setInsertionPoint(condOp);
454 SmallVector<Value> conditionArgs;
455 for (Value arg : condOp.getArgs()) {
456 conditionArgs.push_back(rewriter.getRemappedValue(arg));
458 assignValues(conditionArgs, resultVars, rewriter, loc);
461 Value condition = rewriter.getRemappedValue(condOp.getCondition());
462 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
466 if (whileOp.getAfterBody()->getOperations().size() > 1) {
467 auto ifOp = emitc::IfOp::create(rewriter, loc, condition,
false,
false);
470 Block *afterBlock = &whileOp.getAfter().front();
471 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
474 SmallVector<Value> afterReplacingValues;
475 for (Value arg : condOp.getArgs())
476 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
478 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
480 if (
failed(lowerYield(whileOp, loopVars, rewriter,
485 rewriter.eraseOp(condOp);
488 Region &condRegion = loweredDo.getConditionRegion();
489 Block *condBlock = rewriter.createBlock(&condRegion);
490 rewriter.setInsertionPointToStart(condBlock);
492 auto exprOp = emitc::ExpressionOp::create(
493 rewriter, loc, i1Type, conditionVal,
false);
494 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
498 rewriter.setInsertionPointToStart(exprBlock);
502 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->
getArgument(0));
503 emitc::YieldOp::create(rewriter, loc, cond);
506 rewriter.setInsertionPointToEnd(condBlock);
507 emitc::YieldOp::create(rewriter, loc, exprOp);
515 patterns.
add<ForLowering>(typeConverter, patterns.
getContext());
516 patterns.
add<IfLowering>(typeConverter, patterns.
getContext());
521void SCFToEmitCPass::runOnOperation() {
530 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
531 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
533 applyPartialConversion(getOperation(),
target, std::move(patterns))))
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from common builtin types to the EmitC dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override