19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "inlining"
35 auto [it, inserted] = mappedLocations.try_emplace(loc);
39 it->getSecond() = newLoc;
46 [&](
LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
50 for (
Block &block : inlinedBlocks) {
64 for (
auto &operand : op->getOpOperands())
66 operand.set(mappedOp);
68 for (
auto &block : inlinedBlocks)
69 block.walk(remapOperands);
77 bool wouldBeCloned)
const {
79 return handler->isLegalToInline(call, callable, wouldBeCloned);
87 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
95 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
101 return handler ? handler->shouldAnalyzeRecursively(op) :
true;
108 assert(handler &&
"expected valid dialect handler");
109 handler->handleTerminator(op, newDest);
117 assert(handler &&
"expected valid dialect handler");
118 handler->handleTerminator(op, valuesToRepl);
123 DictionaryAttr argumentAttrs)
const {
125 assert(handler &&
"expected valid dialect handler");
126 return handler->handleArgument(builder, call, callable, argument,
132 DictionaryAttr resultAttrs)
const {
134 assert(handler &&
"expected valid dialect handler");
135 return handler->handleResult(builder, call, callable, result, resultAttrs);
141 assert(handler &&
"expected valid dialect handler");
142 handler->processInlinedCallBlocks(call, inlinedBlocks);
147 Region *insertRegion,
bool shouldCloneInlinedRegion,
149 for (
auto &block : *src) {
150 for (
auto &op : block) {
153 shouldCloneInlinedRegion, valueMapping)) {
155 llvm::dbgs() <<
"* Illegal to inline because of op: ";
162 llvm::any_of(op.getRegions(), [&](
Region ®ion) {
163 return !isLegalToInline(interface, ®ion, insertRegion,
164 shouldCloneInlinedRegion, valueMapping);
177 CallOpInterface call,
178 CallableOpInterface callable,
182 callable.getCallableRegion()->getNumArguments(),
184 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
185 assert(arrayAttr.size() == argAttrs.size());
187 argAttrs[idx] = cast<DictionaryAttr>(attr);
191 for (
auto [blockArg, argAttr] :
192 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
194 builder, call, callable, mapper.
lookup(blockArg), argAttr);
195 assert(newArgument.
getType() == mapper.
lookup(blockArg).getType() &&
196 "expected the argument type to not change");
199 mapper.
map(blockArg, newArgument);
204 CallOpInterface call, CallableOpInterface callable,
209 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
210 assert(arrayAttr.size() == resAttrs.size());
212 resAttrs[idx] = cast<DictionaryAttr>(attr);
217 for (
auto [result, resAttr] : llvm::zip(results, resAttrs)) {
221 resultUsers.insert(user);
224 interface.
handleResult(builder, call, callable, result, resAttr);
225 assert(newResult.
getType() == result.getType() &&
226 "expected the result type to not change");
229 result.replaceUsesWithIf(newResult, [&](
OpOperand &operand) {
230 return resultUsers.count(operand.
getOwner());
239 std::optional<Location> inlineLoc,
240 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
241 assert(resultsToReplace.size() == regionResultTypes.size());
247 auto *srcEntryBlock = &src->
front();
248 if (llvm::any_of(srcEntryBlock->getArguments(),
254 if (!interface.
isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
256 !
isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
261 OpBuilder builder(inlineBlock, inlinePoint);
262 auto callable = dyn_cast<CallableOpInterface>(src->
getParentOp());
263 if (call && callable)
270 if (shouldCloneInlinedRegion)
271 src->
cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
273 insertRegion->
getBlocks().splice(postInsertBlock->getIterator(),
278 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
279 postInsertBlock->getIterator());
284 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
289 if (!shouldCloneInlinedRegion)
298 if (std::next(newBlocks.begin()) == newBlocks.end()) {
301 builder.setInsertionPoint(firstBlockTerminator);
302 if (call && callable)
308 firstBlockTerminator->
erase();
313 postInsertBlock->
erase();
318 resultToRepl.value().replaceAllUsesWith(
319 postInsertBlock->
addArgument(regionResultTypes[resultToRepl.index()],
320 resultToRepl.value().getLoc()));
324 builder.setInsertionPointToStart(postInsertBlock);
325 if (call && callable)
330 for (
auto &newBlock : newBlocks)
337 firstNewBlock->
erase();
344 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
345 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
350 auto *entryBlock = &src->
front();
351 if (inlinedOperands.size() != entryBlock->getNumArguments())
356 for (
unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
362 mapper.
map(regionArg, inlinedOperands[i]);
367 resultsToReplace, resultsToReplace.
getTypes(),
368 inlineLoc, shouldCloneInlinedRegion, call);
375 std::optional<Location> inlineLoc,
376 bool shouldCloneInlinedRegion) {
378 ++inlinePoint->getIterator(), mapper, resultsToReplace,
379 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
386 std::optional<Location> inlineLoc,
387 bool shouldCloneInlinedRegion) {
389 resultsToReplace, regionResultTypes, inlineLoc,
390 shouldCloneInlinedRegion);
397 std::optional<Location> inlineLoc,
398 bool shouldCloneInlinedRegion) {
400 ++inlinePoint->getIterator(), inlinedOperands,
401 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
408 std::optional<Location> inlineLoc,
409 bool shouldCloneInlinedRegion) {
411 inlinedOperands, resultsToReplace, inlineLoc,
412 shouldCloneInlinedRegion);
426 type, conversionLoc);
429 castOps.push_back(castOp);
444 CallOpInterface call,
445 CallableOpInterface callable,
Region *src,
446 bool shouldCloneInlinedRegion) {
450 auto *entryBlock = &src->
front();
457 if (callOperands.size() != entryBlock->getNumArguments() ||
458 callResults.size() != callableResultTypes.size())
464 castOps.reserve(callOperands.size() + callResults.size());
467 auto cleanupState = [&] {
468 for (
auto *op : castOps) {
469 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
478 const auto *callInterface = interface.
getInterfaceFor(call->getDialect());
482 for (
unsigned i = 0, e = callOperands.size(); i != e; ++i) {
484 Value operand = callOperands[i];
489 if (operand.
getType() != regionArgType) {
491 operand, regionArgType, castLoc)))
492 return cleanupState();
494 mapper.
map(regionArg, operand);
499 for (
unsigned i = 0, e = callResults.size(); i != e; ++i) {
500 Value callResult = callResults[i];
501 if (callResult.
getType() == callableResultTypes[i])
508 callResult.
getType(), castLoc);
510 return cleanupState();
516 if (!interface.
isLegalToInline(call, callable, shouldCloneInlinedRegion))
517 return cleanupState();
521 ++call->getIterator(), mapper, callResults,
522 callableResultTypes, call.getLoc(),
523 shouldCloneInlinedRegion, call)))
524 return cleanupState();
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, IRMapping &mapper)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results)
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call={})
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap all locations reachable from the inlined blocks with CallSiteLoc locations with the provided ca...
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper)
This is an attribute/type replacer that is naively cached.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
This is the interface that must be implemented by the dialects of operations to be inlined.
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This interface provides the hooks into the inlining interface.
virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const
virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, DictionaryAttr argumentAttrs) const
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
result_type_iterator result_type_begin()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
BlockListType & getBlocks()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...