19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "inlining"
38 while (
auto nextCallSite = dyn_cast<CallSiteLoc>(lastCallee)) {
39 calleeInliningStack.push_back(nextCallSite);
40 lastCallee = nextCallSite.getCaller();
44 for (CallSiteLoc currentCallSite : reverse(calleeInliningStack))
58 auto [it, inserted] = mappedLocations.try_emplace(loc);
62 it->getSecond() = newLoc;
69 [&](
LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
73 for (
Block &block : inlinedBlocks) {
87 for (
auto &operand : op->getOpOperands())
89 operand.set(mappedOp);
91 for (
auto &block : inlinedBlocks)
92 block.walk(remapOperands);
100 bool wouldBeCloned)
const {
102 return handler->isLegalToInline(call, callable, wouldBeCloned);
110 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
118 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
124 return handler ? handler->shouldAnalyzeRecursively(op) :
true;
131 assert(handler &&
"expected valid dialect handler");
132 handler->handleTerminator(op, newDest);
140 assert(handler &&
"expected valid dialect handler");
141 handler->handleTerminator(op, valuesToRepl);
148 if (inlinedBlocks.empty()) {
152 assert(handler &&
"expected valid dialect handler");
153 return handler->allowSingleBlockOptimization(inlinedBlocks);
158 DictionaryAttr argumentAttrs)
const {
160 assert(handler &&
"expected valid dialect handler");
161 return handler->handleArgument(builder, call, callable, argument,
167 DictionaryAttr resultAttrs)
const {
169 assert(handler &&
"expected valid dialect handler");
170 return handler->handleResult(builder, call, callable, result, resultAttrs);
176 assert(handler &&
"expected valid dialect handler");
177 handler->processInlinedCallBlocks(call, inlinedBlocks);
182 Region *insertRegion,
bool shouldCloneInlinedRegion,
184 for (
auto &block : *src) {
185 for (
auto &op : block) {
188 shouldCloneInlinedRegion, valueMapping)) {
190 llvm::dbgs() <<
"* Illegal to inline because of op: ";
197 llvm::any_of(op.getRegions(), [&](
Region ®ion) {
198 return !isLegalToInline(interface, ®ion, insertRegion,
199 shouldCloneInlinedRegion, valueMapping);
212 CallOpInterface call,
213 CallableOpInterface callable,
217 callable.getCallableRegion()->getNumArguments(),
219 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
220 assert(arrayAttr.size() == argAttrs.size());
222 argAttrs[idx] = cast<DictionaryAttr>(attr);
226 for (
auto [blockArg, argAttr] :
227 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
229 builder, call, callable, mapper.
lookup(blockArg), argAttr);
230 assert(newArgument.
getType() == mapper.
lookup(blockArg).getType() &&
231 "expected the argument type to not change");
234 mapper.
map(blockArg, newArgument);
239 CallOpInterface call, CallableOpInterface callable,
244 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
245 assert(arrayAttr.size() == resAttrs.size());
247 resAttrs[idx] = cast<DictionaryAttr>(attr);
252 for (
auto [result, resAttr] : llvm::zip(results, resAttrs)) {
257 interface.
handleResult(builder, call, callable, result, resAttr);
258 assert(newResult.
getType() == result.getType() &&
259 "expected the result type to not change");
262 result.replaceUsesWithIf(newResult, [&](
OpOperand &operand) {
263 return resultUsers.count(operand.
getOwner());
272 std::optional<Location> inlineLoc,
273 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
274 assert(resultsToReplace.size() == regionResultTypes.size());
280 auto *srcEntryBlock = &src->
front();
281 if (llvm::any_of(srcEntryBlock->getArguments(),
287 if (!interface.
isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
289 !
isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
294 OpBuilder builder(inlineBlock, inlinePoint);
295 auto callable = dyn_cast<CallableOpInterface>(src->
getParentOp());
296 if (call && callable)
303 if (shouldCloneInlinedRegion)
304 src->
cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
306 insertRegion->
getBlocks().splice(postInsertBlock->getIterator(),
311 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
312 postInsertBlock->getIterator());
317 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
322 if (!shouldCloneInlinedRegion)
333 if (singleBlockFastPath && std::next(newBlocks.begin()) == newBlocks.end()) {
336 builder.setInsertionPoint(firstBlockTerminator);
337 if (call && callable)
343 firstBlockTerminator->
erase();
348 postInsertBlock->
erase();
353 resultToRepl.value().replaceAllUsesWith(
354 postInsertBlock->
addArgument(regionResultTypes[resultToRepl.index()],
355 resultToRepl.value().getLoc()));
359 builder.setInsertionPointToStart(postInsertBlock);
360 if (call && callable)
365 for (
auto &newBlock : newBlocks)
372 firstNewBlock->
erase();
379 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
380 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
385 auto *entryBlock = &src->
front();
386 if (inlinedOperands.size() != entryBlock->getNumArguments())
391 for (
unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
397 mapper.
map(regionArg, inlinedOperands[i]);
402 resultsToReplace, resultsToReplace.
getTypes(),
403 inlineLoc, shouldCloneInlinedRegion, call);
410 std::optional<Location> inlineLoc,
411 bool shouldCloneInlinedRegion) {
413 ++inlinePoint->getIterator(), mapper, resultsToReplace,
414 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
421 std::optional<Location> inlineLoc,
422 bool shouldCloneInlinedRegion) {
424 resultsToReplace, regionResultTypes, inlineLoc,
425 shouldCloneInlinedRegion);
432 std::optional<Location> inlineLoc,
433 bool shouldCloneInlinedRegion) {
435 ++inlinePoint->getIterator(), inlinedOperands,
436 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
443 std::optional<Location> inlineLoc,
444 bool shouldCloneInlinedRegion) {
446 inlinedOperands, resultsToReplace, inlineLoc,
447 shouldCloneInlinedRegion);
461 type, conversionLoc);
464 castOps.push_back(castOp);
479 CallOpInterface call,
480 CallableOpInterface callable,
Region *src,
481 bool shouldCloneInlinedRegion) {
485 auto *entryBlock = &src->
front();
492 if (callOperands.size() != entryBlock->getNumArguments() ||
493 callResults.size() != callableResultTypes.size())
499 castOps.reserve(callOperands.size() + callResults.size());
502 auto cleanupState = [&] {
503 for (
auto *op : castOps) {
504 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
513 const auto *callInterface = interface.
getInterfaceFor(call->getDialect());
517 for (
unsigned i = 0, e = callOperands.size(); i != e; ++i) {
519 Value operand = callOperands[i];
524 if (operand.
getType() != regionArgType) {
526 operand, regionArgType, castLoc)))
527 return cleanupState();
529 mapper.
map(regionArg, operand);
534 for (
unsigned i = 0, e = callResults.size(); i != e; ++i) {
535 Value callResult = callResults[i];
536 if (callResult.
getType() == callableResultTypes[i])
543 callResult.
getType(), castLoc);
545 return cleanupState();
551 if (!interface.
isLegalToInline(call, callable, shouldCloneInlinedRegion))
552 return cleanupState();
556 ++call->getIterator(), mapper, callResults,
557 callableResultTypes, call.getLoc(),
558 shouldCloneInlinedRegion, call)))
559 return cleanupState();
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, IRMapping &mapper)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results)
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call={})
static LocationAttr stackLocations(Location callee, Location caller)
Combine callee location with caller location to create a stack that represents the call chain.
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap all locations reachable from the inlined blocks with CallSiteLoc locations with the provided ca...
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper)
This is an attribute/type replacer that is naively cached.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
This is the interface that must be implemented by the dialects of operations to be inlined.
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This interface provides the hooks into the inlining interface.
virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const
virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, DictionaryAttr argumentAttrs) const
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual bool allowSingleBlockOptimization(iterator_range< Region::iterator > inlinedBlocks) const
Returns true if the inliner can assume a fast path of not creating a new block, if there is only one ...
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
result_type_iterator result_type_begin()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
BlockListType & getBlocks()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...