22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "llvm-inliner"
34 allocaOp->getUsers().end());
35 while (!stack.empty()) {
37 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
39 if (isa<LLVM::BitcastOp>(op))
63 Block *callerEntryBlock =
nullptr;
77 Block *calleeEntryBlock = &(*inlinedBlocks.begin());
78 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
82 bool shouldInsertLifetimes =
false;
83 bool hasDynamicAlloca =
false;
87 for (
auto allocaOp : calleeEntryBlock->
getOps<LLVM::AllocaOp>()) {
88 IntegerAttr arraySize;
90 hasDynamicAlloca =
true;
93 bool shouldInsertLifetime =
95 shouldInsertLifetimes |= shouldInsertLifetime;
96 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
99 for (
Block &block : llvm::drop_begin(inlinedBlocks)) {
100 if (hasDynamicAlloca)
103 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](
auto allocaOp) {
104 return !matchPattern(allocaOp.getArraySize(), m_Constant());
107 if (allocasToMove.empty() && !hasDynamicAlloca)
111 if (hasDynamicAlloca) {
116 stackPtr = builder.
create<LLVM::StackSaveOp>(
120 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
121 auto newConstant = builder.
create<LLVM::ConstantOp>(
122 allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
124 if (shouldInsertLifetime) {
127 builder.
create<LLVM::LifetimeStartOp>(
128 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
129 allocaOp.getResult());
132 allocaOp.getArraySizeMutable().assign(newConstant.getResult());
134 if (!shouldInsertLifetimes && !hasDynamicAlloca)
137 for (
Block &block : inlinedBlocks) {
141 if (hasDynamicAlloca)
142 builder.
create<LLVM::StackRestoreOp>(call->
getLoc(), stackPtr);
143 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
144 if (shouldInsertLifetime)
145 builder.
create<LLVM::LifetimeEndOp>(
146 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
147 allocaOp.getResult());
167 walker.
addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
169 domainAttr.getContext(), domainAttr.getDescription());
172 walker.
addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
174 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
175 scopeAttr.getDescription());
179 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
184 walker.
walk(arrayAttr);
187 llvm::map_to_vector(arrayAttr, [&](
Attribute attr) {
188 return mapping.lookup(attr);
192 for (
Block &block : inlinedBlocks) {
194 if (
auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
195 aliasInterface.setAliasScopes(
196 convertScopeList(aliasInterface.getAliasScopesOrNull()));
197 aliasInterface.setNoAliasScopes(
198 convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
201 if (
auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
203 walker.
walk(noAliasScope.getScopeAttr());
205 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
206 mapping.lookup(noAliasScope.getScopeAttr())));
222 llvm::append_range(result, lhs);
223 llvm::append_range(result, rhs);
230 static FailureOr<SmallVector<Value>>
242 if (controlFlowPredecessors)
246 if (isa<OpResult>(val)) {
247 result.push_back(val);
279 for (
Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
281 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
284 ssaCopies.insert(ssaCopy);
286 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
288 noAliasParams.insert(ssaCopy);
294 auto exit = llvm::make_scope_exit([&] {
295 for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
296 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
302 if (noAliasParams.empty())
308 call->
getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
310 for (LLVM::SSACopyOp copyOp : noAliasParams) {
312 pointerScopes[copyOp] = scope;
319 for (
Block &inlinedBlock : inlinedBlocks) {
320 inlinedBlock.
walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
326 for (
Value pointer : pointerArgs) {
327 FailureOr<SmallVector<Value>> underlyingObjectSet =
329 if (failed(underlyingObjectSet))
332 std::inserter(basedOnPointers, basedOnPointers.begin()));
335 bool aliasesOtherKnownObject =
false;
345 if (llvm::any_of(basedOnPointers, [&](
Value object) {
349 if (
auto ssaCopy =
object.getDefiningOp<LLVM::SSACopyOp>()) {
352 aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
356 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
357 object.getDefiningOp())) {
358 aliasesOtherKnownObject =
true;
368 for (LLVM::SSACopyOp noAlias : noAliasParams) {
369 if (basedOnPointers.contains(noAlias))
372 noAliasScopes.push_back(pointerScopes[noAlias]);
375 if (!noAliasScopes.empty())
376 aliasInterface.setNoAliasScopes(
404 if (aliasesOtherKnownObject ||
405 isa<LLVM::CallOp>(aliasInterface.getOperation()))
409 for (LLVM::SSACopyOp noAlias : noAliasParams)
410 if (basedOnPointers.contains(noAlias))
411 aliasScopes.push_back(pointerScopes[noAlias]);
413 if (!aliasScopes.empty())
414 aliasInterface.setAliasScopes(
426 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
427 if (!callAliasInterface)
430 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
431 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
434 if (!aliasScopes && !noAliasScopes)
439 for (
Block &block : inlinedBlocks) {
440 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
443 aliasInterface.getAliasScopesOrNull(), aliasScopes));
447 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
464 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
465 if (!callAccessGroupInterface)
468 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
474 for (
Block &block : inlinedBlocks)
475 for (
auto accessGroupOpInterface :
476 block.getOps<LLVM::AccessGroupOpInterface>())
478 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
490 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
494 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
508 replacer.
addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
509 -> std::pair<Attribute, WalkResult> {
510 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
511 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
512 if (!newStartLoc && !newEndLoc)
515 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
516 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
517 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
518 loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
519 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
520 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
521 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
522 loopAnnotation.getParallelAccesses());
527 for (
Block &block : inlinedBlocks)
536 uint64_t requestedAlignment,
538 uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
539 if (requestedAlignment <= allocaAlignment)
541 return allocaAlignment;
545 if (naturalStackAlignmentBits == 0 ||
548 8 * requestedAlignment <= naturalStackAlignmentBits ||
551 8 * allocaAlignment > naturalStackAlignmentBits) {
552 alloca.setAlignment(requestedAlignment);
553 allocaAlignment = requestedAlignment;
555 return allocaAlignment;
567 if (
auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
570 if (
auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
571 if (
auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
572 definingOp, addressOf.getGlobalNameAttr()))
573 return global.getAlignment().value_or(1);
580 if (
auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
583 auto blockArg = llvm::cast<BlockArgument>(value);
584 if (
Attribute alignAttr = func.getArgAttr(
585 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
586 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
596 uint64_t elementTypeSize,
597 uint64_t targetAlignment) {
608 allocaOp = builder.
create<LLVM::AllocaOp>(
609 loc, argument.
getType(), elementType, one, targetAlignment);
614 builder.
create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
626 uint64_t requestedAlignment) {
627 auto func = cast<LLVM::LLVMFuncOp>(callable);
628 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
631 bool isReadOnly = memoryEffects &&
632 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
633 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
638 if (requestedAlignment <= minimumAlignment)
640 uint64_t currentAlignment =
642 if (currentAlignment >= requestedAlignment)
645 uint64_t targetAlignment =
std::max(requestedAlignment, minimumAlignment);
647 builder, argument.
getLoc(), argument, elementType,
648 dataLayout.
getTypeSize(elementType), targetAlignment);
655 LLVMInlinerInterface(
Dialect *dialect)
658 disallowedFunctionAttrs({
666 bool wouldBeCloned)
const final {
667 if (!isa<LLVM::CallOp>(call)) {
668 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: call is not an '"
669 << LLVM::CallOp::getOperationName() <<
"' op\n");
672 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
674 LLVM_DEBUG(llvm::dbgs()
675 <<
"Cannot inline: callable is not an '"
676 << LLVM::LLVMFuncOp::getOperationName() <<
"' op\n");
679 if (funcOp.isNoInline()) {
680 LLVM_DEBUG(llvm::dbgs()
681 <<
"Cannot inline: function is marked no_inline\n");
684 if (funcOp.isVarArg()) {
685 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: callable is variadic\n");
689 if (
auto attrs = funcOp.getArgAttrs()) {
690 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
691 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
692 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
693 <<
": inalloca arguments not supported\n");
699 if (funcOp.getPersonality()) {
700 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
701 <<
": unhandled function personality\n");
704 if (funcOp.getPassthrough()) {
706 if (llvm::any_of(*funcOp.getPassthrough(), [&](
Attribute attr) {
707 auto stringAttr = dyn_cast<StringAttr>(attr);
710 if (disallowedFunctionAttrs.contains(stringAttr)) {
711 LLVM_DEBUG(llvm::dbgs()
712 <<
"Cannot inline " << funcOp.getSymName()
713 <<
": found disallowed function attribute "
714 << stringAttr <<
"\n");
730 return !isa<LLVM::VaStartOp>(op);
735 void handleTerminator(
Operation *op,
Block *newDest)
const final {
737 auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
743 builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
747 bool allowSingleBlockOptimization(
749 if (!inlinedBlocks.empty() &&
750 isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator()))
760 auto returnOp = cast<LLVM::ReturnOp>(op);
763 assert(returnOp.getNumOperands() == valuesToRepl.size());
764 for (
auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
765 dst.replaceAllUsesWith(src);
770 DictionaryAttr argumentAttrs)
const final {
771 if (std::optional<NamedAttribute> attr =
772 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
773 Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
774 uint64_t requestedAlignment = 1;
775 if (std::optional<NamedAttribute> alignAttr =
776 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
777 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
797 auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
798 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
799 copyOp->setDiscardableAttr(
800 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
801 builder.getUnitAttr());
805 void processInlinedCallBlocks(
824 dialect->addInterfaces<LLVMInlinerInterface>();
830 dialect->addInterfaces<LLVMInlinerInterface>();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp)
Check whether the given alloca is an input to a lifetime intrinsic, optionally passing through one or...
static void appendCallOpAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any alias scopes of the call operation to any inlined memory operation.
static void handleLoopAnnotations(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Updates locations inside loop annotations to reflect that they were inlined.
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs)
Creates a new ArrayAttr by concatenating lhs with rhs.
static void createNewAliasScopesFromNoAliasParameter(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Creates a new AliasScopeAttr for every noalias parameter and attaches it to the appropriate inlined m...
static void handleAccessGroups(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any access groups of the call operation to any inlined memory operation.
static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, uint64_t requestedAlignment)
Handles a function argument marked with the byval attribute by introducing a memcpy or realigning the...
static void handleAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles all interactions with alias scopes during inlining.
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, uint64_t requestedAlignment, DataLayout const &dataLayout)
If requestedAlignment is higher than the alignment specified on alloca, realigns alloca if this does ...
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, DataLayout const &dataLayout)
Tries to find and return the alignment of the pointer value by looking for an alignment attribute on ...
static FailureOr< SmallVector< Value > > getUnderlyingObjectSet(Value pointerValue)
Attempts to return the set of all underlying pointer values that pointerValue is based on.
static void deepCloneAliasScopes(iterator_range< Region::iterator > inlinedBlocks)
Maps all alias scopes in the inlined operations to deep clones of the scopes and domain.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, uint64_t elementTypeSize, uint64_t targetAlignment)
Introduces a new alloca and copies the memory pointed to by argument to the address of the new alloca...
static void handleInlinedAllocas(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles alloca operations in the inlined blocks:
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This is an attribute/type replacer that is naively cached.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI64IntegerAttr(int64_t value)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getStackAlignment() const
Returns the natural alignment of the stack in bits.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait of region holding operations that define a new scope for automatic allocations,...
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void moveAfter(Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A class to signal how to proceed with the walk of the backward slice:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkContinuation skip()
Creates a continuation that advances the walk without adding any predecessor values to the work list.
static WalkContinuation advanceTo(mlir::ValueRange nextValues)
Creates a continuation that adds the user-specified nextValues to the work list and advances the walk...
static WalkContinuation interrupt()
Creates a continuation that interrupts the walk.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void registerInlinerInterface(DialectRegistry ®istry)
Register the LLVMInlinerInterface implementation of DialectInlinerInterface with the LLVM dialect.
void registerInlinerInterface(DialectRegistry ®istry)
Register the NVVMInlinerInterface implementation of DialectInlinerInterface with the NVVM dialect.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< SmallVector< Value > > getControlFlowPredecessors(Value value)
Computes a vector of all control predecessors of value.
WalkContinuation walkSlice(mlir::ValueRange rootValues, WalkCallback walkCallback)
Walks the slice starting from the rootValues using a depth-first traversal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This trait indicates that a terminator operation is "return-like".