28 #define GEN_PASS_DEF_SCFTOEMITC
29 #include "mlir/Conversion/Passes.h.inc"
37 struct SCFToEmitCPass :
public impl::SCFToEmitCBase<SCFToEmitCPass> {
38 void runOnOperation()
override;
47 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
54 createVariablesForResults(T op,
const TypeConverter *typeConverter,
57 if (!op.getNumResults())
66 for (
OpResult result : op.getResults()) {
72 emitc::VariableOp var =
73 rewriter.
create<emitc::VariableOp>(loc, varType, noInit);
74 resultVariables.push_back(var);
84 for (
auto [value, var] : llvm::zip(values, variables))
85 rewriter.
create<emitc::AssignOp>(loc, var, value);
90 return llvm::map_to_vector<>(variables, [&](
Value var) {
91 Type type = cast<emitc::LValueType>(var.
getType()).getValueType();
92 return rewriter.
create<emitc::LoadOp>(loc, type, var).getResult();
109 assignValues(yieldOperands, resultVariables, rewriter, loc);
111 rewriter.
create<emitc::YieldOp>(loc);
127 return lowerYield(op, resultVariables, rewriter,
128 cast<scf::YieldOp>(terminator));
132 ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
139 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
142 "create variables for results failed");
144 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
146 emitc::ForOp loweredFor = rewriter.
create<emitc::ForOp>(
147 loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
149 Block *loweredBody = loweredFor.getBody();
152 rewriter.
eraseOp(loweredBody->getTerminator());
158 loadValues(resultVariables, rewriter, loc);
165 *getTypeConverter(),
nullptr))) {
171 Block *scfBody = &(forOp.getRegion().front());
173 replacingValues.push_back(loweredFor.getInductionVar());
174 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
175 rewriter.
mergeBlocks(scfBody, loweredBody, replacingValues);
177 auto result = lowerYield(forOp, resultVariables, rewriter,
178 cast<scf::YieldOp>(loweredBody->getTerminator()));
180 if (failed(result)) {
197 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
204 IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
211 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
214 "create variables for results failed");
221 auto lowerRegion = [&resultVariables, &rewriter,
223 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.
end());
225 auto result = lowerYield(ifOp, resultVariables, rewriter,
226 cast<scf::YieldOp>(terminator));
227 if (failed(result)) {
233 Region &thenRegion = adaptor.getThenRegion();
234 Region &elseRegion = adaptor.getElseRegion();
236 bool hasElseBlock = !elseRegion.
empty();
239 rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(),
false,
false);
241 Region &loweredThenRegion = loweredIf.getThenRegion();
242 auto result = lowerRegion(thenRegion, loweredThenRegion);
243 if (failed(result)) {
248 Region &loweredElseRegion = loweredIf.getElseRegion();
249 auto result = lowerRegion(elseRegion, loweredElseRegion);
250 if (failed(result)) {
255 rewriter.setInsertionPointAfter(ifOp);
258 rewriter.replaceOp(ifOp, results);
268 matchAndRewrite(IndexSwitchOp indexSwitchOp,
OpAdaptor adaptor,
273 IndexSwitchOp indexSwitchOp,
OpAdaptor adaptor,
275 Location loc = indexSwitchOp.getLoc();
280 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
281 rewriter, resultVariables))) {
283 "create variables for results failed");
286 auto loweredSwitch = rewriter.
create<emitc::SwitchOp>(
287 loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
291 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
292 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
293 *std::get<0>(pair), std::get<1>(pair)))) {
299 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
300 adaptor.getDefaultRegion(),
301 loweredSwitch.getDefaultRegion()))) {
308 rewriter.
replaceOp(indexSwitchOp, results);
319 void SCFToEmitCPass::runOnOperation() {
331 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
332 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.