MLIR  21.0.0git
SCFToEmitC.cpp
Go to the documentation of this file.
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Transforms/Passes.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_SCFTOEMITC
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::scf;
35 
36 namespace {
37 
38 /// Implement the interface to convert SCF to EmitC.
39 struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
41 
42  /// Hook for derived dialect interface to provide conversion patterns
43  /// and mark dialect legal for the conversion target.
44  void populateConvertToEmitCConversionPatterns(
45  ConversionTarget &target, TypeConverter &typeConverter,
46  RewritePatternSet &patterns) const final {
47  populateEmitCSizeTTypeConversions(typeConverter);
49  }
50 };
51 } // namespace
52 
54  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
55  dialect->addInterfaces<SCFToEmitCDialectInterface>();
56  });
57 }
58 
59 namespace {
60 
61 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
62  void runOnOperation() override;
63 };
64 
65 // Lower scf::for to emitc::for, implementing result values using
66 // emitc::variable's updated within the loop body.
67 struct ForLowering : public OpConversionPattern<ForOp> {
69 
70  LogicalResult
71  matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override;
73 };
74 
75 // Create an uninitialized emitc::variable op for each result of the given op.
76 template <typename T>
77 static LogicalResult
78 createVariablesForResults(T op, const TypeConverter *typeConverter,
79  ConversionPatternRewriter &rewriter,
80  SmallVector<Value> &resultVariables) {
81  if (!op.getNumResults())
82  return success();
83 
84  Location loc = op->getLoc();
85  MLIRContext *context = op.getContext();
86 
87  OpBuilder::InsertionGuard guard(rewriter);
88  rewriter.setInsertionPoint(op);
89 
90  for (OpResult result : op.getResults()) {
91  Type resultType = typeConverter->convertType(result.getType());
92  if (!resultType)
93  return rewriter.notifyMatchFailure(op, "result type conversion failed");
94  Type varType = emitc::LValueType::get(resultType);
95  emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
96  emitc::VariableOp var =
97  rewriter.create<emitc::VariableOp>(loc, varType, noInit);
98  resultVariables.push_back(var);
99  }
100 
101  return success();
102 }
103 
104 // Create a series of assign ops assigning given values to given variables at
105 // the current insertion point of given rewriter.
106 static void assignValues(ValueRange values, ValueRange variables,
107  ConversionPatternRewriter &rewriter, Location loc) {
108  for (auto [value, var] : llvm::zip(values, variables))
109  rewriter.create<emitc::AssignOp>(loc, var, value);
110 }
111 
112 SmallVector<Value> loadValues(const SmallVector<Value> &variables,
113  PatternRewriter &rewriter, Location loc) {
114  return llvm::map_to_vector<>(variables, [&](Value var) {
115  Type type = cast<emitc::LValueType>(var.getType()).getValueType();
116  return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
117  });
118 }
119 
120 static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
121  ConversionPatternRewriter &rewriter,
122  scf::YieldOp yield) {
123  Location loc = yield.getLoc();
124 
125  OpBuilder::InsertionGuard guard(rewriter);
126  rewriter.setInsertionPoint(yield);
127 
128  SmallVector<Value> yieldOperands;
129  if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
130  return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
131  }
132 
133  assignValues(yieldOperands, resultVariables, rewriter, loc);
134 
135  rewriter.create<emitc::YieldOp>(loc);
136  rewriter.eraseOp(yield);
137 
138  return success();
139 }
140 
141 // Lower the contents of an scf::if/scf::index_switch regions to an
142 // emitc::if/emitc::switch region. The contents of the lowering region is
143 // moved into the respective lowered region, but the scf::yield is replaced not
144 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
145 // set the yielded values into the result variables.
146 static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
147  ConversionPatternRewriter &rewriter,
148  Region &region, Region &loweredRegion) {
149  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
150  Operation *terminator = loweredRegion.back().getTerminator();
151  return lowerYield(op, resultVariables, rewriter,
152  cast<scf::YieldOp>(terminator));
153 }
154 
155 LogicalResult
156 ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
157  ConversionPatternRewriter &rewriter) const {
158  Location loc = forOp.getLoc();
159 
160  // Create an emitc::variable op for each result. These variables will be
161  // assigned to by emitc::assign ops within the loop body.
162  SmallVector<Value> resultVariables;
163  if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
164  resultVariables)))
165  return rewriter.notifyMatchFailure(forOp,
166  "create variables for results failed");
167 
168  assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
169 
170  emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
171  loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
172 
173  Block *loweredBody = loweredFor.getBody();
174 
175  // Erase the auto-generated terminator for the lowered for op.
176  rewriter.eraseOp(loweredBody->getTerminator());
177 
178  IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
179  rewriter.setInsertionPointToEnd(loweredBody);
180 
181  SmallVector<Value> iterArgsValues =
182  loadValues(resultVariables, rewriter, loc);
183 
184  rewriter.restoreInsertionPoint(ip);
185 
186  // Convert the original region types into the new types by adding unrealized
187  // casts in the beginning of the loop. This performs the conversion in place.
188  if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
189  *getTypeConverter(), nullptr))) {
190  return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
191  }
192 
193  // Register the replacements for the block arguments and inline the body of
194  // the scf.for loop into the body of the emitc::for loop.
195  Block *scfBody = &(forOp.getRegion().front());
196  SmallVector<Value> replacingValues;
197  replacingValues.push_back(loweredFor.getInductionVar());
198  replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
199  rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
200 
201  auto result = lowerYield(forOp, resultVariables, rewriter,
202  cast<scf::YieldOp>(loweredBody->getTerminator()));
203 
204  if (failed(result)) {
205  return result;
206  }
207 
208  // Load variables into SSA values after the for loop.
209  SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
210 
211  rewriter.replaceOp(forOp, resultValues);
212  return success();
213 }
214 
215 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
216 // updated within the then and else regions.
217 struct IfLowering : public OpConversionPattern<IfOp> {
219 
220  LogicalResult
221  matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
222  ConversionPatternRewriter &rewriter) const override;
223 };
224 
225 } // namespace
226 
227 LogicalResult
228 IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
229  ConversionPatternRewriter &rewriter) const {
230  Location loc = ifOp.getLoc();
231 
232  // Create an emitc::variable op for each result. These variables will be
233  // assigned to by emitc::assign ops within the then & else regions.
234  SmallVector<Value> resultVariables;
235  if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
236  resultVariables)))
237  return rewriter.notifyMatchFailure(ifOp,
238  "create variables for results failed");
239 
240  // Utility function to lower the contents of an scf::if region to an emitc::if
241  // region. The contents of the scf::if regions is moved into the respective
242  // emitc::if regions, but the scf::yield is replaced not only with an
243  // emitc::yield, but also with a sequence of emitc::assign ops that set the
244  // yielded values into the result variables.
245  auto lowerRegion = [&resultVariables, &rewriter,
246  &ifOp](Region &region, Region &loweredRegion) {
247  rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
248  Operation *terminator = loweredRegion.back().getTerminator();
249  auto result = lowerYield(ifOp, resultVariables, rewriter,
250  cast<scf::YieldOp>(terminator));
251  if (failed(result)) {
252  return result;
253  }
254  return success();
255  };
256 
257  Region &thenRegion = adaptor.getThenRegion();
258  Region &elseRegion = adaptor.getElseRegion();
259 
260  bool hasElseBlock = !elseRegion.empty();
261 
262  auto loweredIf =
263  rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
264 
265  Region &loweredThenRegion = loweredIf.getThenRegion();
266  auto result = lowerRegion(thenRegion, loweredThenRegion);
267  if (failed(result)) {
268  return result;
269  }
270 
271  if (hasElseBlock) {
272  Region &loweredElseRegion = loweredIf.getElseRegion();
273  auto result = lowerRegion(elseRegion, loweredElseRegion);
274  if (failed(result)) {
275  return result;
276  }
277  }
278 
279  rewriter.setInsertionPointAfter(ifOp);
280  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
281 
282  rewriter.replaceOp(ifOp, results);
283  return success();
284 }
285 
286 // Lower scf::index_switch to emitc::switch, implementing result values as
287 // emitc::variable's updated within the case and default regions.
288 struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
290 
291  LogicalResult
292  matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
293  ConversionPatternRewriter &rewriter) const override;
294 };
295 
297  IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
298  ConversionPatternRewriter &rewriter) const {
299  Location loc = indexSwitchOp.getLoc();
300 
301  // Create an emitc::variable op for each result. These variables will be
302  // assigned to by emitc::assign ops within the case and default regions.
303  SmallVector<Value> resultVariables;
304  if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
305  rewriter, resultVariables))) {
306  return rewriter.notifyMatchFailure(indexSwitchOp,
307  "create variables for results failed");
308  }
309 
310  auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
311  loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
312 
313  // Lowering all case regions.
314  for (auto pair :
315  llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
316  if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
317  *std::get<0>(pair), std::get<1>(pair)))) {
318  return failure();
319  }
320  }
321 
322  // Lowering default region.
323  if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
324  adaptor.getDefaultRegion(),
325  loweredSwitch.getDefaultRegion()))) {
326  return failure();
327  }
328 
329  rewriter.setInsertionPointAfter(indexSwitchOp);
330  SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
331 
332  rewriter.replaceOp(indexSwitchOp, results);
333  return success();
334 }
335 
337  TypeConverter &typeConverter) {
338  patterns.add<ForLowering>(typeConverter, patterns.getContext());
339  patterns.add<IfLowering>(typeConverter, patterns.getContext());
340  patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
341 }
342 
343 void SCFToEmitCPass::runOnOperation() {
345  TypeConverter typeConverter;
346  // Fallback for other types.
347  typeConverter.addConversion([](Type type) -> std::optional<Type> {
348  if (!emitc::isSupportedEmitCType(type))
349  return {};
350  return type;
351  });
352  populateEmitCSizeTTypeConversions(typeConverter);
354 
355  // Configure conversion to lower out SCF operations.
356  ConversionTarget target(getContext());
357  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
358  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
359  if (failed(
360  applyPartialConversion(getOperation(), target, std::move(patterns))))
361  signalPassFailure();
362 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:388
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:62
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
Collect a set of patterns to convert SCF operations to the EmitC dialect.
Definition: SCFToEmitC.cpp:336
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertSCFToEmitCInterface(DialectRegistry &registry)
Definition: SCFToEmitC.cpp:53
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: SCFToEmitC.cpp:296