MLIR  19.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/Passes.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFTOEMITC
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::scf;
33 
34 namespace {
35 
36 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37  void runOnOperation() override;
38 };
39 
40 // Lower scf::for to emitc::for, implementing result values using
41 // emitc::variable's updated within the loop body.
42 struct ForLowering : public OpRewritePattern<ForOp> {
44 
45  LogicalResult matchAndRewrite(ForOp forOp,
46  PatternRewriter &rewriter) const override;
47 };
48 
49 // Create an uninitialized emitc::variable op for each result of the given op.
50 template <typename T>
51 static SmallVector<Value> createVariablesForResults(T op,
52  PatternRewriter &rewriter) {
53  SmallVector<Value> resultVariables;
54 
55  if (!op.getNumResults())
56  return resultVariables;
57 
58  Location loc = op->getLoc();
59  MLIRContext *context = op.getContext();
60 
61  OpBuilder::InsertionGuard guard(rewriter);
62  rewriter.setInsertionPoint(op);
63 
64  for (OpResult result : op.getResults()) {
65  Type resultType = result.getType();
66  emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
67  emitc::VariableOp var =
68  rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69  resultVariables.push_back(var);
70  }
71 
72  return resultVariables;
73 }
74 
75 // Create a series of assign ops assigning given values to given variables at
76 // the current insertion point of given rewriter.
77 static void assignValues(ValueRange values, SmallVector<Value> &variables,
78  PatternRewriter &rewriter, Location loc) {
79  for (auto [value, var] : llvm::zip(values, variables))
80  rewriter.create<emitc::AssignOp>(loc, var, value);
81 }
82 
83 static void lowerYield(SmallVector<Value> &resultVariables,
84  PatternRewriter &rewriter, scf::YieldOp yield) {
85  Location loc = yield.getLoc();
86  ValueRange operands = yield.getOperands();
87 
88  OpBuilder::InsertionGuard guard(rewriter);
89  rewriter.setInsertionPoint(yield);
90 
91  assignValues(operands, resultVariables, rewriter, loc);
92 
93  rewriter.create<emitc::YieldOp>(loc);
94  rewriter.eraseOp(yield);
95 }
96 
97 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98  PatternRewriter &rewriter) const {
99  Location loc = forOp.getLoc();
100 
101  // Create an emitc::variable op for each result. These variables will be
102  // assigned to by emitc::assign ops within the loop body.
103  SmallVector<Value> resultVariables =
104  createVariablesForResults(forOp, rewriter);
105  SmallVector<Value> iterArgsVariables =
106  createVariablesForResults(forOp, rewriter);
107 
108  assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
109 
110  emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
111  loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
112 
113  Block *loweredBody = loweredFor.getBody();
114 
115  // Erase the auto-generated terminator for the lowered for op.
116  rewriter.eraseOp(loweredBody->getTerminator());
117 
118  SmallVector<Value> replacingValues;
119  replacingValues.push_back(loweredFor.getInductionVar());
120  replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
121 
122  rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
123  lowerYield(iterArgsVariables, rewriter,
124  cast<scf::YieldOp>(loweredBody->getTerminator()));
125 
126  // Copy iterArgs into results after the for loop.
127  assignValues(iterArgsVariables, resultVariables, rewriter, loc);
128 
129  rewriter.replaceOp(forOp, resultVariables);
130  return success();
131 }
132 
133 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
134 // updated within the then and else regions.
135 struct IfLowering : public OpRewritePattern<IfOp> {
137 
138  LogicalResult matchAndRewrite(IfOp ifOp,
139  PatternRewriter &rewriter) const override;
140 };
141 
142 } // namespace
143 
144 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
145  PatternRewriter &rewriter) const {
146  Location loc = ifOp.getLoc();
147 
148  // Create an emitc::variable op for each result. These variables will be
149  // assigned to by emitc::assign ops within the then & else regions.
150  SmallVector<Value> resultVariables =
151  createVariablesForResults(ifOp, rewriter);
152 
153  // Utility function to lower the contents of an scf::if region to an emitc::if
154  // region. The contents of the scf::if regions is moved into the respective
155  // emitc::if regions, but the scf::yield is replaced not only with an
156  // emitc::yield, but also with a sequence of emitc::assign ops that set the
157  // yielded values into the result variables.
158  auto lowerRegion = [&resultVariables, &rewriter](Region &region,
159  Region &loweredRegion) {
160  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
161  Operation *terminator = loweredRegion.back().getTerminator();
162  lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
163  };
164 
165  Region &thenRegion = ifOp.getThenRegion();
166  Region &elseRegion = ifOp.getElseRegion();
167 
168  bool hasElseBlock = !elseRegion.empty();
169 
170  auto loweredIf =
171  rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
172 
173  Region &loweredThenRegion = loweredIf.getThenRegion();
174  lowerRegion(thenRegion, loweredThenRegion);
175 
176  if (hasElseBlock) {
177  Region &loweredElseRegion = loweredIf.getElseRegion();
178  lowerRegion(elseRegion, loweredElseRegion);
179  }
180 
181  rewriter.replaceOp(ifOp, resultVariables);
182  return success();
183 }
184 
186  patterns.add<ForLowering>(patterns.getContext());
187  patterns.add<IfLowering>(patterns.getContext());
188 }
189 
190 void SCFToEmitCPass::runOnOperation() {
191  RewritePatternSet patterns(&getContext());
193 
194  // Configure conversion to lower out SCF operations.
195  ConversionTarget target(getContext());
196  target.addIllegalOp<scf::ForOp, scf::IfOp>();
197  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
198  if (failed(
199  applyPartialConversion(getOperation(), target, std::move(patterns))))
200  signalPassFailure();
201 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:30
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Include the generated interface declarations.
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
Definition: SCFToEmitC.cpp:185
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358