MLIR 23.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/MLIRContext.h"
24#include "llvm/Support/LogicalResult.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::scf;
33
34namespace {
35
36/// Implement the interface to convert SCF to EmitC.
37struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
39
40 /// Hook for derived dialect interface to provide conversion patterns
41 /// and mark dialect legal for the conversion target.
42 void populateConvertToEmitCConversionPatterns(
43 ConversionTarget &target, TypeConverter &typeConverter,
44 RewritePatternSet &patterns) const final {
46 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
47 }
48};
49} // namespace
50
52 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
53 dialect->addInterfaces<SCFToEmitCDialectInterface>();
54 });
55}
56
57namespace {
58
59struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
60 void runOnOperation() override;
61};
62
63// Lower scf::for to emitc::for, implementing result values using
64// emitc::variable's updated within the loop body.
65struct ForLowering : public OpConversionPattern<ForOp> {
66 using OpConversionPattern<ForOp>::OpConversionPattern;
67
68 LogicalResult
69 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter) const override;
71};
72
73// Create an uninitialized emitc::variable op for each result of the given op.
74template <typename T>
75static LogicalResult
76createVariablesForResults(T op, const TypeConverter *typeConverter,
77 ConversionPatternRewriter &rewriter,
78 SmallVector<Value> &resultVariables) {
79 if (!op.getNumResults())
80 return success();
81
82 Location loc = op->getLoc();
83 MLIRContext *context = op.getContext();
84
85 OpBuilder::InsertionGuard guard(rewriter);
86 rewriter.setInsertionPoint(op);
87
88 for (OpResult result : op.getResults()) {
89 Type resultType = typeConverter->convertType(result.getType());
90 if (!resultType)
91 return rewriter.notifyMatchFailure(op, "result type conversion failed");
92 if (isa<emitc::ArrayType>(resultType))
93 return rewriter.notifyMatchFailure(
94 op, "cannot create variable for result of array type");
95 Type varType = emitc::LValueType::get(resultType);
96 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
97 emitc::VariableOp var =
98 emitc::VariableOp::create(rewriter, loc, varType, noInit);
99 resultVariables.push_back(var);
100 }
101
102 return success();
103}
104
105// Create a series of assign ops assigning given values to given variables at
106// the current insertion point of given rewriter.
107static void assignValues(ValueRange values, ValueRange variables,
108 ConversionPatternRewriter &rewriter, Location loc) {
109 for (auto [value, var] : llvm::zip(values, variables))
110 emitc::AssignOp::create(rewriter, loc, var, value);
111}
112
113SmallVector<Value> loadValues(ArrayRef<Value> variables,
114 PatternRewriter &rewriter, Location loc) {
115 return llvm::map_to_vector<>(variables, [&](Value var) {
116 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
117 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
118 });
119}
120
121static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
122 ConversionPatternRewriter &rewriter,
123 scf::YieldOp yield, bool createYield = true) {
124 Location loc = yield.getLoc();
125
126 OpBuilder::InsertionGuard guard(rewriter);
127 rewriter.setInsertionPoint(yield);
128
129 SmallVector<Value> yieldOperands;
130 if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
131 return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
132
133 assignValues(yieldOperands, resultVariables, rewriter, loc);
134
135 emitc::YieldOp::create(rewriter, loc);
136 rewriter.eraseOp(yield);
137
138 return success();
139}
140
141// Lower the contents of an scf::if/scf::index_switch regions to an
142// emitc::if/emitc::switch region. The contents of the lowering region is
143// moved into the respective lowered region, but the scf::yield is replaced not
144// only with an emitc::yield, but also with a sequence of emitc::assign ops that
145// set the yielded values into the result variables.
146static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
147 ConversionPatternRewriter &rewriter,
148 Region &region, Region &loweredRegion) {
149 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
150 Operation *terminator = loweredRegion.back().getTerminator();
151 return lowerYield(op, resultVariables, rewriter,
152 cast<scf::YieldOp>(terminator));
153}
154
155LogicalResult
156ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const {
158 Location loc = forOp.getLoc();
159
160 if (forOp.getUnsignedCmp())
161 return rewriter.notifyMatchFailure(forOp,
162 "unsigned loops are not supported");
163
164 // Create an emitc::variable op for each result. These variables will be
165 // assigned to by emitc::assign ops within the loop body.
166 SmallVector<Value> resultVariables;
167 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
168 resultVariables)))
169 return rewriter.notifyMatchFailure(forOp,
170 "create variables for results failed");
171
172 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
173
174 emitc::ForOp loweredFor =
175 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
176 adaptor.getUpperBound(), adaptor.getStep());
177
178 Block *loweredBody = loweredFor.getBody();
179
180 // Erase the auto-generated terminator for the lowered for op.
181 rewriter.eraseOp(loweredBody->getTerminator());
182
183 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
184 rewriter.setInsertionPointToEnd(loweredBody);
185
186 SmallVector<Value> iterArgsValues =
187 loadValues(resultVariables, rewriter, loc);
188
189 rewriter.restoreInsertionPoint(ip);
190
191 // Convert the original region types into the new types by adding unrealized
192 // casts in the beginning of the loop. This performs the conversion in place.
193 if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
194 *getTypeConverter(), nullptr))) {
195 return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
196 }
197
198 // Register the replacements for the block arguments and inline the body of
199 // the scf.for loop into the body of the emitc::for loop.
200 Block *scfBody = &(forOp.getRegion().front());
201 SmallVector<Value> replacingValues;
202 replacingValues.push_back(loweredFor.getInductionVar());
203 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
204 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
205
206 auto result = lowerYield(forOp, resultVariables, rewriter,
207 cast<scf::YieldOp>(loweredBody->getTerminator()));
208
209 if (failed(result)) {
210 return result;
211 }
212
213 // Load variables into SSA values after the for loop.
214 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
215
216 rewriter.replaceOp(forOp, resultValues);
217 return success();
218}
219
220// Lower scf::if to emitc::if, implementing result values as emitc::variable's
221// updated within the then and else regions.
222struct IfLowering : public OpConversionPattern<IfOp> {
223 using OpConversionPattern<IfOp>::OpConversionPattern;
224
225 LogicalResult
226 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override;
228};
229
230} // namespace
231
232LogicalResult
233IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const {
235 Location loc = ifOp.getLoc();
236
237 // Create an emitc::variable op for each result. These variables will be
238 // assigned to by emitc::assign ops within the then & else regions.
239 SmallVector<Value> resultVariables;
240 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
241 resultVariables)))
242 return rewriter.notifyMatchFailure(ifOp,
243 "create variables for results failed");
244
245 // Utility function to lower the contents of an scf::if region to an emitc::if
246 // region. The contents of the scf::if regions is moved into the respective
247 // emitc::if regions, but the scf::yield is replaced not only with an
248 // emitc::yield, but also with a sequence of emitc::assign ops that set the
249 // yielded values into the result variables.
250 auto lowerRegion = [&resultVariables, &rewriter,
251 &ifOp](Region &region, Region &loweredRegion) {
252 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
253 Operation *terminator = loweredRegion.back().getTerminator();
254 auto result = lowerYield(ifOp, resultVariables, rewriter,
255 cast<scf::YieldOp>(terminator));
256 if (failed(result)) {
257 return result;
258 }
259 return success();
260 };
261
262 Region &thenRegion = adaptor.getThenRegion();
263 Region &elseRegion = adaptor.getElseRegion();
264
265 bool hasElseBlock = !elseRegion.empty();
266
267 auto loweredIf =
268 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
269
270 Region &loweredThenRegion = loweredIf.getThenRegion();
271 auto result = lowerRegion(thenRegion, loweredThenRegion);
272 if (failed(result)) {
273 return result;
274 }
275
276 if (hasElseBlock) {
277 Region &loweredElseRegion = loweredIf.getElseRegion();
278 auto result = lowerRegion(elseRegion, loweredElseRegion);
279 if (failed(result)) {
280 return result;
281 }
282 }
283
284 rewriter.setInsertionPointAfter(ifOp);
285 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
286
287 rewriter.replaceOp(ifOp, results);
288 return success();
289}
290
291// Lower scf::index_switch to emitc::switch, implementing result values as
292// emitc::variable's updated within the case and default regions.
293struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
294 using OpConversionPattern::OpConversionPattern;
295
296 LogicalResult
297 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override;
299};
300
302 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const {
304 Location loc = indexSwitchOp.getLoc();
305
306 // Create an emitc::variable op for each result. These variables will be
307 // assigned to by emitc::assign ops within the case and default regions.
308 SmallVector<Value> resultVariables;
309 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
310 rewriter, resultVariables))) {
311 return rewriter.notifyMatchFailure(indexSwitchOp,
312 "create variables for results failed");
313 }
314
315 auto loweredSwitch =
316 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
317 adaptor.getCases(), indexSwitchOp.getNumCases());
318
319 // Lowering all case regions.
320 for (auto pair :
321 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
322 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
323 *std::get<0>(pair), std::get<1>(pair)))) {
324 return failure();
325 }
326 }
327
328 // Lowering default region.
329 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
330 adaptor.getDefaultRegion(),
331 loweredSwitch.getDefaultRegion()))) {
332 return failure();
333 }
334
335 rewriter.setInsertionPointAfter(indexSwitchOp);
336 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
337
338 rewriter.replaceOp(indexSwitchOp, results);
339 return success();
340}
341
342// Lower scf::while to emitc::do using mutable variables to maintain loop state
343// across iterations. The do-while structure ensures the condition is evaluated
344// after each iteration, matching SCF while semantics.
345struct WhileLowering : public OpConversionPattern<WhileOp> {
346 using OpConversionPattern::OpConversionPattern;
347
348 LogicalResult
349 matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351 Location loc = whileOp.getLoc();
352 MLIRContext *context = loc.getContext();
353
354 // Create an emitc::variable op for each result. These variables will be
355 // assigned to by emitc::assign ops within the loop body.
356 SmallVector<Value> resultVariables;
357 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
358 resultVariables)))
359 return rewriter.notifyMatchFailure(whileOp,
360 "Failed to create result variables");
361
362 // Create variable storage for loop-carried values to enable imperative
363 // updates while maintaining SSA semantics at conversion boundaries.
364 SmallVector<Value> loopVariables;
365 if (failed(createVariablesForLoopCarriedValues(
366 whileOp, rewriter, loopVariables, loc, context)))
367 return failure();
368
369 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
370 rewriter, loc)))
371 return failure();
372
373 rewriter.setInsertionPointAfter(whileOp);
374
375 // Load the final result values from result variables.
376 SmallVector<Value> finalResults =
377 loadValues(resultVariables, rewriter, loc);
378 rewriter.replaceOp(whileOp, finalResults);
379
380 return success();
381 }
382
383private:
384 // Initialize variables for loop-carried values to enable state updates
385 // across iterations without SSA argument passing.
386 LogicalResult createVariablesForLoopCarriedValues(
387 WhileOp whileOp, ConversionPatternRewriter &rewriter,
388 SmallVectorImpl<Value> &loopVars, Location loc,
389 MLIRContext *context) const {
390 OpBuilder::InsertionGuard guard(rewriter);
391 rewriter.setInsertionPoint(whileOp);
392
393 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
394
395 for (Value init : whileOp.getInits()) {
396 Type convertedType = getTypeConverter()->convertType(init.getType());
397 if (!convertedType)
398 return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
399 if (isa<emitc::ArrayType>(convertedType))
400 return rewriter.notifyMatchFailure(
401 whileOp,
402 "cannot create variable for loop-carried value of array type");
403
404 auto var = emitc::VariableOp::create(
405 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
406 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
407 loopVars.push_back(var);
408 }
409
410 return success();
411 }
412
413 // Lower scf.while to emitc.do.
414 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
415 ArrayRef<Value> resultVars, MLIRContext *context,
416 ConversionPatternRewriter &rewriter,
417 Location loc) const {
418 // Create a global boolean variable to store the loop condition state.
419 Type i1Type = IntegerType::get(context, 1);
420 auto globalCondition =
421 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
422 emitc::OpaqueAttr::get(context, ""));
423 Value conditionVal = globalCondition.getResult();
424
425 auto loweredDo = emitc::DoOp::create(rewriter, loc);
426
427 // Convert region types to match the target dialect type system.
428 if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
429 *getTypeConverter(), nullptr)) ||
430 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
431 *getTypeConverter(), nullptr))) {
432 return rewriter.notifyMatchFailure(whileOp,
433 "region types conversion failed");
434 }
435
436 // Prepare the before region (condition evaluation) for merging.
437 Block *beforeBlock = &whileOp.getBefore().front();
438 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
439 rewriter.setInsertionPointToStart(bodyBlock);
440
441 // Load current variable values to use as initial arguments for the
442 // condition block.
443 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
444 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
445
446 Operation *condTerminator =
447 loweredDo.getBodyRegion().back().getTerminator();
448 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
449 rewriter.setInsertionPoint(condOp);
450
451 // Update result variables with values from scf::condition.
452 SmallVector<Value> conditionArgs;
453 for (Value arg : condOp.getArgs()) {
454 conditionArgs.push_back(rewriter.getRemappedValue(arg));
455 }
456 assignValues(conditionArgs, resultVars, rewriter, loc);
457
458 // Convert scf.condition to condition variable assignment.
459 Value condition = rewriter.getRemappedValue(condOp.getCondition());
460 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
461
462 // Wrap body region in conditional to preserve scf semantics. Only create
463 // ifOp if after-region is non-empty.
464 if (whileOp.getAfterBody()->getOperations().size() > 1) {
465 auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
466
467 // Prepare the after region (loop body) for merging.
468 Block *afterBlock = &whileOp.getAfter().front();
469 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
470
471 // Replacement values for after block using condition op arguments.
472 SmallVector<Value> afterReplacingValues;
473 for (Value arg : condOp.getArgs())
474 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
475
476 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
477
478 if (failed(lowerYield(whileOp, loopVars, rewriter,
479 cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
480 return failure();
481 }
482
483 rewriter.eraseOp(condOp);
484
485 // Create condition region that loads from the flag variable.
486 Region &condRegion = loweredDo.getConditionRegion();
487 Block *condBlock = rewriter.createBlock(&condRegion);
488 rewriter.setInsertionPointToStart(condBlock);
489
490 auto exprOp = emitc::ExpressionOp::create(
491 rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
492 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
493
494 // Set up the expression block to load the condition variable.
495 exprBlock->addArgument(conditionVal.getType(), loc);
496 rewriter.setInsertionPointToStart(exprBlock);
497
498 // Load the condition value and yield it as the expression result.
499 Value cond =
500 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
501 emitc::YieldOp::create(rewriter, loc, cond);
502
503 // Yield the expression as the condition region result.
504 rewriter.setInsertionPointToEnd(condBlock);
505 emitc::YieldOp::create(rewriter, loc, exprOp);
506
507 return success();
508 }
509};
510
512 TypeConverter &typeConverter) {
513 patterns.add<ForLowering>(typeConverter, patterns.getContext());
514 patterns.add<IfLowering>(typeConverter, patterns.getContext());
515 patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
516 patterns.add<WhileLowering>(typeConverter, patterns.getContext());
517}
518
519void SCFToEmitCPass::runOnOperation() {
520 RewritePatternSet patterns(&getContext());
521 TypeConverter typeConverter;
522 // Fallback for other types.
523 typeConverter.addConversion([](Type type) -> std::optional<Type> {
525 return {};
526 return type;
527 });
529 populateSCFToEmitCConversionPatterns(patterns, typeConverter);
530
531 // Configure conversion to lower out SCF operations.
532 ConversionTarget target(getContext());
533 target
534 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
535 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
536 if (failed(
537 applyPartialConversion(getOperation(), target, std::move(patterns))))
538 signalPassFailure();
539}
return success()
b getContext())
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition EmitC.cpp:62
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
void registerConvertSCFToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override