24#include "llvm/Support/LogicalResult.h"
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
37struct SCFToEmitCDialectInterface :
public ConvertToEmitCPatternInterface {
38 SCFToEmitCDialectInterface(Dialect *dialect)
39 : ConvertToEmitCPatternInterface(dialect) {}
43 void populateConvertToEmitCConversionPatterns(
44 ConversionTarget &
target, TypeConverter &typeConverter,
45 RewritePatternSet &patterns)
const final {
54 dialect->addInterfaces<SCFToEmitCDialectInterface>();
60struct SCFToEmitCPass :
public impl::SCFToEmitCBase<SCFToEmitCPass> {
61 void runOnOperation()
override;
66struct ForLowering :
public OpConversionPattern<ForOp> {
67 using OpConversionPattern<ForOp>::OpConversionPattern;
70 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter)
const override;
77createVariablesForResults(T op,
const TypeConverter *typeConverter,
78 ConversionPatternRewriter &rewriter,
80 if (!op.getNumResults())
87 rewriter.setInsertionPoint(op);
90 Type resultType = typeConverter->convertType(
result.getType());
92 return rewriter.notifyMatchFailure(op,
"result type conversion failed");
93 if (isa<emitc::ArrayType>(resultType))
94 return rewriter.notifyMatchFailure(
95 op,
"cannot create variable for result of array type");
96 Type varType = emitc::LValueType::get(resultType);
97 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
98 emitc::VariableOp var =
99 emitc::VariableOp::create(rewriter, loc, varType, noInit);
100 resultVariables.push_back(var);
109 ConversionPatternRewriter &rewriter,
Location loc) {
110 for (
auto [value, var] : llvm::zip(values, variables))
111 emitc::AssignOp::create(rewriter, loc, var, value);
116 return llvm::map_to_vector<>(variables, [&](
Value var) {
117 Type type = cast<emitc::LValueType>(var.
getType()).getValueType();
118 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
123 ConversionPatternRewriter &rewriter,
124 scf::YieldOp yield,
bool createYield =
true) {
128 rewriter.setInsertionPoint(yield);
131 if (
failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
132 return rewriter.notifyMatchFailure(op,
"failed to lower yield operands");
134 assignValues(yieldOperands, resultVariables, rewriter, loc);
136 emitc::YieldOp::create(rewriter, loc);
137 rewriter.eraseOp(yield);
148 ConversionPatternRewriter &rewriter,
150 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
152 return lowerYield(op, resultVariables, rewriter,
153 cast<scf::YieldOp>(terminator));
157ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter)
const {
159 Location loc = forOp.getLoc();
161 if (forOp.getUnsignedCmp())
162 return rewriter.notifyMatchFailure(forOp,
163 "unsigned loops are not supported");
167 SmallVector<Value> resultVariables;
168 if (
failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
170 return rewriter.notifyMatchFailure(forOp,
171 "create variables for results failed");
173 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
175 emitc::ForOp loweredFor =
176 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
177 adaptor.getUpperBound(), adaptor.getStep());
179 Block *loweredBody = loweredFor.getBody();
184 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
185 rewriter.setInsertionPointToEnd(loweredBody);
187 SmallVector<Value> iterArgsValues =
188 loadValues(resultVariables, rewriter, loc);
190 rewriter.restoreInsertionPoint(ip);
194 if (
failed(rewriter.convertRegionTypes(&forOp.getRegion(),
195 *getTypeConverter(),
nullptr))) {
196 return rewriter.notifyMatchFailure(forOp,
"region types conversion failed");
201 Block *scfBody = &(forOp.getRegion().front());
202 SmallVector<Value> replacingValues;
203 replacingValues.push_back(loweredFor.getInductionVar());
204 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
205 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
207 auto result = lowerYield(forOp, resultVariables, rewriter,
215 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
217 rewriter.replaceOp(forOp, resultValues);
223struct IfLowering :
public OpConversionPattern<IfOp> {
224 using OpConversionPattern<IfOp>::OpConversionPattern;
227 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter)
const override;
234IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter)
const {
236 Location loc = ifOp.getLoc();
240 SmallVector<Value> resultVariables;
241 if (
failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
243 return rewriter.notifyMatchFailure(ifOp,
244 "create variables for results failed");
251 auto lowerRegion = [&resultVariables, &rewriter,
252 &ifOp](Region ®ion, Region &loweredRegion) {
253 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
255 auto result = lowerYield(ifOp, resultVariables, rewriter,
256 cast<scf::YieldOp>(terminator));
263 Region &thenRegion = adaptor.getThenRegion();
264 Region &elseRegion = adaptor.getElseRegion();
266 bool hasElseBlock = !elseRegion.
empty();
269 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(),
false,
false);
271 Region &loweredThenRegion = loweredIf.getThenRegion();
272 auto result = lowerRegion(thenRegion, loweredThenRegion);
278 Region &loweredElseRegion = loweredIf.getElseRegion();
279 auto result = lowerRegion(elseRegion, loweredElseRegion);
285 rewriter.setInsertionPointAfter(ifOp);
286 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
288 rewriter.replaceOp(ifOp, results);
295 using OpConversionPattern::OpConversionPattern;
299 ConversionPatternRewriter &rewriter)
const override;
303 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const {
305 Location loc = indexSwitchOp.getLoc();
310 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
311 rewriter, resultVariables))) {
312 return rewriter.notifyMatchFailure(indexSwitchOp,
313 "create variables for results failed");
317 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
318 adaptor.getCases(), indexSwitchOp.getNumCases());
322 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
323 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
324 *std::get<0>(pair), std::get<1>(pair)))) {
330 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
331 adaptor.getDefaultRegion(),
332 loweredSwitch.getDefaultRegion()))) {
336 rewriter.setInsertionPointAfter(indexSwitchOp);
339 rewriter.replaceOp(indexSwitchOp, results);
347 using OpConversionPattern::OpConversionPattern;
351 ConversionPatternRewriter &rewriter)
const override {
358 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
360 return rewriter.notifyMatchFailure(whileOp,
361 "Failed to create result variables");
366 if (failed(createVariablesForLoopCarriedValues(
367 whileOp, rewriter, loopVariables, loc, context)))
370 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
374 rewriter.setInsertionPointAfter(whileOp);
378 loadValues(resultVariables, rewriter, loc);
379 rewriter.replaceOp(whileOp, finalResults);
387 LogicalResult createVariablesForLoopCarriedValues(
388 WhileOp whileOp, ConversionPatternRewriter &rewriter,
392 rewriter.setInsertionPoint(whileOp);
394 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
396 for (
Value init : whileOp.getInits()) {
397 Type convertedType = getTypeConverter()->convertType(init.getType());
399 return rewriter.notifyMatchFailure(whileOp,
"type conversion failed");
400 if (isa<emitc::ArrayType>(convertedType))
401 return rewriter.notifyMatchFailure(
403 "cannot create variable for loop-carried value of array type");
405 auto var = emitc::VariableOp::create(
406 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
407 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
408 loopVars.push_back(var);
415 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
416 ArrayRef<Value> resultVars, MLIRContext *context,
417 ConversionPatternRewriter &rewriter,
418 Location loc)
const {
420 Type i1Type = IntegerType::get(context, 1);
421 auto globalCondition =
422 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
423 emitc::OpaqueAttr::get(context,
""));
424 Value conditionVal = globalCondition.getResult();
426 auto loweredDo = emitc::DoOp::create(rewriter, loc);
429 if (
failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
430 *getTypeConverter(),
nullptr)) ||
431 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
432 *getTypeConverter(),
nullptr))) {
433 return rewriter.notifyMatchFailure(whileOp,
434 "region types conversion failed");
438 Block *beforeBlock = &whileOp.getBefore().front();
439 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
440 rewriter.setInsertionPointToStart(bodyBlock);
444 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
445 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
447 Operation *condTerminator =
448 loweredDo.getBodyRegion().back().getTerminator();
449 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
450 rewriter.setInsertionPoint(condOp);
453 SmallVector<Value> conditionArgs;
454 for (Value arg : condOp.getArgs()) {
455 conditionArgs.push_back(rewriter.getRemappedValue(arg));
457 assignValues(conditionArgs, resultVars, rewriter, loc);
460 Value condition = rewriter.getRemappedValue(condOp.getCondition());
461 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
465 if (whileOp.getAfterBody()->getOperations().size() > 1) {
466 auto ifOp = emitc::IfOp::create(rewriter, loc, condition,
false,
false);
469 Block *afterBlock = &whileOp.getAfter().front();
470 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
473 SmallVector<Value> afterReplacingValues;
474 for (Value arg : condOp.getArgs())
475 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
477 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
479 if (
failed(lowerYield(whileOp, loopVars, rewriter,
484 rewriter.eraseOp(condOp);
487 Region &condRegion = loweredDo.getConditionRegion();
488 Block *condBlock = rewriter.createBlock(&condRegion);
489 rewriter.setInsertionPointToStart(condBlock);
491 auto exprOp = emitc::ExpressionOp::create(
492 rewriter, loc, i1Type, conditionVal,
false);
493 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
497 rewriter.setInsertionPointToStart(exprBlock);
501 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->
getArgument(0));
502 emitc::YieldOp::create(rewriter, loc, cond);
505 rewriter.setInsertionPointToEnd(condBlock);
506 emitc::YieldOp::create(rewriter, loc, exprOp);
514 patterns.
add<ForLowering>(typeConverter, patterns.
getContext());
515 patterns.
add<IfLowering>(typeConverter, patterns.
getContext());
520void SCFToEmitCPass::runOnOperation() {
524 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
535 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
536 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
538 applyPartialConversion(getOperation(),
target, std::move(patterns))))
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override