MLIR  21.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCFTOEMITC
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::scf;
34 
35 namespace {
36 
37 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
38  void runOnOperation() override;
39 };
40 
41 // Lower scf::for to emitc::for, implementing result values using
42 // emitc::variable's updated within the loop body.
43 struct ForLowering : public OpConversionPattern<ForOp> {
45 
46  LogicalResult
47  matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
48  ConversionPatternRewriter &rewriter) const override;
49 };
50 
51 // Create an uninitialized emitc::variable op for each result of the given op.
52 template <typename T>
53 static LogicalResult
54 createVariablesForResults(T op, const TypeConverter *typeConverter,
55  ConversionPatternRewriter &rewriter,
56  SmallVector<Value> &resultVariables) {
57  if (!op.getNumResults())
58  return success();
59 
60  Location loc = op->getLoc();
61  MLIRContext *context = op.getContext();
62 
63  OpBuilder::InsertionGuard guard(rewriter);
64  rewriter.setInsertionPoint(op);
65 
66  for (OpResult result : op.getResults()) {
67  Type resultType = typeConverter->convertType(result.getType());
68  if (!resultType)
69  return rewriter.notifyMatchFailure(op, "result type conversion failed");
70  Type varType = emitc::LValueType::get(resultType);
71  emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
72  emitc::VariableOp var =
73  rewriter.create<emitc::VariableOp>(loc, varType, noInit);
74  resultVariables.push_back(var);
75  }
76 
77  return success();
78 }
79 
80 // Create a series of assign ops assigning given values to given variables at
81 // the current insertion point of given rewriter.
82 static void assignValues(ValueRange values, ValueRange variables,
83  ConversionPatternRewriter &rewriter, Location loc) {
84  for (auto [value, var] : llvm::zip(values, variables))
85  rewriter.create<emitc::AssignOp>(loc, var, value);
86 }
87 
88 SmallVector<Value> loadValues(const SmallVector<Value> &variables,
89  PatternRewriter &rewriter, Location loc) {
90  return llvm::map_to_vector<>(variables, [&](Value var) {
91  Type type = cast<emitc::LValueType>(var.getType()).getValueType();
92  return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
93  });
94 }
95 
96 static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
97  ConversionPatternRewriter &rewriter,
98  scf::YieldOp yield) {
99  Location loc = yield.getLoc();
100 
101  OpBuilder::InsertionGuard guard(rewriter);
102  rewriter.setInsertionPoint(yield);
103 
104  SmallVector<Value> yieldOperands;
105  if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
106  return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
107  }
108 
109  assignValues(yieldOperands, resultVariables, rewriter, loc);
110 
111  rewriter.create<emitc::YieldOp>(loc);
112  rewriter.eraseOp(yield);
113 
114  return success();
115 }
116 
117 // Lower the contents of an scf::if/scf::index_switch regions to an
118 // emitc::if/emitc::switch region. The contents of the lowering region is
119 // moved into the respective lowered region, but the scf::yield is replaced not
120 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
121 // set the yielded values into the result variables.
122 static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
123  ConversionPatternRewriter &rewriter,
124  Region &region, Region &loweredRegion) {
125  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
126  Operation *terminator = loweredRegion.back().getTerminator();
127  return lowerYield(op, resultVariables, rewriter,
128  cast<scf::YieldOp>(terminator));
129 }
130 
131 LogicalResult
132 ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
133  ConversionPatternRewriter &rewriter) const {
134  Location loc = forOp.getLoc();
135 
136  // Create an emitc::variable op for each result. These variables will be
137  // assigned to by emitc::assign ops within the loop body.
138  SmallVector<Value> resultVariables;
139  if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
140  resultVariables)))
141  return rewriter.notifyMatchFailure(forOp,
142  "create variables for results failed");
143 
144  assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
145 
146  emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
147  loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
148 
149  Block *loweredBody = loweredFor.getBody();
150 
151  // Erase the auto-generated terminator for the lowered for op.
152  rewriter.eraseOp(loweredBody->getTerminator());
153 
154  IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
155  rewriter.setInsertionPointToEnd(loweredBody);
156 
157  SmallVector<Value> iterArgsValues =
158  loadValues(resultVariables, rewriter, loc);
159 
160  rewriter.restoreInsertionPoint(ip);
161 
162  // Convert the original region types into the new types by adding unrealized
163  // casts in the beginning of the loop. This performs the conversion in place.
164  if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
165  *getTypeConverter(), nullptr))) {
166  return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
167  }
168 
169  // Register the replacements for the block arguments and inline the body of
170  // the scf.for loop into the body of the emitc::for loop.
171  Block *scfBody = &(forOp.getRegion().front());
172  SmallVector<Value> replacingValues;
173  replacingValues.push_back(loweredFor.getInductionVar());
174  replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
175  rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
176 
177  auto result = lowerYield(forOp, resultVariables, rewriter,
178  cast<scf::YieldOp>(loweredBody->getTerminator()));
179 
180  if (failed(result)) {
181  return result;
182  }
183 
184  // Load variables into SSA values after the for loop.
185  SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
186 
187  rewriter.replaceOp(forOp, resultValues);
188  return success();
189 }
190 
191 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
192 // updated within the then and else regions.
193 struct IfLowering : public OpConversionPattern<IfOp> {
195 
196  LogicalResult
197  matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override;
199 };
200 
201 } // namespace
202 
203 LogicalResult
204 IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const {
206  Location loc = ifOp.getLoc();
207 
208  // Create an emitc::variable op for each result. These variables will be
209  // assigned to by emitc::assign ops within the then & else regions.
210  SmallVector<Value> resultVariables;
211  if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
212  resultVariables)))
213  return rewriter.notifyMatchFailure(ifOp,
214  "create variables for results failed");
215 
216  // Utility function to lower the contents of an scf::if region to an emitc::if
217  // region. The contents of the scf::if regions is moved into the respective
218  // emitc::if regions, but the scf::yield is replaced not only with an
219  // emitc::yield, but also with a sequence of emitc::assign ops that set the
220  // yielded values into the result variables.
221  auto lowerRegion = [&resultVariables, &rewriter,
222  &ifOp](Region &region, Region &loweredRegion) {
223  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
224  Operation *terminator = loweredRegion.back().getTerminator();
225  auto result = lowerYield(ifOp, resultVariables, rewriter,
226  cast<scf::YieldOp>(terminator));
227  if (failed(result)) {
228  return result;
229  }
230  return success();
231  };
232 
233  Region &thenRegion = adaptor.getThenRegion();
234  Region &elseRegion = adaptor.getElseRegion();
235 
236  bool hasElseBlock = !elseRegion.empty();
237 
238  auto loweredIf =
239  rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
240 
241  Region &loweredThenRegion = loweredIf.getThenRegion();
242  auto result = lowerRegion(thenRegion, loweredThenRegion);
243  if (failed(result)) {
244  return result;
245  }
246 
247  if (hasElseBlock) {
248  Region &loweredElseRegion = loweredIf.getElseRegion();
249  auto result = lowerRegion(elseRegion, loweredElseRegion);
250  if (failed(result)) {
251  return result;
252  }
253  }
254 
255  rewriter.setInsertionPointAfter(ifOp);
256  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
257 
258  rewriter.replaceOp(ifOp, results);
259  return success();
260 }
261 
262 // Lower scf::index_switch to emitc::switch, implementing result values as
263 // emitc::variable's updated within the case and default regions.
264 struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
266 
267  LogicalResult
268  matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const override;
270 };
271 
273  IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const {
275  Location loc = indexSwitchOp.getLoc();
276 
277  // Create an emitc::variable op for each result. These variables will be
278  // assigned to by emitc::assign ops within the case and default regions.
279  SmallVector<Value> resultVariables;
280  if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
281  rewriter, resultVariables))) {
282  return rewriter.notifyMatchFailure(indexSwitchOp,
283  "create variables for results failed");
284  }
285 
286  auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
287  loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
288 
289  // Lowering all case regions.
290  for (auto pair :
291  llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
292  if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
293  *std::get<0>(pair), std::get<1>(pair)))) {
294  return failure();
295  }
296  }
297 
298  // Lowering default region.
299  if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
300  adaptor.getDefaultRegion(),
301  loweredSwitch.getDefaultRegion()))) {
302  return failure();
303  }
304 
305  rewriter.setInsertionPointAfter(indexSwitchOp);
306  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
307 
308  rewriter.replaceOp(indexSwitchOp, results);
309  return success();
310 }
311 
313  TypeConverter &typeConverter) {
314  patterns.add<ForLowering>(typeConverter, patterns.getContext());
315  patterns.add<IfLowering>(typeConverter, patterns.getContext());
316  patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
317 }
318 
319 void SCFToEmitCPass::runOnOperation() {
321  TypeConverter typeConverter;
322  // Fallback converter
323  // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324  // Type converters are called most to least recently inserted
325  typeConverter.addConversion([](Type t) { return t; });
326  populateEmitCSizeTTypeConversions(typeConverter);
328 
329  // Configure conversion to lower out SCF operations.
330  ConversionTarget target(getContext());
331  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
332  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
333  if (failed(
334  applyPartialConversion(getOperation(), target, std::move(patterns))))
335  signalPassFailure();
336 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:388
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
Definition: SCFToEmitC.cpp:312
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: SCFToEmitC.cpp:272