MLIR  20.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/Passes.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFTOEMITC
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::scf;
33 
34 namespace {
35 
36 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37  void runOnOperation() override;
38 };
39 
40 // Lower scf::for to emitc::for, implementing result values using
41 // emitc::variable's updated within the loop body.
42 struct ForLowering : public OpRewritePattern<ForOp> {
44 
45  LogicalResult matchAndRewrite(ForOp forOp,
46  PatternRewriter &rewriter) const override;
47 };
48 
49 // Create an uninitialized emitc::variable op for each result of the given op.
50 template <typename T>
51 static SmallVector<Value> createVariablesForResults(T op,
52  PatternRewriter &rewriter) {
53  SmallVector<Value> resultVariables;
54 
55  if (!op.getNumResults())
56  return resultVariables;
57 
58  Location loc = op->getLoc();
59  MLIRContext *context = op.getContext();
60 
61  OpBuilder::InsertionGuard guard(rewriter);
62  rewriter.setInsertionPoint(op);
63 
64  for (OpResult result : op.getResults()) {
65  Type resultType = result.getType();
66  Type varType = emitc::LValueType::get(resultType);
67  emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
68  emitc::VariableOp var =
69  rewriter.create<emitc::VariableOp>(loc, varType, noInit);
70  resultVariables.push_back(var);
71  }
72 
73  return resultVariables;
74 }
75 
76 // Create a series of assign ops assigning given values to given variables at
77 // the current insertion point of given rewriter.
78 static void assignValues(ValueRange values, SmallVector<Value> &variables,
79  PatternRewriter &rewriter, Location loc) {
80  for (auto [value, var] : llvm::zip(values, variables))
81  rewriter.create<emitc::AssignOp>(loc, var, value);
82 }
83 
84 SmallVector<Value> loadValues(const SmallVector<Value> &variables,
85  PatternRewriter &rewriter, Location loc) {
86  return llvm::map_to_vector<>(variables, [&](Value var) {
87  Type type = cast<emitc::LValueType>(var.getType()).getValueType();
88  return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
89  });
90 }
91 
92 static void lowerYield(SmallVector<Value> &resultVariables,
93  PatternRewriter &rewriter, scf::YieldOp yield) {
94  Location loc = yield.getLoc();
95  ValueRange operands = yield.getOperands();
96 
97  OpBuilder::InsertionGuard guard(rewriter);
98  rewriter.setInsertionPoint(yield);
99 
100  assignValues(operands, resultVariables, rewriter, loc);
101 
102  rewriter.create<emitc::YieldOp>(loc);
103  rewriter.eraseOp(yield);
104 }
105 
106 // Lower the contents of an scf::if/scf::index_switch regions to an
107 // emitc::if/emitc::switch region. The contents of the lowering region is
108 // moved into the respective lowered region, but the scf::yield is replaced not
109 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
110 // set the yielded values into the result variables.
111 static void lowerRegion(SmallVector<Value> &resultVariables,
112  PatternRewriter &rewriter, Region &region,
113  Region &loweredRegion) {
114  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
115  Operation *terminator = loweredRegion.back().getTerminator();
116  lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
117 }
118 
119 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
120  PatternRewriter &rewriter) const {
121  Location loc = forOp.getLoc();
122 
123  // Create an emitc::variable op for each result. These variables will be
124  // assigned to by emitc::assign ops within the loop body.
125  SmallVector<Value> resultVariables =
126  createVariablesForResults(forOp, rewriter);
127 
128  assignValues(forOp.getInits(), resultVariables, rewriter, loc);
129 
130  emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
131  loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
132 
133  Block *loweredBody = loweredFor.getBody();
134 
135  // Erase the auto-generated terminator for the lowered for op.
136  rewriter.eraseOp(loweredBody->getTerminator());
137 
138  IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
139  rewriter.setInsertionPointToEnd(loweredBody);
140 
141  SmallVector<Value> iterArgsValues =
142  loadValues(resultVariables, rewriter, loc);
143 
144  rewriter.restoreInsertionPoint(ip);
145 
146  SmallVector<Value> replacingValues;
147  replacingValues.push_back(loweredFor.getInductionVar());
148  replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
149 
150  rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
151  lowerYield(resultVariables, rewriter,
152  cast<scf::YieldOp>(loweredBody->getTerminator()));
153 
154  // Load variables into SSA values after the for loop.
155  SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
156 
157  rewriter.replaceOp(forOp, resultValues);
158  return success();
159 }
160 
161 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
162 // updated within the then and else regions.
163 struct IfLowering : public OpRewritePattern<IfOp> {
165 
166  LogicalResult matchAndRewrite(IfOp ifOp,
167  PatternRewriter &rewriter) const override;
168 };
169 
170 } // namespace
171 
172 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
173  PatternRewriter &rewriter) const {
174  Location loc = ifOp.getLoc();
175 
176  // Create an emitc::variable op for each result. These variables will be
177  // assigned to by emitc::assign ops within the then & else regions.
178  SmallVector<Value> resultVariables =
179  createVariablesForResults(ifOp, rewriter);
180 
181  Region &thenRegion = ifOp.getThenRegion();
182  Region &elseRegion = ifOp.getElseRegion();
183 
184  bool hasElseBlock = !elseRegion.empty();
185 
186  auto loweredIf =
187  rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
188 
189  Region &loweredThenRegion = loweredIf.getThenRegion();
190  lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
191 
192  if (hasElseBlock) {
193  Region &loweredElseRegion = loweredIf.getElseRegion();
194  lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
195  }
196 
197  rewriter.setInsertionPointAfter(ifOp);
198  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
199 
200  rewriter.replaceOp(ifOp, results);
201  return success();
202 }
203 
204 // Lower scf::index_switch to emitc::switch, implementing result values as
205 // emitc::variable's updated within the case and default regions.
206 struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
208 
209  LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
210  PatternRewriter &rewriter) const override;
211 };
212 
213 LogicalResult
214 IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
215  PatternRewriter &rewriter) const {
216  Location loc = indexSwitchOp.getLoc();
217 
218  // Create an emitc::variable op for each result. These variables will be
219  // assigned to by emitc::assign ops within the case and default regions.
220  SmallVector<Value> resultVariables =
221  createVariablesForResults(indexSwitchOp, rewriter);
222 
223  auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
224  loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
225  indexSwitchOp.getNumCases());
226 
227  // Lowering all case regions.
228  for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
229  loweredSwitch.getCaseRegions())) {
230  lowerRegion(resultVariables, rewriter, std::get<0>(pair),
231  std::get<1>(pair));
232  }
233 
234  // Lowering default region.
235  lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
236  loweredSwitch.getDefaultRegion());
237 
238  rewriter.setInsertionPointAfter(indexSwitchOp);
239  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
240 
241  rewriter.replaceOp(indexSwitchOp, results);
242  return success();
243 }
244 
246  patterns.add<ForLowering>(patterns.getContext());
247  patterns.add<IfLowering>(patterns.getContext());
248  patterns.add<IndexSwitchOpLowering>(patterns.getContext());
249 }
250 
251 void SCFToEmitCPass::runOnOperation() {
252  RewritePatternSet patterns(&getContext());
254 
255  // Configure conversion to lower out SCF operations.
256  ConversionTarget target(getContext());
257  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
258  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
259  if (failed(
260  applyPartialConversion(getOperation(), target, std::move(patterns))))
261  signalPassFailure();
262 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:393
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Include the generated interface declarations.
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
Definition: SCFToEmitC.cpp:245
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, PatternRewriter &rewriter) const override
Definition: SCFToEmitC.cpp:214
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358