19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "inlining"
38 while (
auto nextCallSite = dyn_cast<CallSiteLoc>(lastCallee)) {
39 calleeInliningStack.push_back(nextCallSite);
40 lastCallee = nextCallSite.getCaller();
44 for (CallSiteLoc currentCallSite : reverse(calleeInliningStack))
58 auto [it, inserted] = mappedLocations.try_emplace(loc);
62 it->getSecond() = newLoc;
69 [&](
LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
73 for (
Block &block : inlinedBlocks) {
87 for (
auto &operand : op->getOpOperands())
89 operand.set(mappedOp);
91 for (
auto &block : inlinedBlocks)
92 block.walk(remapOperands);
100 bool wouldBeCloned)
const {
102 return handler->isLegalToInline(call, callable, wouldBeCloned);
110 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
118 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
124 return handler ? handler->shouldAnalyzeRecursively(op) :
true;
131 assert(handler &&
"expected valid dialect handler");
132 handler->handleTerminator(op, newDest);
140 assert(handler &&
"expected valid dialect handler");
141 handler->handleTerminator(op, valuesToRepl);
148 if (inlinedBlocks.empty()) {
152 assert(handler &&
"expected valid dialect handler");
153 return handler->allowSingleBlockOptimization(inlinedBlocks);
158 DictionaryAttr argumentAttrs)
const {
160 assert(handler &&
"expected valid dialect handler");
161 return handler->handleArgument(builder, call, callable, argument,
167 DictionaryAttr resultAttrs)
const {
169 assert(handler &&
"expected valid dialect handler");
170 return handler->handleResult(builder, call, callable, result, resultAttrs);
176 assert(handler &&
"expected valid dialect handler");
177 handler->processInlinedCallBlocks(call, inlinedBlocks);
182 Region *insertRegion,
bool shouldCloneInlinedRegion,
184 for (
auto &block : *src) {
185 for (
auto &op : block) {
188 shouldCloneInlinedRegion, valueMapping)) {
190 llvm::dbgs() <<
"* Illegal to inline because of op: ";
197 llvm::any_of(op.getRegions(), [&](
Region ®ion) {
198 return !isLegalToInline(interface, ®ion, insertRegion,
199 shouldCloneInlinedRegion, valueMapping);
212 CallOpInterface call,
213 CallableOpInterface callable,
217 callable.getCallableRegion()->getNumArguments(),
219 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
220 assert(arrayAttr.size() == argAttrs.size());
222 argAttrs[idx] = cast<DictionaryAttr>(attr);
226 for (
auto [blockArg, argAttr] :
227 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
229 builder, call, callable, mapper.
lookup(blockArg), argAttr);
230 assert(newArgument.
getType() == mapper.
lookup(blockArg).getType() &&
231 "expected the argument type to not change");
234 mapper.
map(blockArg, newArgument);
239 CallOpInterface call, CallableOpInterface callable,
244 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
245 assert(arrayAttr.size() == resAttrs.size());
247 resAttrs[idx] = cast<DictionaryAttr>(attr);
251 for (
auto [result, resAttr] : llvm::zip(results, resAttrs)) {
256 interface.
handleResult(builder, call, callable, result, resAttr);
257 assert(newResult.
getType() == result.getType() &&
258 "expected the result type to not change");
261 result.replaceUsesWithIf(newResult, [&](
OpOperand &operand) {
262 return resultUsers.count(operand.
getOwner());
272 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion,
273 CallOpInterface call = {}) {
274 assert(resultsToReplace.size() == regionResultTypes.size());
280 auto *srcEntryBlock = &src->
front();
281 if (llvm::any_of(srcEntryBlock->getArguments(),
287 if (!interface.
isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
289 !
isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
294 OpBuilder builder(inlineBlock, inlinePoint);
295 auto callable = dyn_cast<CallableOpInterface>(src->
getParentOp());
296 if (call && callable)
301 cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
302 shouldCloneInlinedRegion);
305 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
306 postInsertBlock->getIterator());
311 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
316 if (!shouldCloneInlinedRegion)
327 if (singleBlockFastPath && llvm::hasSingleElement(newBlocks)) {
330 builder.setInsertionPoint(firstBlockTerminator);
331 if (call && callable)
337 firstBlockTerminator->
erase();
342 postInsertBlock->
erase();
347 resultToRepl.value().replaceAllUsesWith(
348 postInsertBlock->
addArgument(regionResultTypes[resultToRepl.index()],
349 resultToRepl.value().getLoc()));
353 builder.setInsertionPointToStart(postInsertBlock);
354 if (call && callable)
359 for (
auto &newBlock : newBlocks)
366 firstNewBlock->
erase();
375 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion,
376 CallOpInterface call = {}) {
381 auto *entryBlock = &src->
front();
382 if (inlinedOperands.size() != entryBlock->getNumArguments())
387 for (
unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
393 mapper.
map(regionArg, inlinedOperands[i]);
398 inlinePoint, mapper, resultsToReplace,
399 resultsToReplace.
getTypes(), inlineLoc,
400 shouldCloneInlinedRegion, call);
408 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
410 ++inlinePoint->getIterator(), mapper, resultsToReplace,
411 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
419 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
421 interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
422 resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
429 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
430 bool shouldCloneInlinedRegion) {
432 ++inlinePoint->getIterator(), inlinedOperands,
433 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
441 std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
443 inlinePoint, inlinedOperands, resultsToReplace,
444 inlineLoc, shouldCloneInlinedRegion);
458 type, conversionLoc);
461 castOps.push_back(castOp);
478 CallOpInterface call, CallableOpInterface callable,
Region *src,
479 bool shouldCloneInlinedRegion) {
483 auto *entryBlock = &src->
front();
490 if (callOperands.size() != entryBlock->getNumArguments() ||
491 callResults.size() != callableResultTypes.size())
497 castOps.reserve(callOperands.size() + callResults.size());
500 auto cleanupState = [&] {
501 for (
auto *op : castOps) {
502 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
511 const auto *callInterface = interface.
getInterfaceFor(call->getDialect());
515 for (
unsigned i = 0, e = callOperands.size(); i != e; ++i) {
517 Value operand = callOperands[i];
522 if (operand.
getType() != regionArgType) {
524 operand, regionArgType, castLoc)))
525 return cleanupState();
527 mapper.
map(regionArg, operand);
532 for (
unsigned i = 0, e = callResults.size(); i != e; ++i) {
533 Value callResult = callResults[i];
534 if (callResult.
getType() == callableResultTypes[i])
541 callResult.
getType(), castLoc);
543 return cleanupState();
549 if (!interface.
isLegalToInline(call, callable, shouldCloneInlinedRegion))
550 return cleanupState();
553 if (failed(
inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
554 ++call->getIterator(), mapper, callResults,
555 callableResultTypes, call.getLoc(),
556 shouldCloneInlinedRegion, call)))
557 return cleanupState();
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, IRMapping &mapper)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results)
static LocationAttr stackLocations(Location callee, Location caller)
Combine callee location with caller location to create a stack that represents the call chain.
static LogicalResult inlineRegionImpl(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call={})
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap all locations reachable from the inlined blocks with CallSiteLoc locations with the provided ca...
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper)
This is an attribute/type replacer that is naively cached.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
This is the interface that must be implemented by the dialects of operations to be inlined.
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This interface provides the hooks into the inlining interface.
virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const
virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, DictionaryAttr argumentAttrs) const
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual bool allowSingleBlockOptimization(iterator_range< Region::iterator > inlinedBlocks) const
Returns true if the inliner can assume a fast path of not creating a new block, if there is only one ...
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
result_type_iterator result_type_begin()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult inlineRegion(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, Region *src, Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.