19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "inlining"
35 auto it = mappedLocations.find(op->
getLoc());
36 if (it == mappedLocations.end()) {
38 it = mappedLocations.try_emplace(op->
getLoc(), newLoc).first;
42 for (
auto &block : inlinedBlocks)
43 block.walk(remapOpLoc);
51 operand.set(mappedOp);
53 for (
auto &block : inlinedBlocks)
54 block.walk(remapOperands);
62 bool wouldBeCloned)
const {
64 return handler->isLegalToInline(call, callable, wouldBeCloned);
72 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
80 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
86 return handler ? handler->shouldAnalyzeRecursively(op) :
true;
93 assert(handler &&
"expected valid dialect handler");
94 handler->handleTerminator(op, newDest);
102 assert(handler &&
"expected valid dialect handler");
103 handler->handleTerminator(op, valuesToRepl);
108 DictionaryAttr argumentAttrs)
const {
110 assert(handler &&
"expected valid dialect handler");
111 return handler->handleArgument(builder, call, callable, argument,
117 DictionaryAttr resultAttrs)
const {
119 assert(handler &&
"expected valid dialect handler");
120 return handler->handleResult(builder, call, callable, result, resultAttrs);
126 assert(handler &&
"expected valid dialect handler");
127 handler->processInlinedCallBlocks(call, inlinedBlocks);
132 Region *insertRegion,
bool shouldCloneInlinedRegion,
134 for (
auto &block : *src) {
135 for (
auto &op : block) {
138 shouldCloneInlinedRegion, valueMapping)) {
140 llvm::dbgs() <<
"* Illegal to inline because of op: ";
148 return !isLegalToInline(interface, ®ion, insertRegion,
149 shouldCloneInlinedRegion, valueMapping);
162 CallOpInterface call,
163 CallableOpInterface callable,
167 callable.getCallableRegion()->getNumArguments(),
169 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
170 assert(arrayAttr.size() == argAttrs.size());
172 argAttrs[idx] = cast<DictionaryAttr>(attr);
176 for (
auto [blockArg, argAttr] :
177 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
179 builder, call, callable, mapper.
lookup(blockArg), argAttr);
180 assert(newArgument.
getType() == mapper.
lookup(blockArg).getType() &&
181 "expected the argument type to not change");
184 mapper.
map(blockArg, newArgument);
189 CallOpInterface call, CallableOpInterface callable,
194 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
195 assert(arrayAttr.size() == resAttrs.size());
197 resAttrs[idx] = cast<DictionaryAttr>(attr);
202 for (
auto [result, resAttr] : llvm::zip(results, resAttrs)) {
206 resultUsers.insert(user);
209 interface.
handleResult(builder, call, callable, result, resAttr);
210 assert(newResult.
getType() == result.getType() &&
211 "expected the result type to not change");
214 result.replaceUsesWithIf(newResult, [&](
OpOperand &operand) {
215 return resultUsers.count(operand.
getOwner());
224 std::optional<Location> inlineLoc,
225 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
226 assert(resultsToReplace.size() == regionResultTypes.size());
232 auto *srcEntryBlock = &src->
front();
233 if (llvm::any_of(srcEntryBlock->getArguments(),
239 if (!interface.
isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
241 !
isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
246 OpBuilder builder(inlineBlock, inlinePoint);
247 auto callable = dyn_cast<CallableOpInterface>(src->
getParentOp());
248 if (call && callable)
255 if (shouldCloneInlinedRegion)
256 src->
cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
258 insertRegion->
getBlocks().splice(postInsertBlock->getIterator(),
263 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
264 postInsertBlock->getIterator());
269 if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
274 if (!shouldCloneInlinedRegion)
283 if (std::next(newBlocks.begin()) == newBlocks.end()) {
286 builder.setInsertionPoint(firstBlockTerminator);
287 if (call && callable)
293 firstBlockTerminator->
erase();
298 postInsertBlock->
erase();
303 resultToRepl.value().replaceAllUsesWith(
304 postInsertBlock->
addArgument(regionResultTypes[resultToRepl.index()],
305 resultToRepl.value().getLoc()));
309 builder.setInsertionPointToStart(postInsertBlock);
310 if (call && callable)
315 for (
auto &newBlock : newBlocks)
322 firstNewBlock->
erase();
329 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
330 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
335 auto *entryBlock = &src->
front();
336 if (inlinedOperands.size() != entryBlock->getNumArguments())
341 for (
unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
345 if (inlinedOperands[i].getType() != regionArg.
getType())
347 mapper.
map(regionArg, inlinedOperands[i]);
352 resultsToReplace, resultsToReplace.
getTypes(),
353 inlineLoc, shouldCloneInlinedRegion, call);
360 std::optional<Location> inlineLoc,
361 bool shouldCloneInlinedRegion) {
363 ++inlinePoint->getIterator(), mapper, resultsToReplace,
364 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
371 std::optional<Location> inlineLoc,
372 bool shouldCloneInlinedRegion) {
374 resultsToReplace, regionResultTypes, inlineLoc,
375 shouldCloneInlinedRegion);
382 std::optional<Location> inlineLoc,
383 bool shouldCloneInlinedRegion) {
385 ++inlinePoint->getIterator(), inlinedOperands,
386 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
393 std::optional<Location> inlineLoc,
394 bool shouldCloneInlinedRegion) {
396 inlinedOperands, resultsToReplace, inlineLoc,
397 shouldCloneInlinedRegion);
411 type, conversionLoc);
414 castOps.push_back(castOp);
429 CallOpInterface call,
430 CallableOpInterface callable,
Region *src,
431 bool shouldCloneInlinedRegion) {
435 auto *entryBlock = &src->
front();
442 if (callOperands.size() != entryBlock->getNumArguments() ||
443 callResults.size() != callableResultTypes.size())
449 castOps.reserve(callOperands.size() + callResults.size());
452 auto cleanupState = [&] {
453 for (
auto *op : castOps) {
463 const auto *callInterface = interface.
getInterfaceFor(call->getDialect());
467 for (
unsigned i = 0, e = callOperands.size(); i != e; ++i) {
469 Value operand = callOperands[i];
474 if (operand.
getType() != regionArgType) {
476 operand, regionArgType, castLoc)))
477 return cleanupState();
479 mapper.
map(regionArg, operand);
484 for (
unsigned i = 0, e = callResults.size(); i != e; ++i) {
485 Value callResult = callResults[i];
486 if (callResult.
getType() == callableResultTypes[i])
493 callResult.
getType(), castLoc);
495 return cleanupState();
501 if (!interface.
isLegalToInline(call, callable, shouldCloneInlinedRegion))
502 return cleanupState();
506 ++call->getIterator(), mapper, callResults,
507 callableResultTypes, call.getLoc(),
508 shouldCloneInlinedRegion, call)))
509 return cleanupState();
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, IRMapping &mapper)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, ValueRange results)
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc, bool shouldCloneInlinedRegion, CallOpInterface call={})
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap locations from the inlined blocks with CallSiteLoc locations with the provided caller location.
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, CallOpInterface call, CallableOpInterface callable, IRMapping &mapper)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
This is the interface that must be implemented by the dialects of operations to be inlined.
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This interface provides the hooks into the inlining interface.
virtual Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const
virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, DictionaryAttr argumentAttrs) const
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
result_type_iterator result_type_begin()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
BlockListType & getBlocks()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
This class represents an efficient way to signal success or failure.