MLIR 22.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.if ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/MLIRContext.h"
24#include "llvm/Support/LogicalResult.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_SCFTOEMITC
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::scf;
33
34namespace {
35
36/// Implement the interface to convert SCF to EmitC.
37struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
39
40 /// Hook for derived dialect interface to provide conversion patterns
41 /// and mark dialect legal for the conversion target.
42 void populateConvertToEmitCConversionPatterns(
43 ConversionTarget &target, TypeConverter &typeConverter,
44 RewritePatternSet &patterns) const final {
47 }
48};
49} // namespace
50
52 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
53 dialect->addInterfaces<SCFToEmitCDialectInterface>();
54 });
55}
56
57namespace {
58
59struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
60 void runOnOperation() override;
61};
62
63// Lower scf::for to emitc::for, implementing result values using
64// emitc::variable's updated within the loop body.
65struct ForLowering : public OpConversionPattern<ForOp> {
66 using OpConversionPattern<ForOp>::OpConversionPattern;
67
68 LogicalResult
69 matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter) const override;
71};
72
73// Create an uninitialized emitc::variable op for each result of the given op.
74template <typename T>
75static LogicalResult
76createVariablesForResults(T op, const TypeConverter *typeConverter,
77 ConversionPatternRewriter &rewriter,
78 SmallVector<Value> &resultVariables) {
79 if (!op.getNumResults())
80 return success();
81
82 Location loc = op->getLoc();
83 MLIRContext *context = op.getContext();
84
85 OpBuilder::InsertionGuard guard(rewriter);
86 rewriter.setInsertionPoint(op);
87
88 for (OpResult result : op.getResults()) {
89 Type resultType = typeConverter->convertType(result.getType());
90 if (!resultType)
91 return rewriter.notifyMatchFailure(op, "result type conversion failed");
92 Type varType = emitc::LValueType::get(resultType);
93 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
94 emitc::VariableOp var =
95 emitc::VariableOp::create(rewriter, loc, varType, noInit);
96 resultVariables.push_back(var);
97 }
98
99 return success();
100}
101
102// Create a series of assign ops assigning given values to given variables at
103// the current insertion point of given rewriter.
104static void assignValues(ValueRange values, ValueRange variables,
105 ConversionPatternRewriter &rewriter, Location loc) {
106 for (auto [value, var] : llvm::zip(values, variables))
107 emitc::AssignOp::create(rewriter, loc, var, value);
108}
109
110SmallVector<Value> loadValues(ArrayRef<Value> variables,
111 PatternRewriter &rewriter, Location loc) {
112 return llvm::map_to_vector<>(variables, [&](Value var) {
113 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
114 return emitc::LoadOp::create(rewriter, loc, type, var).getResult();
115 });
116}
117
118static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
119 ConversionPatternRewriter &rewriter,
120 scf::YieldOp yield, bool createYield = true) {
121 Location loc = yield.getLoc();
122
123 OpBuilder::InsertionGuard guard(rewriter);
124 rewriter.setInsertionPoint(yield);
125
126 SmallVector<Value> yieldOperands;
127 if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
128 return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
129
130 assignValues(yieldOperands, resultVariables, rewriter, loc);
131
132 emitc::YieldOp::create(rewriter, loc);
133 rewriter.eraseOp(yield);
134
135 return success();
136}
137
138// Lower the contents of an scf::if/scf::index_switch regions to an
139// emitc::if/emitc::switch region. The contents of the lowering region is
140// moved into the respective lowered region, but the scf::yield is replaced not
141// only with an emitc::yield, but also with a sequence of emitc::assign ops that
142// set the yielded values into the result variables.
143static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
144 ConversionPatternRewriter &rewriter,
145 Region &region, Region &loweredRegion) {
146 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
147 Operation *terminator = loweredRegion.back().getTerminator();
148 return lowerYield(op, resultVariables, rewriter,
149 cast<scf::YieldOp>(terminator));
150}
151
152LogicalResult
153ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const {
155 Location loc = forOp.getLoc();
156
157 if (forOp.getUnsignedCmp())
158 return rewriter.notifyMatchFailure(forOp,
159 "unsigned loops are not supported");
160
161 // Create an emitc::variable op for each result. These variables will be
162 // assigned to by emitc::assign ops within the loop body.
163 SmallVector<Value> resultVariables;
164 if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
165 resultVariables)))
166 return rewriter.notifyMatchFailure(forOp,
167 "create variables for results failed");
168
169 assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
170
171 emitc::ForOp loweredFor =
172 emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(),
173 adaptor.getUpperBound(), adaptor.getStep());
174
175 Block *loweredBody = loweredFor.getBody();
176
177 // Erase the auto-generated terminator for the lowered for op.
178 rewriter.eraseOp(loweredBody->getTerminator());
179
180 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
181 rewriter.setInsertionPointToEnd(loweredBody);
182
183 SmallVector<Value> iterArgsValues =
184 loadValues(resultVariables, rewriter, loc);
185
186 rewriter.restoreInsertionPoint(ip);
187
188 // Convert the original region types into the new types by adding unrealized
189 // casts in the beginning of the loop. This performs the conversion in place.
190 if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
191 *getTypeConverter(), nullptr))) {
192 return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
193 }
194
195 // Register the replacements for the block arguments and inline the body of
196 // the scf.for loop into the body of the emitc::for loop.
197 Block *scfBody = &(forOp.getRegion().front());
198 SmallVector<Value> replacingValues;
199 replacingValues.push_back(loweredFor.getInductionVar());
200 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
201 rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
202
203 auto result = lowerYield(forOp, resultVariables, rewriter,
204 cast<scf::YieldOp>(loweredBody->getTerminator()));
205
206 if (failed(result)) {
207 return result;
208 }
209
210 // Load variables into SSA values after the for loop.
211 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
212
213 rewriter.replaceOp(forOp, resultValues);
214 return success();
215}
216
217// Lower scf::if to emitc::if, implementing result values as emitc::variable's
218// updated within the then and else regions.
219struct IfLowering : public OpConversionPattern<IfOp> {
220 using OpConversionPattern<IfOp>::OpConversionPattern;
221
222 LogicalResult
223 matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter) const override;
225};
226
227} // namespace
228
229LogicalResult
230IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter) const {
232 Location loc = ifOp.getLoc();
233
234 // Create an emitc::variable op for each result. These variables will be
235 // assigned to by emitc::assign ops within the then & else regions.
236 SmallVector<Value> resultVariables;
237 if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
238 resultVariables)))
239 return rewriter.notifyMatchFailure(ifOp,
240 "create variables for results failed");
241
242 // Utility function to lower the contents of an scf::if region to an emitc::if
243 // region. The contents of the scf::if regions is moved into the respective
244 // emitc::if regions, but the scf::yield is replaced not only with an
245 // emitc::yield, but also with a sequence of emitc::assign ops that set the
246 // yielded values into the result variables.
247 auto lowerRegion = [&resultVariables, &rewriter,
248 &ifOp](Region &region, Region &loweredRegion) {
249 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
250 Operation *terminator = loweredRegion.back().getTerminator();
251 auto result = lowerYield(ifOp, resultVariables, rewriter,
252 cast<scf::YieldOp>(terminator));
253 if (failed(result)) {
254 return result;
255 }
256 return success();
257 };
258
259 Region &thenRegion = adaptor.getThenRegion();
260 Region &elseRegion = adaptor.getElseRegion();
261
262 bool hasElseBlock = !elseRegion.empty();
263
264 auto loweredIf =
265 emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false);
266
267 Region &loweredThenRegion = loweredIf.getThenRegion();
268 auto result = lowerRegion(thenRegion, loweredThenRegion);
269 if (failed(result)) {
270 return result;
271 }
272
273 if (hasElseBlock) {
274 Region &loweredElseRegion = loweredIf.getElseRegion();
275 auto result = lowerRegion(elseRegion, loweredElseRegion);
276 if (failed(result)) {
277 return result;
278 }
279 }
280
281 rewriter.setInsertionPointAfter(ifOp);
282 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
283
284 rewriter.replaceOp(ifOp, results);
285 return success();
286}
287
288// Lower scf::index_switch to emitc::switch, implementing result values as
289// emitc::variable's updated within the case and default regions.
290struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
291 using OpConversionPattern::OpConversionPattern;
292
293 LogicalResult
294 matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override;
296};
297
299 IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
300 ConversionPatternRewriter &rewriter) const {
301 Location loc = indexSwitchOp.getLoc();
302
303 // Create an emitc::variable op for each result. These variables will be
304 // assigned to by emitc::assign ops within the case and default regions.
305 SmallVector<Value> resultVariables;
306 if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
307 rewriter, resultVariables))) {
308 return rewriter.notifyMatchFailure(indexSwitchOp,
309 "create variables for results failed");
310 }
311
312 auto loweredSwitch =
313 emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(),
314 adaptor.getCases(), indexSwitchOp.getNumCases());
315
316 // Lowering all case regions.
317 for (auto pair :
318 llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
319 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
320 *std::get<0>(pair), std::get<1>(pair)))) {
321 return failure();
322 }
323 }
324
325 // Lowering default region.
326 if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
327 adaptor.getDefaultRegion(),
328 loweredSwitch.getDefaultRegion()))) {
329 return failure();
330 }
331
332 rewriter.setInsertionPointAfter(indexSwitchOp);
333 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
334
335 rewriter.replaceOp(indexSwitchOp, results);
336 return success();
337}
338
339// Lower scf::while to emitc::do using mutable variables to maintain loop state
340// across iterations. The do-while structure ensures the condition is evaluated
341// after each iteration, matching SCF while semantics.
342struct WhileLowering : public OpConversionPattern<WhileOp> {
343 using OpConversionPattern::OpConversionPattern;
344
345 LogicalResult
346 matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter) const override {
348 Location loc = whileOp.getLoc();
349 MLIRContext *context = loc.getContext();
350
351 // Create an emitc::variable op for each result. These variables will be
352 // assigned to by emitc::assign ops within the loop body.
353 SmallVector<Value> resultVariables;
354 if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
355 resultVariables)))
356 return rewriter.notifyMatchFailure(whileOp,
357 "Failed to create result variables");
358
359 // Create variable storage for loop-carried values to enable imperative
360 // updates while maintaining SSA semantics at conversion boundaries.
361 SmallVector<Value> loopVariables;
362 if (failed(createVariablesForLoopCarriedValues(
363 whileOp, rewriter, loopVariables, loc, context)))
364 return failure();
365
366 if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
367 rewriter, loc)))
368 return failure();
369
370 rewriter.setInsertionPointAfter(whileOp);
371
372 // Load the final result values from result variables.
373 SmallVector<Value> finalResults =
374 loadValues(resultVariables, rewriter, loc);
375 rewriter.replaceOp(whileOp, finalResults);
376
377 return success();
378 }
379
380private:
381 // Initialize variables for loop-carried values to enable state updates
382 // across iterations without SSA argument passing.
383 LogicalResult createVariablesForLoopCarriedValues(
384 WhileOp whileOp, ConversionPatternRewriter &rewriter,
385 SmallVectorImpl<Value> &loopVars, Location loc,
386 MLIRContext *context) const {
387 OpBuilder::InsertionGuard guard(rewriter);
388 rewriter.setInsertionPoint(whileOp);
389
390 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
391
392 for (Value init : whileOp.getInits()) {
393 Type convertedType = getTypeConverter()->convertType(init.getType());
394 if (!convertedType)
395 return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
396
397 auto var = emitc::VariableOp::create(
398 rewriter, loc, emitc::LValueType::get(convertedType), noInit);
399 emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
400 loopVars.push_back(var);
401 }
402
403 return success();
404 }
405
406 // Lower scf.while to emitc.do.
407 LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
408 ArrayRef<Value> resultVars, MLIRContext *context,
409 ConversionPatternRewriter &rewriter,
410 Location loc) const {
411 // Create a global boolean variable to store the loop condition state.
412 Type i1Type = IntegerType::get(context, 1);
413 auto globalCondition =
414 emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
415 emitc::OpaqueAttr::get(context, ""));
416 Value conditionVal = globalCondition.getResult();
417
418 auto loweredDo = emitc::DoOp::create(rewriter, loc);
419
420 // Convert region types to match the target dialect type system.
421 if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
422 *getTypeConverter(), nullptr)) ||
423 failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
424 *getTypeConverter(), nullptr))) {
425 return rewriter.notifyMatchFailure(whileOp,
426 "region types conversion failed");
427 }
428
429 // Prepare the before region (condition evaluation) for merging.
430 Block *beforeBlock = &whileOp.getBefore().front();
431 Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
432 rewriter.setInsertionPointToStart(bodyBlock);
433
434 // Load current variable values to use as initial arguments for the
435 // condition block.
436 SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
437 rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
438
439 Operation *condTerminator =
440 loweredDo.getBodyRegion().back().getTerminator();
441 scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
442 rewriter.setInsertionPoint(condOp);
443
444 // Update result variables with values from scf::condition.
445 SmallVector<Value> conditionArgs;
446 for (Value arg : condOp.getArgs()) {
447 conditionArgs.push_back(rewriter.getRemappedValue(arg));
448 }
449 assignValues(conditionArgs, resultVars, rewriter, loc);
450
451 // Convert scf.condition to condition variable assignment.
452 Value condition = rewriter.getRemappedValue(condOp.getCondition());
453 emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
454
455 // Wrap body region in conditional to preserve scf semantics. Only create
456 // ifOp if after-region is non-empty.
457 if (whileOp.getAfterBody()->getOperations().size() > 1) {
458 auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
459
460 // Prepare the after region (loop body) for merging.
461 Block *afterBlock = &whileOp.getAfter().front();
462 Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
463
464 // Replacement values for after block using condition op arguments.
465 SmallVector<Value> afterReplacingValues;
466 for (Value arg : condOp.getArgs())
467 afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
468
469 rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
470
471 if (failed(lowerYield(whileOp, loopVars, rewriter,
472 cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
473 return failure();
474 }
475
476 rewriter.eraseOp(condOp);
477
478 // Create condition region that loads from the flag variable.
479 Region &condRegion = loweredDo.getConditionRegion();
480 Block *condBlock = rewriter.createBlock(&condRegion);
481 rewriter.setInsertionPointToStart(condBlock);
482
483 auto exprOp = emitc::ExpressionOp::create(
484 rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
485 Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
486
487 // Set up the expression block to load the condition variable.
488 exprBlock->addArgument(conditionVal.getType(), loc);
489 rewriter.setInsertionPointToStart(exprBlock);
490
491 // Load the condition value and yield it as the expression result.
492 Value cond =
493 emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
494 emitc::YieldOp::create(rewriter, loc, cond);
495
496 // Yield the expression as the condition region result.
497 rewriter.setInsertionPointToEnd(condBlock);
498 emitc::YieldOp::create(rewriter, loc, exprOp);
499
500 return success();
501 }
502};
503
505 TypeConverter &typeConverter) {
506 patterns.add<ForLowering>(typeConverter, patterns.getContext());
507 patterns.add<IfLowering>(typeConverter, patterns.getContext());
508 patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
509 patterns.add<WhileLowering>(typeConverter, patterns.getContext());
510}
511
512void SCFToEmitCPass::runOnOperation() {
514 TypeConverter typeConverter;
515 // Fallback for other types.
516 typeConverter.addConversion([](Type type) -> std::optional<Type> {
518 return {};
519 return type;
520 });
523
524 // Configure conversion to lower out SCF operations.
525 ConversionTarget target(getContext());
526 target
527 .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
528 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
529 if (failed(
530 applyPartialConversion(getOperation(), target, std::move(patterns))))
531 signalPassFailure();
532}
return success()
b getContext())
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
iterator end()
Definition Region.h:56
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition EmitC.cpp:61
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
const FrozenRewritePatternSet & patterns
void registerConvertSCFToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override