MLIR 23.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/MLIRContext.h"
25#include "llvm/Support/LogicalResult.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_SCFTOEMITC
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::scf;
34
35namespace {
36
37/// Implement the interface to convert SCF to EmitC.
38struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
39 SCFToEmitCDialectInterface(Dialect *dialect)
40 : ConvertToEmitCPatternInterface(dialect) {}
41
42 /// Hook for derived dialect interface to provide conversion patterns
43 /// and mark dialect legal for the conversion target.
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &target, TypeConverter &typeConverter,
46 RewritePatternSet &patterns, std::optional<bool> lowerToCpp) const final {
48 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
49 }
50};
51} // namespace
52
54 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
55 dialect->addInterfaces<SCFToEmitCDialectInterface>();
56 });
57}
58
59namespace {
60
61struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
62 void runOnOperation() override;
63};
64
65// Lower scf::for to emitc::for, implementing result values using
66// emitc::variable's updated within the loop body.
67struct ForLowering : public OpConversionPattern<ForOp> {
68 using OpConversionPattern<ForOp>::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override;
73};
74
75// Create an uninitialized emitc::variable op for each result of the given op.
76template <typename T>
77static LogicalResult
78createVariablesForResults(T op, const TypeConverter *typeConverter,
79 ConversionPatternRewriter &rewriter,
80 SmallVector<Value> &resultVariables) {
81 if (!op.getNumResults())
82 return success();
83
84 Location loc = op->getLoc();
85 MLIRContext *context = op.getContext();
86
87 OpBuilder::InsertionGuard guard(rewriter);
88 rewriter.setInsertionPoint(op);
89
90 for (OpResult result : op.getResults()) {
91 Type resultType = typeConverter->convertType(result.getType());
92 if (!resultType)
93 return rewriter.notifyMatchFailure(op, "result type conversion failed");
94 if (isa<emitc::ArrayType>(resultType))
95 return rewriter.notifyMatchFailure(
96 op, "cannot create variable for result of array type");
97 Type varType = emitc::LValueType::get(resultType);
98 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
99 emitc::VariableOp var =
100 emitc::VariableOp::create(rewriter, loc, varType, noInit);
101 resultVariables.push_back(var);
102 }
103
104 return success();
105}
106
107// Create a series of assign ops assigning given values to given variables at
108// the current insertion point of given rewriter.
109static void assignValues(ValueRange values, ValueRange variables,
110 ConversionPatternRewriter &rewriter, Location loc) {
111 for (auto [value, var] : llvm::zip(values, variables))
112 emitc::AssignOp::create(rewriter, loc, var, value);
113}
114
115SmallVector<Value> loadValues(ArrayRef<Value> variables,
116 PatternRewriter &rewriter, Location loc) {
117 return llvm::map_to_vector<>(variables, [&](Value var) {
118 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
119 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
120 });
121}
122
123static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
124 ConversionPatternRewriter &rewriter,
125 scf::YieldOp yield, bool createYield = true) {
126 Location loc = yield.getLoc();
127
128 OpBuilder::InsertionGuard guard(rewriter);
129 rewriter.setInsertionPoint(yield);
130
131 SmallVector<Value> yieldOperands;
132 if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
133 return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
134
135 assignValues(yieldOperands, resultVariables, rewriter, loc);
136
137 emitc::YieldOp::create(rewriter, loc);
138 rewriter.eraseOp(yield);
139
140 return success();
141}
142
143// Lower the contents of an scf::if/scf::index_switch regions to an
144// emitc::if/emitc::switch region. The contents of the lowering region is
145// moved into the respective lowered region, but the scf::yield is replaced not
146// only with an emitc::yield, but also with a sequence of emitc::assign ops that
147// set the yielded values into the result variables.
148static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
149 ConversionPatternRewriter &rewriter,
150 Region &region, Region &loweredRegion) {
151 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
152 Operation *terminator = loweredRegion.back().getTerminator();
153 return lowerYield(op, resultVariables, rewriter,
154 cast<scf::YieldOp>(terminator));
155}
156
157LogicalResult
158ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
159 ConversionPatternRewriter &rewriter) const {
160 Location loc = forOp.getLoc();
161
162 if (forOp.getUnsignedCmp())
163 return rewriter.notifyMatchFailure(forOp,
164 "unsigned loops are not supported");
165
166 // Create an emitc::variable op for each result. These variables will be
167 // assigned to by emitc::assign ops within the loop body.
168 SmallVector<Value> resultVariables;
169 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
170 resultVariables)))
171 return rewriter.notifyMatchFailure(forOp,
172 "create variables for results failed");
173
174 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
175
176 emitc::ForOp loweredFor =
177 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
178 adaptor.getUpperBound(), adaptor.getStep());
179
180 Block *loweredBody = loweredFor.getBody();
181
182 // Erase the auto-generated terminator for the lowered for op.
183 rewriter.eraseOp(loweredBody->getTerminator());
184
185 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
186 rewriter.setInsertionPointToEnd(loweredBody);
187
188 SmallVector<Value> iterArgsValues =
189 loadValues(resultVariables, rewriter, loc);
190
191 rewriter.restoreInsertionPoint(ip);
192
193 // Convert the original region types into the new types by adding unrealized
194 // casts in the beginning of the loop. This performs the conversion in place.
195 if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
196 *getTypeConverter(), nullptr))) {
197 return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
198 }
199
200 // Register the replacements for the block arguments and inline the body of
201 // the scf.for loop into the body of the emitc::for loop.
202 Block *scfBody = &(forOp.getRegion().front());
203 SmallVector<Value> replacingValues;
204 replacingValues.push_back(loweredFor.getInductionVar());
205 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
206 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
207
208 auto result = lowerYield(forOp, resultVariables, rewriter,
209 cast<scf::YieldOp>(loweredBody->getTerminator()));
210
211 if (failed(result)) {
212 return result;
213 }
214
215 // Load variables into SSA values after the for loop.
216 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
217
218 rewriter.replaceOp(forOp, resultValues);
219 return success();
220}
221
222// Lower scf::if to emitc::if, implementing result values as emitc::variable's
223// updated within the then and else regions.
224struct IfLowering : public OpConversionPattern<IfOp> {
225 using OpConversionPattern<IfOp>::OpConversionPattern;
226
227 LogicalResult
228 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override;
230};
231
232} // namespace
233
234LogicalResult
235IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const {
237 Location loc = ifOp.getLoc();
238
239 // Create an emitc::variable op for each result. These variables will be
240 // assigned to by emitc::assign ops within the then & else regions.
241 SmallVector<Value> resultVariables;
242 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
243 resultVariables)))
244 return rewriter.notifyMatchFailure(ifOp,
245 "create variables for results failed");
246
247 // Utility function to lower the contents of an scf::if region to an emitc::if
248 // region. The contents of the scf::if regions is moved into the respective
249 // emitc::if regions, but the scf::yield is replaced not only with an
250 // emitc::yield, but also with a sequence of emitc::assign ops that set the
251 // yielded values into the result variables.
252 auto lowerRegion = [&resultVariables, &rewriter,
253 &ifOp](Region &region, Region &loweredRegion) {
254 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
255 Operation *terminator = loweredRegion.back().getTerminator();
256 auto result = lowerYield(ifOp, resultVariables, rewriter,
257 cast<scf::YieldOp>(terminator));
258 if (failed(result)) {
259 return result;
260 }
261 return success();
262 };
263
264 Region &thenRegion = adaptor.getThenRegion();
265 Region &elseRegion = adaptor.getElseRegion();
266
267 bool hasElseBlock = !elseRegion.empty();
268
269 auto loweredIf =
270 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
271
272 Region &loweredThenRegion = loweredIf.getThenRegion();
273 auto result = lowerRegion(thenRegion, loweredThenRegion);
274 if (failed(result)) {
275 return result;
276 }
277
278 if (hasElseBlock) {
279 Region &loweredElseRegion = loweredIf.getElseRegion();
280 auto result = lowerRegion(elseRegion, loweredElseRegion);
281 if (failed(result)) {
282 return result;
283 }
284 }
285
286 rewriter.setInsertionPointAfter(ifOp);
287 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
288
289 rewriter.replaceOp(ifOp, results);
290 return success();
291}
292
293// Lower scf::index_switch to emitc::switch, implementing result values as
294// emitc::variable's updated within the case and default regions.
295struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
296 using OpConversionPattern::OpConversionPattern;
297
298 LogicalResult
299 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
300 ConversionPatternRewriter &rewriter) const override;
301};
302
304 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const {
306 Location loc = indexSwitchOp.getLoc();
307
308 // Create an emitc::variable op for each result. These variables will be
309 // assigned to by emitc::assign ops within the case and default regions.
310 SmallVector<Value> resultVariables;
311 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
312 rewriter, resultVariables))) {
313 return rewriter.notifyMatchFailure(indexSwitchOp,
314 "create variables for results failed");
315 }
316
317 auto loweredSwitch =
318 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
319 adaptor.getCases(), indexSwitchOp.getNumCases());
320
321 // Lowering all case regions.
322 for (auto pair :
323 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
324 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
325 *std::get<0>(pair), std::get<1>(pair)))) {
326 return failure();
327 }
328 }
329
330 // Lowering default region.
331 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
332 adaptor.getDefaultRegion(),
333 loweredSwitch.getDefaultRegion()))) {
334 return failure();
335 }
336
337 rewriter.setInsertionPointAfter(indexSwitchOp);
338 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
339
340 rewriter.replaceOp(indexSwitchOp, results);
341 return success();
342}
343
344// Lower scf::while to emitc::do using mutable variables to maintain loop state
345// across iterations. The do-while structure ensures the condition is evaluated
346// after each iteration, matching SCF while semantics.
347struct WhileLowering : public OpConversionPattern<WhileOp> {
348 using OpConversionPattern::OpConversionPattern;
349
350 LogicalResult
351 matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 Location loc = whileOp.getLoc();
354 MLIRContext *context = loc.getContext();
355
356 // Create an emitc::variable op for each result. These variables will be
357 // assigned to by emitc::assign ops within the loop body.
358 SmallVector<Value> resultVariables;
359 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
360 resultVariables)))
361 return rewriter.notifyMatchFailure(whileOp,
362 "Failed to create result variables");
363
364 // Create variable storage for loop-carried values to enable imperative
365 // updates while maintaining SSA semantics at conversion boundaries.
366 SmallVector<Value> loopVariables;
367 if (failed(createVariablesForLoopCarriedValues(
368 whileOp, rewriter, loopVariables, loc, context)))
369 return failure();
370
371 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
372 rewriter, loc)))
373 return failure();
374
375 rewriter.setInsertionPointAfter(whileOp);
376
377 // Load the final result values from result variables.
378 SmallVector<Value> finalResults =
379 loadValues(resultVariables, rewriter, loc);
380 rewriter.replaceOp(whileOp, finalResults);
381
382 return success();
383 }
384
385private:
386 // Initialize variables for loop-carried values to enable state updates
387 // across iterations without SSA argument passing.
388 LogicalResult createVariablesForLoopCarriedValues(
389 WhileOp whileOp, ConversionPatternRewriter &rewriter,
390 SmallVectorImpl<Value> &loopVars, Location loc,
391 MLIRContext *context) const {
392 OpBuilder::InsertionGuard guard(rewriter);
393 rewriter.setInsertionPoint(whileOp);
394
395 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
396
397 for (Value init : whileOp.getInits()) {
398 Type convertedType = getTypeConverter()->convertType(init.getType());
399 if (!convertedType)
400 return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
401 if (isa<emitc::ArrayType>(convertedType))
402 return rewriter.notifyMatchFailure(
403 whileOp,
404 "cannot create variable for loop-carried value of array type");
405
406 auto var = emitc::VariableOp::create(
407 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
408 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
409 loopVars.push_back(var);
410 }
411
412 return success();
413 }
414
415 // Lower scf.while to emitc.do.
416 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
417 ArrayRef<Value> resultVars, MLIRContext *context,
418 ConversionPatternRewriter &rewriter,
419 Location loc) const {
420 // Create a global boolean variable to store the loop condition state.
421 Type i1Type = IntegerType::get(context, 1);
422 auto globalCondition =
423 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
424 emitc::OpaqueAttr::get(context, ""));
425 Value conditionVal = globalCondition.getResult();
426
427 auto loweredDo = emitc::DoOp::create(rewriter, loc);
428
429 // Convert region types to match the target dialect type system.
430 if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
431 *getTypeConverter(), nullptr)) ||
432 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
433 *getTypeConverter(), nullptr))) {
434 return rewriter.notifyMatchFailure(whileOp,
435 "region types conversion failed");
436 }
437
438 // Prepare the before region (condition evaluation) for merging.
439 Block *beforeBlock = &whileOp.getBefore().front();
440 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
441 rewriter.setInsertionPointToStart(bodyBlock);
442
443 // Load current variable values to use as initial arguments for the
444 // condition block.
445 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
446 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
447
448 Operation *condTerminator =
449 loweredDo.getBodyRegion().back().getTerminator();
450 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
451 rewriter.setInsertionPoint(condOp);
452
453 // Update result variables with values from scf::condition.
454 SmallVector<Value> conditionArgs;
455 for (Value arg : condOp.getArgs()) {
456 conditionArgs.push_back(rewriter.getRemappedValue(arg));
457 }
458 assignValues(conditionArgs, resultVars, rewriter, loc);
459
460 // Convert scf.condition to condition variable assignment.
461 Value condition = rewriter.getRemappedValue(condOp.getCondition());
462 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
463
464 // Wrap body region in conditional to preserve scf semantics. Only create
465 // ifOp if after-region is non-empty.
466 if (whileOp.getAfterBody()->getOperations().size() > 1) {
467 auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
468
469 // Prepare the after region (loop body) for merging.
470 Block *afterBlock = &whileOp.getAfter().front();
471 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
472
473 // Replacement values for after block using condition op arguments.
474 SmallVector<Value> afterReplacingValues;
475 for (Value arg : condOp.getArgs())
476 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
477
478 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
479
480 if (failed(lowerYield(whileOp, loopVars, rewriter,
481 cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
482 return failure();
483 }
484
485 rewriter.eraseOp(condOp);
486
487 // Create condition region that loads from the flag variable.
488 Region &condRegion = loweredDo.getConditionRegion();
489 Block *condBlock = rewriter.createBlock(&condRegion);
490 rewriter.setInsertionPointToStart(condBlock);
491
492 auto exprOp = emitc::ExpressionOp::create(
493 rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
494 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
495
496 // Set up the expression block to load the condition variable.
497 exprBlock->addArgument(conditionVal.getType(), loc);
498 rewriter.setInsertionPointToStart(exprBlock);
499
500 // Load the condition value and yield it as the expression result.
501 Value cond =
502 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
503 emitc::YieldOp::create(rewriter, loc, cond);
504
505 // Yield the expression as the condition region result.
506 rewriter.setInsertionPointToEnd(condBlock);
507 emitc::YieldOp::create(rewriter, loc, exprOp);
508
509 return success();
510 }
511};
512
514 TypeConverter &typeConverter) {
515 patterns.add<ForLowering>(typeConverter, patterns.getContext());
516 patterns.add<IfLowering>(typeConverter, patterns.getContext());
517 patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
518 patterns.add<WhileLowering>(typeConverter, patterns.getContext());
519}
520
521void SCFToEmitCPass::runOnOperation() {
522 RewritePatternSet patterns(&getContext());
523 EmitCTypeConverter typeConverter(&getContext());
525 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
526
527 // Configure conversion to lower out SCF operations.
529 target
530 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
531 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
532 if (failed(
533 applyPartialConversion(getOperation(), target, std::move(patterns))))
534 signalPassFailure();
535}
return success()
b getContext())
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from common builtin types to the EmitC dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override