24#include "llvm/Support/LogicalResult.h"
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
42 void populateConvertToEmitCConversionPatterns(
43 ConversionTarget &
target, TypeConverter &typeConverter,
44 RewritePatternSet &
patterns)
const final {
53 dialect->addInterfaces<SCFToEmitCDialectInterface>();
60 void runOnOperation()
override;
65struct ForLowering :
public OpConversionPattern<ForOp> {
66 using OpConversionPattern<ForOp>::OpConversionPattern;
69 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter)
const override;
76createVariablesForResults(T op,
const TypeConverter *typeConverter,
77 ConversionPatternRewriter &rewriter,
79 if (!op.getNumResults())
86 rewriter.setInsertionPoint(op);
89 Type resultType = typeConverter->convertType(
result.getType());
91 return rewriter.notifyMatchFailure(op,
"result type conversion failed");
92 Type varType = emitc::LValueType::get(resultType);
93 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
94 emitc::VariableOp var =
95 emitc::VariableOp::create(rewriter, loc, varType, noInit);
96 resultVariables.push_back(var);
105 ConversionPatternRewriter &rewriter,
Location loc) {
106 for (
auto [value, var] : llvm::zip(values, variables))
107 emitc::AssignOp::create(rewriter, loc, var, value);
112 return llvm::map_to_vector<>(variables, [&](
Value var) {
113 Type type = cast<emitc::LValueType>(var.
getType()).getValueType();
114 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
119 ConversionPatternRewriter &rewriter,
120 scf::YieldOp yield,
bool createYield =
true) {
124 rewriter.setInsertionPoint(yield);
127 if (
failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
128 return rewriter.notifyMatchFailure(op,
"failed to lower yield operands");
130 assignValues(yieldOperands, resultVariables, rewriter, loc);
132 emitc::YieldOp::create(rewriter, loc);
133 rewriter.eraseOp(yield);
144 ConversionPatternRewriter &rewriter,
146 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
148 return lowerYield(op, resultVariables, rewriter,
149 cast<scf::YieldOp>(terminator));
153ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const {
155 Location loc = forOp.getLoc();
157 if (forOp.getUnsignedCmp())
158 return rewriter.notifyMatchFailure(forOp,
159 "unsigned loops are not supported");
163 SmallVector<Value> resultVariables;
164 if (
failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
166 return rewriter.notifyMatchFailure(forOp,
167 "create variables for results failed");
169 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
171 emitc::ForOp loweredFor =
172 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
173 adaptor.getUpperBound(), adaptor.getStep());
175 Block *loweredBody = loweredFor.getBody();
180 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
181 rewriter.setInsertionPointToEnd(loweredBody);
183 SmallVector<Value> iterArgsValues =
184 loadValues(resultVariables, rewriter, loc);
186 rewriter.restoreInsertionPoint(ip);
190 if (
failed(rewriter.convertRegionTypes(&forOp.getRegion(),
191 *getTypeConverter(),
nullptr))) {
192 return rewriter.notifyMatchFailure(forOp,
"region types conversion failed");
197 Block *scfBody = &(forOp.getRegion().front());
198 SmallVector<Value> replacingValues;
199 replacingValues.push_back(loweredFor.getInductionVar());
200 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
201 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
203 auto result = lowerYield(forOp, resultVariables, rewriter,
211 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
213 rewriter.replaceOp(forOp, resultValues);
219struct IfLowering :
public OpConversionPattern<IfOp> {
220 using OpConversionPattern<IfOp>::OpConversionPattern;
223 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter)
const override;
230IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter)
const {
232 Location loc = ifOp.getLoc();
236 SmallVector<Value> resultVariables;
237 if (
failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
239 return rewriter.notifyMatchFailure(ifOp,
240 "create variables for results failed");
247 auto lowerRegion = [&resultVariables, &rewriter,
248 &ifOp](Region ®ion, Region &loweredRegion) {
249 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
251 auto result = lowerYield(ifOp, resultVariables, rewriter,
252 cast<scf::YieldOp>(terminator));
259 Region &thenRegion = adaptor.getThenRegion();
260 Region &elseRegion = adaptor.getElseRegion();
262 bool hasElseBlock = !elseRegion.
empty();
265 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(),
false,
false);
267 Region &loweredThenRegion = loweredIf.getThenRegion();
268 auto result = lowerRegion(thenRegion, loweredThenRegion);
274 Region &loweredElseRegion = loweredIf.getElseRegion();
275 auto result = lowerRegion(elseRegion, loweredElseRegion);
281 rewriter.setInsertionPointAfter(ifOp);
282 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
284 rewriter.replaceOp(ifOp, results);
291 using OpConversionPattern::OpConversionPattern;
295 ConversionPatternRewriter &rewriter)
const override;
299 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
300 ConversionPatternRewriter &rewriter)
const {
301 Location loc = indexSwitchOp.getLoc();
306 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
307 rewriter, resultVariables))) {
308 return rewriter.notifyMatchFailure(indexSwitchOp,
309 "create variables for results failed");
313 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
314 adaptor.getCases(), indexSwitchOp.getNumCases());
318 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
319 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
320 *std::get<0>(pair), std::get<1>(pair)))) {
326 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
327 adaptor.getDefaultRegion(),
328 loweredSwitch.getDefaultRegion()))) {
332 rewriter.setInsertionPointAfter(indexSwitchOp);
335 rewriter.replaceOp(indexSwitchOp, results);
343 using OpConversionPattern::OpConversionPattern;
347 ConversionPatternRewriter &rewriter)
const override {
354 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
356 return rewriter.notifyMatchFailure(whileOp,
357 "Failed to create result variables");
362 if (failed(createVariablesForLoopCarriedValues(
363 whileOp, rewriter, loopVariables, loc, context)))
366 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
370 rewriter.setInsertionPointAfter(whileOp);
374 loadValues(resultVariables, rewriter, loc);
375 rewriter.replaceOp(whileOp, finalResults);
383 LogicalResult createVariablesForLoopCarriedValues(
384 WhileOp whileOp, ConversionPatternRewriter &rewriter,
388 rewriter.setInsertionPoint(whileOp);
390 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context,
"");
392 for (
Value init : whileOp.getInits()) {
393 Type convertedType = getTypeConverter()->convertType(init.getType());
395 return rewriter.notifyMatchFailure(whileOp,
"type conversion failed");
397 auto var = emitc::VariableOp::create(
398 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
399 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
400 loopVars.push_back(var);
407 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
408 ArrayRef<Value> resultVars, MLIRContext *context,
409 ConversionPatternRewriter &rewriter,
410 Location loc)
const {
412 Type i1Type = IntegerType::get(context, 1);
413 auto globalCondition =
414 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
415 emitc::OpaqueAttr::get(context,
""));
416 Value conditionVal = globalCondition.getResult();
418 auto loweredDo = emitc::DoOp::create(rewriter, loc);
421 if (
failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
422 *getTypeConverter(),
nullptr)) ||
423 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
424 *getTypeConverter(),
nullptr))) {
425 return rewriter.notifyMatchFailure(whileOp,
426 "region types conversion failed");
430 Block *beforeBlock = &whileOp.getBefore().front();
431 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
432 rewriter.setInsertionPointToStart(bodyBlock);
436 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
437 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
439 Operation *condTerminator =
440 loweredDo.getBodyRegion().back().getTerminator();
441 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
442 rewriter.setInsertionPoint(condOp);
445 SmallVector<Value> conditionArgs;
446 for (Value arg : condOp.getArgs()) {
447 conditionArgs.push_back(rewriter.getRemappedValue(arg));
449 assignValues(conditionArgs, resultVars, rewriter, loc);
452 Value condition = rewriter.getRemappedValue(condOp.getCondition());
453 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
457 if (whileOp.getAfterBody()->getOperations().size() > 1) {
458 auto ifOp = emitc::IfOp::create(rewriter, loc, condition,
false,
false);
461 Block *afterBlock = &whileOp.getAfter().front();
462 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
465 SmallVector<Value> afterReplacingValues;
466 for (Value arg : condOp.getArgs())
467 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
469 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
471 if (
failed(lowerYield(whileOp, loopVars, rewriter,
476 rewriter.eraseOp(condOp);
479 Region &condRegion = loweredDo.getConditionRegion();
480 Block *condBlock = rewriter.createBlock(&condRegion);
481 rewriter.setInsertionPointToStart(condBlock);
483 auto exprOp = emitc::ExpressionOp::create(
484 rewriter, loc, i1Type, conditionVal,
false);
485 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
489 rewriter.setInsertionPointToStart(exprBlock);
493 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->
getArgument(0));
494 emitc::YieldOp::create(rewriter, loc, cond);
497 rewriter.setInsertionPointToEnd(condBlock);
498 emitc::YieldOp::create(rewriter, loc, exprOp);
512void SCFToEmitCPass::runOnOperation() {
516 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
527 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
528 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
530 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
const FrozenRewritePatternSet & patterns
void registerConvertSCFToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override