MLIR 23.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/MLIRContext.h"
24#include "llvm/Support/LogicalResult.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::scf;
33
34namespace {
35
36/// Implement the interface to convert SCF to EmitC.
37struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
38 SCFToEmitCDialectInterface(Dialect *dialect)
39 : ConvertToEmitCPatternInterface(dialect) {}
40
41 /// Hook for derived dialect interface to provide conversion patterns
42 /// and mark dialect legal for the conversion target.
43 void populateConvertToEmitCConversionPatterns(
44 ConversionTarget &target, TypeConverter &typeConverter,
45 RewritePatternSet &patterns) const final {
47 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
48 }
49};
50} // namespace
51
53 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
54 dialect->addInterfaces<SCFToEmitCDialectInterface>();
55 });
56}
57
58namespace {
59
60struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
61 void runOnOperation() override;
62};
63
64// Lower scf::for to emitc::for, implementing result values using
65// emitc::variable's updated within the loop body.
66struct ForLowering : public OpConversionPattern<ForOp> {
67 using OpConversionPattern<ForOp>::OpConversionPattern;
68
69 LogicalResult
70 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override;
72};
73
74// Create an uninitialized emitc::variable op for each result of the given op.
75template <typename T>
76static LogicalResult
77createVariablesForResults(T op, const TypeConverter *typeConverter,
78 ConversionPatternRewriter &rewriter,
79 SmallVector<Value> &resultVariables) {
80 if (!op.getNumResults())
81 return success();
82
83 Location loc = op->getLoc();
84 MLIRContext *context = op.getContext();
85
86 OpBuilder::InsertionGuard guard(rewriter);
87 rewriter.setInsertionPoint(op);
88
89 for (OpResult result : op.getResults()) {
90 Type resultType = typeConverter->convertType(result.getType());
91 if (!resultType)
92 return rewriter.notifyMatchFailure(op, "result type conversion failed");
93 if (isa<emitc::ArrayType>(resultType))
94 return rewriter.notifyMatchFailure(
95 op, "cannot create variable for result of array type");
96 Type varType = emitc::LValueType::get(resultType);
97 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
98 emitc::VariableOp var =
99 emitc::VariableOp::create(rewriter, loc, varType, noInit);
100 resultVariables.push_back(var);
101 }
102
103 return success();
104}
105
106// Create a series of assign ops assigning given values to given variables at
107// the current insertion point of given rewriter.
108static void assignValues(ValueRange values, ValueRange variables,
109 ConversionPatternRewriter &rewriter, Location loc) {
110 for (auto [value, var] : llvm::zip(values, variables))
111 emitc::AssignOp::create(rewriter, loc, var, value);
112}
113
114SmallVector<Value> loadValues(ArrayRef<Value> variables,
115 PatternRewriter &rewriter, Location loc) {
116 return llvm::map_to_vector<>(variables, [&](Value var) {
117 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
118 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
119 });
120}
121
122static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
123 ConversionPatternRewriter &rewriter,
124 scf::YieldOp yield, bool createYield = true) {
125 Location loc = yield.getLoc();
126
127 OpBuilder::InsertionGuard guard(rewriter);
128 rewriter.setInsertionPoint(yield);
129
130 SmallVector<Value> yieldOperands;
131 if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
132 return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
133
134 assignValues(yieldOperands, resultVariables, rewriter, loc);
135
136 emitc::YieldOp::create(rewriter, loc);
137 rewriter.eraseOp(yield);
138
139 return success();
140}
141
142// Lower the contents of an scf::if/scf::index_switch regions to an
143// emitc::if/emitc::switch region. The contents of the lowering region is
144// moved into the respective lowered region, but the scf::yield is replaced not
145// only with an emitc::yield, but also with a sequence of emitc::assign ops that
146// set the yielded values into the result variables.
147static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
148 ConversionPatternRewriter &rewriter,
149 Region &region, Region &loweredRegion) {
150 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
151 Operation *terminator = loweredRegion.back().getTerminator();
152 return lowerYield(op, resultVariables, rewriter,
153 cast<scf::YieldOp>(terminator));
154}
155
156LogicalResult
157ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const {
159 Location loc = forOp.getLoc();
160
161 if (forOp.getUnsignedCmp())
162 return rewriter.notifyMatchFailure(forOp,
163 "unsigned loops are not supported");
164
165 // Create an emitc::variable op for each result. These variables will be
166 // assigned to by emitc::assign ops within the loop body.
167 SmallVector<Value> resultVariables;
168 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
169 resultVariables)))
170 return rewriter.notifyMatchFailure(forOp,
171 "create variables for results failed");
172
173 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
174
175 emitc::ForOp loweredFor =
176 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
177 adaptor.getUpperBound(), adaptor.getStep());
178
179 Block *loweredBody = loweredFor.getBody();
180
181 // Erase the auto-generated terminator for the lowered for op.
182 rewriter.eraseOp(loweredBody->getTerminator());
183
184 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
185 rewriter.setInsertionPointToEnd(loweredBody);
186
187 SmallVector<Value> iterArgsValues =
188 loadValues(resultVariables, rewriter, loc);
189
190 rewriter.restoreInsertionPoint(ip);
191
192 // Convert the original region types into the new types by adding unrealized
193 // casts in the beginning of the loop. This performs the conversion in place.
194 if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
195 *getTypeConverter(), nullptr))) {
196 return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
197 }
198
199 // Register the replacements for the block arguments and inline the body of
200 // the scf.for loop into the body of the emitc::for loop.
201 Block *scfBody = &(forOp.getRegion().front());
202 SmallVector<Value> replacingValues;
203 replacingValues.push_back(loweredFor.getInductionVar());
204 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
205 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
206
207 auto result = lowerYield(forOp, resultVariables, rewriter,
208 cast<scf::YieldOp>(loweredBody->getTerminator()));
209
210 if (failed(result)) {
211 return result;
212 }
213
214 // Load variables into SSA values after the for loop.
215 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
216
217 rewriter.replaceOp(forOp, resultValues);
218 return success();
219}
220
221// Lower scf::if to emitc::if, implementing result values as emitc::variable's
222// updated within the then and else regions.
223struct IfLowering : public OpConversionPattern<IfOp> {
224 using OpConversionPattern<IfOp>::OpConversionPattern;
225
226 LogicalResult
227 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override;
229};
230
231} // namespace
232
233LogicalResult
234IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const {
236 Location loc = ifOp.getLoc();
237
238 // Create an emitc::variable op for each result. These variables will be
239 // assigned to by emitc::assign ops within the then & else regions.
240 SmallVector<Value> resultVariables;
241 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
242 resultVariables)))
243 return rewriter.notifyMatchFailure(ifOp,
244 "create variables for results failed");
245
246 // Utility function to lower the contents of an scf::if region to an emitc::if
247 // region. The contents of the scf::if regions is moved into the respective
248 // emitc::if regions, but the scf::yield is replaced not only with an
249 // emitc::yield, but also with a sequence of emitc::assign ops that set the
250 // yielded values into the result variables.
251 auto lowerRegion = [&resultVariables, &rewriter,
252 &ifOp](Region &region, Region &loweredRegion) {
253 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
254 Operation *terminator = loweredRegion.back().getTerminator();
255 auto result = lowerYield(ifOp, resultVariables, rewriter,
256 cast<scf::YieldOp>(terminator));
257 if (failed(result)) {
258 return result;
259 }
260 return success();
261 };
262
263 Region &thenRegion = adaptor.getThenRegion();
264 Region &elseRegion = adaptor.getElseRegion();
265
266 bool hasElseBlock = !elseRegion.empty();
267
268 auto loweredIf =
269 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
270
271 Region &loweredThenRegion = loweredIf.getThenRegion();
272 auto result = lowerRegion(thenRegion, loweredThenRegion);
273 if (failed(result)) {
274 return result;
275 }
276
277 if (hasElseBlock) {
278 Region &loweredElseRegion = loweredIf.getElseRegion();
279 auto result = lowerRegion(elseRegion, loweredElseRegion);
280 if (failed(result)) {
281 return result;
282 }
283 }
284
285 rewriter.setInsertionPointAfter(ifOp);
286 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
287
288 rewriter.replaceOp(ifOp, results);
289 return success();
290}
291
292// Lower scf::index_switch to emitc::switch, implementing result values as
293// emitc::variable's updated within the case and default regions.
294struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
295 using OpConversionPattern::OpConversionPattern;
296
297 LogicalResult
298 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override;
300};
301
303 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter) const {
305 Location loc = indexSwitchOp.getLoc();
306
307 // Create an emitc::variable op for each result. These variables will be
308 // assigned to by emitc::assign ops within the case and default regions.
309 SmallVector<Value> resultVariables;
310 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
311 rewriter, resultVariables))) {
312 return rewriter.notifyMatchFailure(indexSwitchOp,
313 "create variables for results failed");
314 }
315
316 auto loweredSwitch =
317 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
318 adaptor.getCases(), indexSwitchOp.getNumCases());
319
320 // Lowering all case regions.
321 for (auto pair :
322 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
323 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
324 *std::get<0>(pair), std::get<1>(pair)))) {
325 return failure();
326 }
327 }
328
329 // Lowering default region.
330 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
331 adaptor.getDefaultRegion(),
332 loweredSwitch.getDefaultRegion()))) {
333 return failure();
334 }
335
336 rewriter.setInsertionPointAfter(indexSwitchOp);
337 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
338
339 rewriter.replaceOp(indexSwitchOp, results);
340 return success();
341}
342
343// Lower scf::while to emitc::do using mutable variables to maintain loop state
344// across iterations. The do-while structure ensures the condition is evaluated
345// after each iteration, matching SCF while semantics.
346struct WhileLowering : public OpConversionPattern<WhileOp> {
347 using OpConversionPattern::OpConversionPattern;
348
349 LogicalResult
350 matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const override {
352 Location loc = whileOp.getLoc();
353 MLIRContext *context = loc.getContext();
354
355 // Create an emitc::variable op for each result. These variables will be
356 // assigned to by emitc::assign ops within the loop body.
357 SmallVector<Value> resultVariables;
358 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
359 resultVariables)))
360 return rewriter.notifyMatchFailure(whileOp,
361 "Failed to create result variables");
362
363 // Create variable storage for loop-carried values to enable imperative
364 // updates while maintaining SSA semantics at conversion boundaries.
365 SmallVector<Value> loopVariables;
366 if (failed(createVariablesForLoopCarriedValues(
367 whileOp, rewriter, loopVariables, loc, context)))
368 return failure();
369
370 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
371 rewriter, loc)))
372 return failure();
373
374 rewriter.setInsertionPointAfter(whileOp);
375
376 // Load the final result values from result variables.
377 SmallVector<Value> finalResults =
378 loadValues(resultVariables, rewriter, loc);
379 rewriter.replaceOp(whileOp, finalResults);
380
381 return success();
382 }
383
384private:
385 // Initialize variables for loop-carried values to enable state updates
386 // across iterations without SSA argument passing.
387 LogicalResult createVariablesForLoopCarriedValues(
388 WhileOp whileOp, ConversionPatternRewriter &rewriter,
389 SmallVectorImpl<Value> &loopVars, Location loc,
390 MLIRContext *context) const {
391 OpBuilder::InsertionGuard guard(rewriter);
392 rewriter.setInsertionPoint(whileOp);
393
394 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
395
396 for (Value init : whileOp.getInits()) {
397 Type convertedType = getTypeConverter()->convertType(init.getType());
398 if (!convertedType)
399 return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
400 if (isa<emitc::ArrayType>(convertedType))
401 return rewriter.notifyMatchFailure(
402 whileOp,
403 "cannot create variable for loop-carried value of array type");
404
405 auto var = emitc::VariableOp::create(
406 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
407 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
408 loopVars.push_back(var);
409 }
410
411 return success();
412 }
413
414 // Lower scf.while to emitc.do.
415 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
416 ArrayRef<Value> resultVars, MLIRContext *context,
417 ConversionPatternRewriter &rewriter,
418 Location loc) const {
419 // Create a global boolean variable to store the loop condition state.
420 Type i1Type = IntegerType::get(context, 1);
421 auto globalCondition =
422 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
423 emitc::OpaqueAttr::get(context, ""));
424 Value conditionVal = globalCondition.getResult();
425
426 auto loweredDo = emitc::DoOp::create(rewriter, loc);
427
428 // Convert region types to match the target dialect type system.
429 if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
430 *getTypeConverter(), nullptr)) ||
431 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
432 *getTypeConverter(), nullptr))) {
433 return rewriter.notifyMatchFailure(whileOp,
434 "region types conversion failed");
435 }
436
437 // Prepare the before region (condition evaluation) for merging.
438 Block *beforeBlock = &whileOp.getBefore().front();
439 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
440 rewriter.setInsertionPointToStart(bodyBlock);
441
442 // Load current variable values to use as initial arguments for the
443 // condition block.
444 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
445 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
446
447 Operation *condTerminator =
448 loweredDo.getBodyRegion().back().getTerminator();
449 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
450 rewriter.setInsertionPoint(condOp);
451
452 // Update result variables with values from scf::condition.
453 SmallVector<Value> conditionArgs;
454 for (Value arg : condOp.getArgs()) {
455 conditionArgs.push_back(rewriter.getRemappedValue(arg));
456 }
457 assignValues(conditionArgs, resultVars, rewriter, loc);
458
459 // Convert scf.condition to condition variable assignment.
460 Value condition = rewriter.getRemappedValue(condOp.getCondition());
461 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
462
463 // Wrap body region in conditional to preserve scf semantics. Only create
464 // ifOp if after-region is non-empty.
465 if (whileOp.getAfterBody()->getOperations().size() > 1) {
466 auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
467
468 // Prepare the after region (loop body) for merging.
469 Block *afterBlock = &whileOp.getAfter().front();
470 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
471
472 // Replacement values for after block using condition op arguments.
473 SmallVector<Value> afterReplacingValues;
474 for (Value arg : condOp.getArgs())
475 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
476
477 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
478
479 if (failed(lowerYield(whileOp, loopVars, rewriter,
480 cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
481 return failure();
482 }
483
484 rewriter.eraseOp(condOp);
485
486 // Create condition region that loads from the flag variable.
487 Region &condRegion = loweredDo.getConditionRegion();
488 Block *condBlock = rewriter.createBlock(&condRegion);
489 rewriter.setInsertionPointToStart(condBlock);
490
491 auto exprOp = emitc::ExpressionOp::create(
492 rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
493 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
494
495 // Set up the expression block to load the condition variable.
496 exprBlock->addArgument(conditionVal.getType(), loc);
497 rewriter.setInsertionPointToStart(exprBlock);
498
499 // Load the condition value and yield it as the expression result.
500 Value cond =
501 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
502 emitc::YieldOp::create(rewriter, loc, cond);
503
504 // Yield the expression as the condition region result.
505 rewriter.setInsertionPointToEnd(condBlock);
506 emitc::YieldOp::create(rewriter, loc, exprOp);
507
508 return success();
509 }
510};
511
513 TypeConverter &typeConverter) {
514 patterns.add<ForLowering>(typeConverter, patterns.getContext());
515 patterns.add<IfLowering>(typeConverter, patterns.getContext());
516 patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
517 patterns.add<WhileLowering>(typeConverter, patterns.getContext());
518}
519
520void SCFToEmitCPass::runOnOperation() {
521 RewritePatternSet patterns(&getContext());
522 TypeConverter typeConverter;
523 // Fallback for other types.
524 typeConverter.addConversion([](Type type) -> std::optional<Type> {
526 return {};
527 return type;
528 });
530 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
531
532 // Configure conversion to lower out SCF operations.
533 ConversionTarget target(getContext());
534 target
535 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
536 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
537 if (failed(
538 applyPartialConversion(getOperation(), target, std::move(patterns))))
539 signalPassFailure();
540}
return success()
b getContext())
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition EmitC.cpp:62
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override