19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/Support/Debug.h"
22 #define DEBUG_TYPE "llvm-inliner"
31 allocaOp->getUsers().end());
32 while (!stack.empty()) {
34 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
36 if (isa<LLVM::BitcastOp>(op))
60 Block *callerEntryBlock =
nullptr;
74 Block *calleeEntryBlock = &(*inlinedBlocks.begin());
75 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
79 bool shouldInsertLifetimes =
false;
80 bool hasDynamicAlloca =
false;
84 for (
auto allocaOp : calleeEntryBlock->
getOps<LLVM::AllocaOp>()) {
85 IntegerAttr arraySize;
87 hasDynamicAlloca =
true;
90 bool shouldInsertLifetime =
92 shouldInsertLifetimes |= shouldInsertLifetime;
93 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
96 for (
Block &block : llvm::drop_begin(inlinedBlocks)) {
100 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](
auto allocaOp) {
101 return !matchPattern(allocaOp.getArraySize(), m_Constant());
104 if (allocasToMove.empty() && !hasDynamicAlloca)
108 if (hasDynamicAlloca) {
113 stackPtr = builder.
create<LLVM::StackSaveOp>(
117 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
118 auto newConstant = builder.
create<LLVM::ConstantOp>(
119 allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
121 if (shouldInsertLifetime) {
124 builder.
create<LLVM::LifetimeStartOp>(
125 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
126 allocaOp.getResult());
129 allocaOp.getArraySizeMutable().assign(newConstant.getResult());
131 if (!shouldInsertLifetimes && !hasDynamicAlloca)
134 for (
Block &block : inlinedBlocks) {
138 if (hasDynamicAlloca)
139 builder.
create<LLVM::StackRestoreOp>(call->
getLoc(), stackPtr);
140 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
141 if (shouldInsertLifetime)
142 builder.
create<LLVM::LifetimeEndOp>(
143 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
144 allocaOp.getResult());
164 walker.
addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
166 domainAttr.getContext(), domainAttr.getDescription());
169 walker.
addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
171 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
172 scopeAttr.getDescription());
176 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
181 walker.
walk(arrayAttr);
184 llvm::map_to_vector(arrayAttr, [&](
Attribute attr) {
185 return mapping.lookup(attr);
189 for (
Block &block : inlinedBlocks) {
191 if (
auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
192 aliasInterface.setAliasScopes(
193 convertScopeList(aliasInterface.getAliasScopesOrNull()));
194 aliasInterface.setNoAliasScopes(
195 convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
198 if (
auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
200 walker.
walk(noAliasScope.getScopeAttr());
202 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
203 mapping.lookup(noAliasScope.getScopeAttr())));
219 llvm::append_range(result, lhs);
220 llvm::append_range(result, rhs);
230 if (
auto gepOp = pointerValue.
getDefiningOp<LLVM::GEPOp>()) {
231 pointerValue = gepOp.getBase();
235 if (
auto addrCast = pointerValue.
getDefiningOp<LLVM::AddrSpaceCastOp>()) {
236 pointerValue = addrCast.getOperand();
256 Value current = workList.pop_back_val();
259 if (!seen.insert(current).second)
262 if (
auto selectOp = current.
getDefiningOp<LLVM::SelectOp>()) {
263 workList.push_back(selectOp.getTrueValue());
264 workList.push_back(selectOp.getFalseValue());
268 if (
auto blockArg = dyn_cast<BlockArgument>(current)) {
269 Block *parentBlock = blockArg.getParentBlock();
275 bool anyUnknown =
false;
277 iter != parentBlock->
pred_end(); iter++) {
278 auto branch = dyn_cast<BranchOpInterface>((*iter)->getTerminator());
280 result.push_back(blockArg);
285 Value operand = branch.getSuccessorOperands(
286 iter.getSuccessorIndex())[blockArg.getArgNumber()];
288 result.push_back(blockArg);
293 operands.push_back(operand);
297 llvm::append_range(workList, operands);
302 result.push_back(current);
303 }
while (!workList.empty());
320 for (
Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
322 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
325 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
328 noAliasParams.insert(ssaCopy);
333 if (noAliasParams.empty())
338 auto exit = llvm::make_scope_exit([&] {
339 for (LLVM::SSACopyOp ssaCopyOp : noAliasParams) {
340 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
348 call->
getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
350 for (LLVM::SSACopyOp copyOp : noAliasParams) {
352 pointerScopes[copyOp] = scope;
359 for (
Block &inlinedBlock : inlinedBlocks) {
360 inlinedBlock.
walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
366 for (
Value pointer : pointerArgs)
368 std::inserter(basedOnPointers, basedOnPointers.begin()));
370 bool aliasesOtherKnownObject =
false;
380 if (llvm::any_of(basedOnPointers, [&](
Value object) {
384 if (noAliasParams.contains(
object.getDefiningOp<LLVM::SSACopyOp>()))
389 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
390 object.getDefiningOp())) {
391 aliasesOtherKnownObject =
true;
401 for (LLVM::SSACopyOp noAlias : noAliasParams) {
402 if (basedOnPointers.contains(noAlias))
405 noAliasScopes.push_back(pointerScopes[noAlias]);
408 if (!noAliasScopes.empty())
409 aliasInterface.setNoAliasScopes(
437 if (aliasesOtherKnownObject ||
438 isa<LLVM::CallOp>(aliasInterface.getOperation()))
442 for (LLVM::SSACopyOp noAlias : noAliasParams)
443 if (basedOnPointers.contains(noAlias))
444 aliasScopes.push_back(pointerScopes[noAlias]);
446 if (!aliasScopes.empty())
447 aliasInterface.setAliasScopes(
459 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
460 if (!callAliasInterface)
463 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
464 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
467 if (!aliasScopes && !noAliasScopes)
472 for (
Block &block : inlinedBlocks) {
473 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
476 aliasInterface.getAliasScopesOrNull(), aliasScopes));
480 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
497 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
498 if (!callAccessGroupInterface)
501 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
507 for (
Block &block : inlinedBlocks)
508 for (
auto accessGroupOpInterface :
509 block.getOps<LLVM::AccessGroupOpInterface>())
511 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
523 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
527 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
541 replacer.
addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
542 -> std::pair<Attribute, WalkResult> {
543 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
544 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
545 if (!newStartLoc && !newEndLoc)
548 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
549 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
550 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
551 loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
552 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
553 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
554 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
555 loopAnnotation.getParallelAccesses());
560 for (
Block &block : inlinedBlocks)
569 uint64_t requestedAlignment,
571 uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
572 if (requestedAlignment <= allocaAlignment)
574 return allocaAlignment;
578 if (naturalStackAlignmentBits == 0 ||
581 8 * requestedAlignment <= naturalStackAlignmentBits ||
584 8 * allocaAlignment > naturalStackAlignmentBits) {
585 alloca.setAlignment(requestedAlignment);
586 allocaAlignment = requestedAlignment;
588 return allocaAlignment;
600 if (
auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
603 if (
auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
604 if (
auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
605 definingOp, addressOf.getGlobalNameAttr()))
606 return global.getAlignment().value_or(1);
613 if (
auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
616 auto blockArg = llvm::cast<BlockArgument>(value);
617 if (
Attribute alignAttr = func.getArgAttr(
618 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
619 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
629 uint64_t elementTypeSize,
630 uint64_t targetAlignment) {
641 allocaOp = builder.
create<LLVM::AllocaOp>(
642 loc, argument.
getType(), elementType, one, targetAlignment);
647 builder.
create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
659 uint64_t requestedAlignment) {
660 auto func = cast<LLVM::LLVMFuncOp>(callable);
661 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
664 bool isReadOnly = memoryEffects &&
665 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
666 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
671 if (requestedAlignment <= minimumAlignment)
673 uint64_t currentAlignment =
675 if (currentAlignment >= requestedAlignment)
678 uint64_t targetAlignment =
std::max(requestedAlignment, minimumAlignment);
688 LLVMInlinerInterface(
Dialect *dialect)
691 disallowedFunctionAttrs({
699 bool wouldBeCloned)
const final {
702 if (!isa<LLVM::CallOp>(call)) {
703 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: call is not an '"
704 << LLVM::CallOp::getOperationName() <<
"' op\n");
707 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
709 LLVM_DEBUG(llvm::dbgs()
710 <<
"Cannot inline: callable is not an '"
711 << LLVM::LLVMFuncOp::getOperationName() <<
"' op\n");
714 if (funcOp.isNoInline()) {
715 LLVM_DEBUG(llvm::dbgs()
716 <<
"Cannot inline: function is marked no_inline\n");
719 if (funcOp.isVarArg()) {
720 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: callable is variadic\n");
724 if (
auto attrs = funcOp.getArgAttrs()) {
725 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
726 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
727 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
728 <<
": inalloca arguments not supported\n");
734 if (funcOp.getPersonality()) {
735 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
736 <<
": unhandled function personality\n");
739 if (funcOp.getPassthrough()) {
741 if (llvm::any_of(*funcOp.getPassthrough(), [&](
Attribute attr) {
742 auto stringAttr = dyn_cast<StringAttr>(attr);
745 if (disallowedFunctionAttrs.contains(stringAttr)) {
746 LLVM_DEBUG(llvm::dbgs()
747 <<
"Cannot inline " << funcOp.getSymName()
748 <<
": found disallowed function attribute "
749 << stringAttr <<
"\n");
765 return !isa<LLVM::VaStartOp>(op);
770 void handleTerminator(
Operation *op,
Block *newDest)
const final {
772 auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
778 builder.create<LLVM::BrOp>(op->
getLoc(), returnOp.getOperands(), newDest);
787 auto returnOp = cast<LLVM::ReturnOp>(op);
790 assert(returnOp.getNumOperands() == valuesToRepl.size());
791 for (
auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
792 dst.replaceAllUsesWith(src);
797 DictionaryAttr argumentAttrs)
const final {
798 if (std::optional<NamedAttribute> attr =
799 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
800 Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
801 uint64_t requestedAlignment = 1;
802 if (std::optional<NamedAttribute> alignAttr =
803 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
804 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
811 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName())) {
812 if (argument.use_empty())
827 auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
828 copyOp->setDiscardableAttr(
829 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
830 builder.getUnitAttr());
836 void processInlinedCallBlocks(
854 dialect->addInterfaces<LLVMInlinerInterface>();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp)
Check whether the given alloca is an input to a lifetime intrinsic, optionally passing through one or...
static void appendCallOpAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any alias scopes of the call operation to any inlined memory operation.
static void handleLoopAnnotations(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Updates locations inside loop annotations to reflect that they were inlined.
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs)
Creates a new ArrayAttr by concatenating lhs with rhs.
static void createNewAliasScopesFromNoAliasParameter(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Creates a new AliasScopeAttr for every noalias parameter and attaches it to the appropriate inlined m...
static void handleAccessGroups(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any access groups of the call operation to any inlined memory operation.
static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, uint64_t requestedAlignment)
Handles a function argument marked with the byval attribute by introducing a memcpy or realigning the...
static void handleAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles all interactions with alias scopes during inlining.
static SmallVector< Value > getUnderlyingObjectSet(Value pointerValue)
Attempts to return the set of all underlying pointer values that pointerValue is based on.
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, uint64_t requestedAlignment, DataLayout const &dataLayout)
If requestedAlignment is higher than the alignment specified on alloca, realigns alloca if this does ...
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, DataLayout const &dataLayout)
Tries to find and return the alignment of the pointer value by looking for an alignment attribute on ...
static Value getUnderlyingObject(Value pointerValue)
Attempts to return the underlying pointer value that pointerValue is based on.
static void deepCloneAliasScopes(iterator_range< Region::iterator > inlinedBlocks)
Maps all alias scopes in the inlined operations to deep clones of the scopes and domain.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, uint64_t elementTypeSize, uint64_t targetAlignment)
Introduces a new alloca and copies the memory pointed to by argument to the address of the new alloca...
static void handleInlinedAllocas(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles alloca operations in the inlined blocks:
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This is an attribute/type replacer that is naively cached.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
pred_iterator pred_begin()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI64IntegerAttr(int64_t value)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getStackAlignment() const
Returns the natural alignment of the stack in bits.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait of region holding operations that define a new scope for automatic allocations,...
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void moveAfter(Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void addLLVMInlinerInterface(LLVMDialect *dialect)
Register the LLVMInlinerInterface implementation of DialectInlinerInterface with the LLVM dialect.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This trait indicates that a terminator operation is "return-like".