24#include "llvm/Support/LogicalResult.h"
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
42 void populateConvertToEmitCConversionPatterns(
43 ConversionTarget &
target, TypeConverter &typeConverter,
44 RewritePatternSet &patterns)
const final {
53 dialect->addInterfaces<SCFToEmitCDialectInterface>();
59struct SCFToEmitCPass :
public impl::SCFToEmitCBase<SCFToEmitCPass> {
60 void runOnOperation()
override;
65struct ForLowering :
public OpConversionPattern<ForOp> {
66 using OpConversionPattern<ForOp>::OpConversionPattern;
69 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter)
const override;
76createVariablesForResults(T op,
const TypeConverter *typeConverter,
77 ConversionPatternRewriter &rewriter,
79 if (!op.getNumResults())
86 rewriter.setInsertionPoint(op);
89 Type resultType = typeConverter->convertType(
result.getType());
91 return rewriter.notifyMatchFailure(op,
"result type conversion failed");
92 if (isa<emitc::ArrayType>(resultType))
93 return rewriter.notifyMatchFailure(
94 op,
"cannot create variable for result of array type");
95 Type varType = emitc::LValueType::get(resultType);
96 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
97 emitc::VariableOp var =
98 emitc::VariableOp::create(rewriter, loc, varType, noInit);
99 resultVariables.push_back(var);
108 ConversionPatternRewriter &rewriter,
Location loc) {
109 for (
auto [value, var] : llvm::zip(values, variables))
110 emitc::AssignOp::create(rewriter, loc, var, value);
115 return llvm::map_to_vector<>(variables, [&](
Value var) {
116 Type type = cast<emitc::LValueType>(var.
getType()).getValueType();
117 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
122 ConversionPatternRewriter &rewriter,
123 scf::YieldOp yield,
bool createYield =
true) {
127 rewriter.setInsertionPoint(yield);
130 if (
failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
131 return rewriter.notifyMatchFailure(op,
"failed to lower yield operands");
133 assignValues(yieldOperands, resultVariables, rewriter, loc);
135 emitc::YieldOp::create(rewriter, loc);
136 rewriter.eraseOp(yield);
147 ConversionPatternRewriter &rewriter,
149 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
151 return lowerYield(op, resultVariables, rewriter,
152 cast<scf::YieldOp>(terminator));
156ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter)
const {
158 Location loc = forOp.getLoc();
160 if (forOp.getUnsignedCmp())
161 return rewriter.notifyMatchFailure(forOp,
162 "unsigned loops are not supported");
166 SmallVector<Value> resultVariables;
167 if (
failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
169 return rewriter.notifyMatchFailure(forOp,
170 "create variables for results failed");
172 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
174 emitc::ForOp loweredFor =
175 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
176 adaptor.getUpperBound(), adaptor.getStep());
178 Block *loweredBody = loweredFor.getBody();
183 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
184 rewriter.setInsertionPointToEnd(loweredBody);
186 SmallVector<Value> iterArgsValues =
187 loadValues(resultVariables, rewriter, loc);
189 rewriter.restoreInsertionPoint(ip);
193 if (
failed(rewriter.convertRegionTypes(&forOp.getRegion(),
194 *getTypeConverter(),
nullptr))) {
195 return rewriter.notifyMatchFailure(forOp,
"region types conversion failed");
200 Block *scfBody = &(forOp.getRegion().front());
201 SmallVector<Value> replacingValues;
202 replacingValues.push_back(loweredFor.getInductionVar());
203 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
204 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
206 auto result = lowerYield(forOp, resultVariables, rewriter,
214 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
216 rewriter.replaceOp(forOp, resultValues);
222struct IfLowering :
public OpConversionPattern<IfOp> {
223 using OpConversionPattern<IfOp>::OpConversionPattern;
226 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter)
const override;
233IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter)
const {
235 Location loc = ifOp.getLoc();
239 SmallVector<Value> resultVariables;
240 if (
failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
242 return rewriter.notifyMatchFailure(ifOp,
243 "create variables for results failed");
250 auto lowerRegion = [&resultVariables, &rewriter,
251 &ifOp](Region ®ion, Region &loweredRegion) {
252 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
254 auto result = lowerYield(ifOp, resultVariables, rewriter,
255 cast<scf::YieldOp>(terminator));
262 Region &thenRegion = adaptor.getThenRegion();
263 Region &elseRegion = adaptor.getElseRegion();
265 bool hasElseBlock = !elseRegion.
empty();
268 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(),
false,
false);
270 Region &loweredThenRegion = loweredIf.getThenRegion();
271 auto result = lowerRegion(thenRegion, loweredThenRegion);
277 Region &loweredElseRegion = loweredIf.getElseRegion();
278 auto result = lowerRegion(elseRegion, loweredElseRegion);
284 rewriter.setInsertionPointAfter(ifOp);
285 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
287 rewriter.replaceOp(ifOp, results);
294 using OpConversionPattern::OpConversionPattern;
298 ConversionPatternRewriter &rewriter)
const override;
302 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const {
304 Location loc = indexSwitchOp.getLoc();
309 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
310 rewriter, resultVariables))) {
311 return rewriter.notifyMatchFailure(indexSwitchOp,
312 "create variables for results failed");
316 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
317 adaptor.getCases(), indexSwitchOp.getNumCases());
321 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
322 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
323 *std::get<0>(pair), std::get<1>(pair)))) {
329 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
330 adaptor.getDefaultRegion(),
331 loweredSwitch.getDefaultRegion()))) {
335 rewriter.setInsertionPointAfter(indexSwitchOp);
338 rewriter.replaceOp(indexSwitchOp, results);
346 using OpConversionPattern::OpConversionPattern;
350 ConversionPatternRewriter &rewriter)
const override {
357 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
359 return rewriter.notifyMatchFailure(whileOp,
360 "Failed to create result variables");
365 if (failed(createVariablesForLoopCarriedValues(
366 whileOp, rewriter, loopVariables, loc, context)))
369 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
373 rewriter.setInsertionPointAfter(whileOp);
377 loadValues(resultVariables, rewriter, loc);
378 rewriter.replaceOp(whileOp, finalResults);
386 LogicalResult createVariablesForLoopCarriedValues(
387 WhileOp whileOp, ConversionPatternRewriter &rewriter,
391 rewriter.setInsertionPoint(whileOp);
393 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
395 for (
Value init : whileOp.getInits()) {
396 Type convertedType = getTypeConverter()->convertType(init.getType());
398 return rewriter.notifyMatchFailure(whileOp,
"type conversion failed");
399 if (isa<emitc::ArrayType>(convertedType))
400 return rewriter.notifyMatchFailure(
402 "cannot create variable for loop-carried value of array type");
404 auto var = emitc::VariableOp::create(
405 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
406 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
407 loopVars.push_back(var);
414 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
415 ArrayRef<Value> resultVars, MLIRContext *context,
416 ConversionPatternRewriter &rewriter,
417 Location loc)
const {
419 Type i1Type = IntegerType::get(context, 1);
420 auto globalCondition =
421 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
422 emitc::OpaqueAttr::get(context,
""));
423 Value conditionVal = globalCondition.getResult();
425 auto loweredDo = emitc::DoOp::create(rewriter, loc);
428 if (
failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
429 *getTypeConverter(),
nullptr)) ||
430 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
431 *getTypeConverter(),
nullptr))) {
432 return rewriter.notifyMatchFailure(whileOp,
433 "region types conversion failed");
437 Block *beforeBlock = &whileOp.getBefore().front();
438 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
439 rewriter.setInsertionPointToStart(bodyBlock);
443 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
444 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
446 Operation *condTerminator =
447 loweredDo.getBodyRegion().back().getTerminator();
448 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
449 rewriter.setInsertionPoint(condOp);
452 SmallVector<Value> conditionArgs;
453 for (Value arg : condOp.getArgs()) {
454 conditionArgs.push_back(rewriter.getRemappedValue(arg));
456 assignValues(conditionArgs, resultVars, rewriter, loc);
459 Value condition = rewriter.getRemappedValue(condOp.getCondition());
460 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
464 if (whileOp.getAfterBody()->getOperations().size() > 1) {
465 auto ifOp = emitc::IfOp::create(rewriter, loc, condition,
false,
false);
468 Block *afterBlock = &whileOp.getAfter().front();
469 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
472 SmallVector<Value> afterReplacingValues;
473 for (Value arg : condOp.getArgs())
474 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
476 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
478 if (
failed(lowerYield(whileOp, loopVars, rewriter,
483 rewriter.eraseOp(condOp);
486 Region &condRegion = loweredDo.getConditionRegion();
487 Block *condBlock = rewriter.createBlock(&condRegion);
488 rewriter.setInsertionPointToStart(condBlock);
490 auto exprOp = emitc::ExpressionOp::create(
491 rewriter, loc, i1Type, conditionVal,
false);
492 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
496 rewriter.setInsertionPointToStart(exprBlock);
500 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->
getArgument(0));
501 emitc::YieldOp::create(rewriter, loc, cond);
504 rewriter.setInsertionPointToEnd(condBlock);
505 emitc::YieldOp::create(rewriter, loc, exprOp);
513 patterns.
add<ForLowering>(typeConverter, patterns.
getContext());
514 patterns.
add<IfLowering>(typeConverter, patterns.
getContext());
519void SCFToEmitCPass::runOnOperation() {
523 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
534 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
535 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
537 applyPartialConversion(getOperation(),
target, std::move(patterns))))
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override