MLIR 22.0.0git
InlinerExtension.cpp
Go to the documentation of this file.
1//===- InlinerExtension.cpp - Func Inliner Extension ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14
15using namespace mlir;
16using namespace mlir::func;
17
18//===----------------------------------------------------------------------===//
19// FuncDialect Interfaces
20//===----------------------------------------------------------------------===//
21namespace {
22/// This class defines the interface for handling inlining with func operations.
23struct FuncInlinerInterface : public DialectInlinerInterface {
25
26 //===--------------------------------------------------------------------===//
27 // Analysis Hooks
28 //===--------------------------------------------------------------------===//
29
30 /// Call operations can be inlined unless specified otherwise by attributes
31 /// on either the call or the callbale.
32 bool isLegalToInline(Operation *call, Operation *callable,
33 bool wouldBeCloned) const final {
34 auto callOp = dyn_cast<func::CallOp>(call);
35 auto funcOp = dyn_cast<func::FuncOp>(callable);
36 return !(callOp && callOp.getNoInline()) &&
37 !(funcOp && funcOp.getNoInline());
38 }
39
40 /// All operations can be inlined.
41 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
42 return true;
43 }
44
45 /// All function bodies can be inlined.
46 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
47 return true;
48 }
49
50 //===--------------------------------------------------------------------===//
51 // Transformation Hooks
52 //===--------------------------------------------------------------------===//
53
54 /// Handle the given inlined terminator by replacing it with a new operation
55 /// as necessary.
56 void handleTerminator(Operation *op, Block *newDest) const final {
57 // Only return needs to be handled here.
58 auto returnOp = dyn_cast<ReturnOp>(op);
59 if (!returnOp)
60 return;
61
62 // Replace the return with a branch to the dest.
63 OpBuilder builder(op);
64 cf::BranchOp::create(builder, op->getLoc(), newDest,
65 returnOp.getOperands());
66 op->erase();
67 }
68
69 /// Handle the given inlined terminator by replacing it with a new operation
70 /// as necessary.
71 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
72 // Only return needs to be handled here.
73 auto returnOp = cast<ReturnOp>(op);
74
75 // Replace the values directly with the return operands.
76 assert(returnOp.getNumOperands() == valuesToRepl.size());
77 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
78 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
79 }
80};
81} // namespace
82
83//===----------------------------------------------------------------------===//
84// Registration
85//===----------------------------------------------------------------------===//
86
88 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
89 dialect->addInterfaces<FuncInlinerInterface>();
90
91 // The inliner extension relies on the ControlFlow dialect.
92 ctx->getOrLoadDialect<cf::ControlFlowDialect>();
93 });
94}
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
void registerInlinerExtension(DialectRegistry &registry)
Register the extension used to support inlining the func dialect.
Include the generated interface declarations.