20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/DebugLog.h"
22 #include "llvm/Support/raw_ostream.h"
25 #define DEBUG_TYPE "inlining"
39 while (
auto nextCallSite = dyn_cast<CallSiteLoc>(lastCallee)) {
40 calleeInliningStack.push_back(nextCallSite);
41 lastCallee = nextCallSite.getCaller();
45 for (CallSiteLoc currentCallSite : reverse(calleeInliningStack))
59 auto [it, inserted] = mappedLocations.try_emplace(loc);
63 it->getSecond() = newLoc;
70 [&](
LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
74 for (
Block &block : inlinedBlocks) {
88 for (
auto &operand : op->getOpOperands())
90 operand.set(mappedOp);
92 for (
auto &block : inlinedBlocks)
93 block.walk(remapOperands);
101 bool wouldBeCloned)
const {
103 return handler->isLegalToInline(call, callable, wouldBeCloned);
111 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
119 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
125 return handler ? handler->shouldAnalyzeRecursively(op) :
true;
132 assert(handler &&
"expected valid dialect handler");
133 handler->handleTerminator(op, newDest);
141 assert(handler &&
"expected valid dialect handler");
142 handler->handleTerminator(op, valuesToRepl);
149 if (inlinedBlocks.empty()) {
153 assert(handler &&
"expected valid dialect handler");
154 return handler->allowSingleBlockOptimization(inlinedBlocks);
159 DictionaryAttr argumentAttrs)
const {
161 assert(handler &&
"expected valid dialect handler");
162 return handler->handleArgument(builder, call, callable, argument,
168 DictionaryAttr resultAttrs)
const {
170 assert(handler &&
"expected valid dialect handler");
171 return handler->handleResult(builder, call, callable, result, resultAttrs);
177 assert(handler &&
"expected valid dialect handler");
178 handler->processInlinedCallBlocks(call, inlinedBlocks);
183 Region *insertRegion,
bool shouldCloneInlinedRegion,
185 for (
auto &block : *src) {
186 for (
auto &op : block) {
189 if (isa<UnrealizedConversionCastOp>(op))
194 shouldCloneInlinedRegion, valueMapping)) {
195 LDBG() <<
"* Illegal to inline because of op: "
201 llvm::any_of(op.getRegions(), [&](
Region ®ion) {
202 return !isLegalToInline(interface, ®ion, insertRegion,
203 shouldCloneInlinedRegion, valueMapping);
216 CallOpInterface call,
217 CallableOpInterface callable,
221 callable.getCallableRegion()->getNumArguments(),
223 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
224 assert(arrayAttr.size() == argAttrs.size());
226 argAttrs[idx] = cast<DictionaryAttr>(attr);
230 for (
auto [blockArg, argAttr] :
231 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
233 builder, call, callable, mapper.
lookup(blockArg), argAttr);
234 assert(newArgument.
getType() == mapper.
lookup(blockArg).getType() &&
235 "expected the argument type to not change");
238 mapper.
map(blockArg, newArgument);
243 CallOpInterface call, CallableOpInterface callable,
248 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
249 assert(arrayAttr.size() == resAttrs.size());
251 resAttrs[idx] = cast<DictionaryAttr>(attr);
255 for (
auto [result, resAttr] : llvm::zip(results, resAttrs)) {
260 interface.
handleResult(builder, call, callable, result, resAttr);
261 assert(newResult.
getType() == result.getType() &&
262 "expected the result type to not change");
265 result.replaceUsesWithIf(newResult, [&](
OpOperand &operand) {
266 return resultUsers.count(operand.
getOwner());
276 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion,
277 CallOpInterface call = {}) {
278 assert(resultsToReplace.size() == regionResultTypes.size());
284 auto *srcEntryBlock = &src->
front();
285 if (llvm::any_of(srcEntryBlock->getArguments(),
291 if (!interface.
isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
293 !
isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
298 OpBuilder builder(inlineBlock, inlinePoint);
299 auto callable = dyn_cast<CallableOpInterface>(src->
getParentOp());
300 if (call && callable)
305 cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
306 shouldCloneInlinedRegion);
309 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
310 postInsertBlock->getIterator());
315 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
320 if (!shouldCloneInlinedRegion)
331 if (singleBlockFastPath && llvm::hasSingleElement(newBlocks)) {
334 builder.setInsertionPoint(firstBlockTerminator);
335 if (call && callable)
341 firstBlockTerminator->
erase();
346 postInsertBlock->
erase();
351 resultToRepl.value().replaceAllUsesWith(
352 postInsertBlock->
addArgument(regionResultTypes[resultToRepl.index()],
353 resultToRepl.value().getLoc()));
357 builder.setInsertionPointToStart(postInsertBlock);
358 if (call && callable)
363 for (
auto &newBlock : newBlocks)
370 firstNewBlock->
erase();
379 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion,
380 CallOpInterface call = {}) {
385 auto *entryBlock = &src->
front();
386 if (inlinedOperands.size() != entryBlock->getNumArguments())
391 for (
unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
397 mapper.
map(regionArg, inlinedOperands[i]);
402 inlinePoint, mapper, resultsToReplace,
403 resultsToReplace.
getTypes(), inlineLoc,
404 shouldCloneInlinedRegion, call);
412 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
414 ++inlinePoint->getIterator(), mapper, resultsToReplace,
415 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
423 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
425 interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
426 resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
433 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
434 bool shouldCloneInlinedRegion) {
436 ++inlinePoint->getIterator(), inlinedOperands,
437 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
445 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
447 inlinePoint, inlinedOperands, resultsToReplace,
448 inlineLoc, shouldCloneInlinedRegion);
462 type, conversionLoc);
465 castOps.push_back(castOp);
482 CallOpInterface call, CallableOpInterface callable,
Region *src,
483 bool shouldCloneInlinedRegion) {
487 auto *entryBlock = &src->
front();
494 if (callOperands.size() != entryBlock->getNumArguments() ||
495 callResults.size() != callableResultTypes.size())
501 castOps.reserve(callOperands.size() + callResults.size());
504 auto cleanupState = [&] {
505 for (
auto *op : castOps) {
506 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
515 const auto *callInterface = interface.
getInterfaceFor(call->getDialect());
519 for (
unsigned i = 0, e = callOperands.size(); i != e; ++i) {
521 Value operand = callOperands[i];
526 if (operand.
getType() != regionArgType) {
528 operand, regionArgType, castLoc)))
529 return cleanupState();
531 mapper.
map(regionArg, operand);
536 for (
unsigned i = 0, e = callResults.size(); i != e; ++i) {
537 Value callResult = callResults[i];
538 if (callResult.
getType() == callableResultTypes[i])
545 callResult.
getType(), castLoc);
547 return cleanupState();
553 if (!interface.
isLegalToInline(call, callable, shouldCloneInlinedRegion))
554 return cleanupState();
558 ++call->getIterator(), mapper, callResults,
559 callableResultTypes, call.getLoc(),
560 shouldCloneInlinedRegion, call)))
561 return cleanupState();
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, IRMapping &mapper)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results)
static LocationAttr stackLocations(Location callee, Location caller)
Combine callee location with caller location to create a stack that represents the call chain.
static LogicalResult inlineRegionImpl(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call={})
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap all locations reachable from the inlined blocks with CallSiteLoc locations with the provided ca...
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper)
This is an attribute/type replacer that is naively cached.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
This is the interface that must be implemented by the dialects of operations to be inlined.
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This interface provides the hooks into the inlining interface.
virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const
virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, DictionaryAttr argumentAttrs) const
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual bool allowSingleBlockOptimization(iterator_range< Region::iterator > inlinedBlocks) const
Returns true if the inliner can assume a fast path of not creating a new block, if there is only one ...
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
result_type_iterator result_type_begin()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult inlineRegion(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, Region *src, Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.