21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "llvm-inliner"
33 allocaOp->getUsers().end());
34 while (!stack.empty()) {
36 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
38 if (isa<LLVM::BitcastOp>(op))
62 Block *callerEntryBlock =
nullptr;
76 Block *calleeEntryBlock = &(*inlinedBlocks.begin());
77 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
81 bool shouldInsertLifetimes =
false;
82 bool hasDynamicAlloca =
false;
86 for (
auto allocaOp : calleeEntryBlock->
getOps<LLVM::AllocaOp>()) {
87 IntegerAttr arraySize;
89 hasDynamicAlloca =
true;
92 bool shouldInsertLifetime =
94 shouldInsertLifetimes |= shouldInsertLifetime;
95 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
98 for (
Block &block : llvm::drop_begin(inlinedBlocks)) {
102 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](
auto allocaOp) {
103 return !matchPattern(allocaOp.getArraySize(), m_Constant());
106 if (allocasToMove.empty() && !hasDynamicAlloca)
110 if (hasDynamicAlloca) {
115 stackPtr = builder.
create<LLVM::StackSaveOp>(
119 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
120 auto newConstant = builder.
create<LLVM::ConstantOp>(
121 allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
123 if (shouldInsertLifetime) {
126 builder.
create<LLVM::LifetimeStartOp>(
127 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
128 allocaOp.getResult());
131 allocaOp.getArraySizeMutable().assign(newConstant.getResult());
133 if (!shouldInsertLifetimes && !hasDynamicAlloca)
136 for (
Block &block : inlinedBlocks) {
140 if (hasDynamicAlloca)
141 builder.
create<LLVM::StackRestoreOp>(call->
getLoc(), stackPtr);
142 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
143 if (shouldInsertLifetime)
144 builder.
create<LLVM::LifetimeEndOp>(
145 allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
146 allocaOp.getResult());
166 walker.
addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
168 domainAttr.getContext(), domainAttr.getDescription());
171 walker.
addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
173 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
174 scopeAttr.getDescription());
178 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
183 walker.
walk(arrayAttr);
186 llvm::map_to_vector(arrayAttr, [&](
Attribute attr) {
187 return mapping.lookup(attr);
191 for (
Block &block : inlinedBlocks) {
193 if (
auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
194 aliasInterface.setAliasScopes(
195 convertScopeList(aliasInterface.getAliasScopesOrNull()));
196 aliasInterface.setNoAliasScopes(
197 convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
200 if (
auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
202 walker.
walk(noAliasScope.getScopeAttr());
204 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
205 mapping.lookup(noAliasScope.getScopeAttr())));
221 llvm::append_range(result, lhs);
222 llvm::append_range(result, rhs);
229 static FailureOr<SmallVector<Value>>
241 if (controlFlowPredecessors)
245 if (isa<OpResult>(val)) {
246 result.push_back(val);
278 for (
Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
280 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
283 ssaCopies.insert(ssaCopy);
285 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
287 noAliasParams.insert(ssaCopy);
293 auto exit = llvm::make_scope_exit([&] {
294 for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
295 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
301 if (noAliasParams.empty())
307 call->
getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
309 for (LLVM::SSACopyOp copyOp : noAliasParams) {
311 pointerScopes[copyOp] = scope;
318 for (
Block &inlinedBlock : inlinedBlocks) {
319 inlinedBlock.
walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
325 for (
Value pointer : pointerArgs) {
326 FailureOr<SmallVector<Value>> underlyingObjectSet =
328 if (failed(underlyingObjectSet))
331 std::inserter(basedOnPointers, basedOnPointers.begin()));
334 bool aliasesOtherKnownObject =
false;
344 if (llvm::any_of(basedOnPointers, [&](
Value object) {
348 if (
auto ssaCopy =
object.getDefiningOp<LLVM::SSACopyOp>()) {
351 aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
355 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
356 object.getDefiningOp())) {
357 aliasesOtherKnownObject =
true;
367 for (LLVM::SSACopyOp noAlias : noAliasParams) {
368 if (basedOnPointers.contains(noAlias))
371 noAliasScopes.push_back(pointerScopes[noAlias]);
374 if (!noAliasScopes.empty())
375 aliasInterface.setNoAliasScopes(
403 if (aliasesOtherKnownObject ||
404 isa<LLVM::CallOp>(aliasInterface.getOperation()))
408 for (LLVM::SSACopyOp noAlias : noAliasParams)
409 if (basedOnPointers.contains(noAlias))
410 aliasScopes.push_back(pointerScopes[noAlias]);
412 if (!aliasScopes.empty())
413 aliasInterface.setAliasScopes(
425 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
426 if (!callAliasInterface)
429 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
430 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
433 if (!aliasScopes && !noAliasScopes)
438 for (
Block &block : inlinedBlocks) {
439 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
442 aliasInterface.getAliasScopesOrNull(), aliasScopes));
446 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
463 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
464 if (!callAccessGroupInterface)
467 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
473 for (
Block &block : inlinedBlocks)
474 for (
auto accessGroupOpInterface :
475 block.getOps<LLVM::AccessGroupOpInterface>())
477 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
489 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
493 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
507 replacer.
addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
508 -> std::pair<Attribute, WalkResult> {
509 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
510 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
511 if (!newStartLoc && !newEndLoc)
514 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
515 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
516 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
517 loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
518 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
519 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
520 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
521 loopAnnotation.getParallelAccesses());
526 for (
Block &block : inlinedBlocks)
535 uint64_t requestedAlignment,
537 uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
538 if (requestedAlignment <= allocaAlignment)
540 return allocaAlignment;
544 if (naturalStackAlignmentBits == 0 ||
547 8 * requestedAlignment <= naturalStackAlignmentBits ||
550 8 * allocaAlignment > naturalStackAlignmentBits) {
551 alloca.setAlignment(requestedAlignment);
552 allocaAlignment = requestedAlignment;
554 return allocaAlignment;
566 if (
auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
569 if (
auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
570 if (
auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
571 definingOp, addressOf.getGlobalNameAttr()))
572 return global.getAlignment().value_or(1);
579 if (
auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
582 auto blockArg = llvm::cast<BlockArgument>(value);
583 if (
Attribute alignAttr = func.getArgAttr(
584 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
585 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
595 uint64_t elementTypeSize,
596 uint64_t targetAlignment) {
607 allocaOp = builder.
create<LLVM::AllocaOp>(
608 loc, argument.
getType(), elementType, one, targetAlignment);
613 builder.
create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
625 uint64_t requestedAlignment) {
626 auto func = cast<LLVM::LLVMFuncOp>(callable);
627 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
630 bool isReadOnly = memoryEffects &&
631 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
632 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
637 if (requestedAlignment <= minimumAlignment)
639 uint64_t currentAlignment =
641 if (currentAlignment >= requestedAlignment)
644 uint64_t targetAlignment =
std::max(requestedAlignment, minimumAlignment);
654 LLVMInlinerInterface(
Dialect *dialect)
657 disallowedFunctionAttrs({
665 bool wouldBeCloned)
const final {
668 if (!isa<LLVM::CallOp>(call)) {
669 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: call is not an '"
670 << LLVM::CallOp::getOperationName() <<
"' op\n");
673 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
675 LLVM_DEBUG(llvm::dbgs()
676 <<
"Cannot inline: callable is not an '"
677 << LLVM::LLVMFuncOp::getOperationName() <<
"' op\n");
680 if (funcOp.isNoInline()) {
681 LLVM_DEBUG(llvm::dbgs()
682 <<
"Cannot inline: function is marked no_inline\n");
685 if (funcOp.isVarArg()) {
686 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: callable is variadic\n");
690 if (
auto attrs = funcOp.getArgAttrs()) {
691 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
692 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
693 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
694 <<
": inalloca arguments not supported\n");
700 if (funcOp.getPersonality()) {
701 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
702 <<
": unhandled function personality\n");
705 if (funcOp.getPassthrough()) {
707 if (llvm::any_of(*funcOp.getPassthrough(), [&](
Attribute attr) {
708 auto stringAttr = dyn_cast<StringAttr>(attr);
711 if (disallowedFunctionAttrs.contains(stringAttr)) {
712 LLVM_DEBUG(llvm::dbgs()
713 <<
"Cannot inline " << funcOp.getSymName()
714 <<
": found disallowed function attribute "
715 << stringAttr <<
"\n");
731 return !isa<LLVM::VaStartOp>(op);
736 void handleTerminator(
Operation *op,
Block *newDest)
const final {
738 auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
744 builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
753 auto returnOp = cast<LLVM::ReturnOp>(op);
756 assert(returnOp.getNumOperands() == valuesToRepl.size());
757 for (
auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
758 dst.replaceAllUsesWith(src);
763 DictionaryAttr argumentAttrs)
const final {
764 if (std::optional<NamedAttribute> attr =
765 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
766 Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
767 uint64_t requestedAlignment = 1;
768 if (std::optional<NamedAttribute> alignAttr =
769 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
770 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
790 auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
791 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
792 copyOp->setDiscardableAttr(
793 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
794 builder.getUnitAttr());
798 void processInlinedCallBlocks(
817 dialect->addInterfaces<LLVMInlinerInterface>();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp)
Check whether the given alloca is an input to a lifetime intrinsic, optionally passing through one or...
static void appendCallOpAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any alias scopes of the call operation to any inlined memory operation.
static void handleLoopAnnotations(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Updates locations inside loop annotations to reflect that they were inlined.
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs)
Creates a new ArrayAttr by concatenating lhs with rhs.
static void createNewAliasScopesFromNoAliasParameter(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Creates a new AliasScopeAttr for every noalias parameter and attaches it to the appropriate inlined m...
static void handleAccessGroups(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any access groups of the call operation to any inlined memory operation.
static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, uint64_t requestedAlignment)
Handles a function argument marked with the byval attribute by introducing a memcpy or realigning the...
static void handleAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles all interactions with alias scopes during inlining.
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, uint64_t requestedAlignment, DataLayout const &dataLayout)
If requestedAlignment is higher than the alignment specified on alloca, realigns alloca if this does ...
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, DataLayout const &dataLayout)
Tries to find and return the alignment of the pointer value by looking for an alignment attribute on ...
static FailureOr< SmallVector< Value > > getUnderlyingObjectSet(Value pointerValue)
Attempts to return the set of all underlying pointer values that pointerValue is based on.
static void deepCloneAliasScopes(iterator_range< Region::iterator > inlinedBlocks)
Maps all alias scopes in the inlined operations to deep clones of the scopes and domain.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, uint64_t elementTypeSize, uint64_t targetAlignment)
Introduces a new alloca and copies the memory pointed to by argument to the address of the new alloca...
static void handleInlinedAllocas(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles alloca operations in the inlined blocks:
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This is an attribute/type replacer that is naively cached.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI64IntegerAttr(int64_t value)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getStackAlignment() const
Returns the natural alignment of the stack in bits.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait of region holding operations that define a new scope for automatic allocations,...
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void moveAfter(Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A class to signal how to proceed with the walk of the backward slice:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkContinuation skip()
Creates a continuation that advances the walk without adding any predecessor values to the work list.
static WalkContinuation advanceTo(mlir::ValueRange nextValues)
Creates a continuation that adds the user-specified nextValues to the work list and advances the walk...
static WalkContinuation interrupt()
Creates a continuation that interrupts the walk.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void registerInlinerInterface(DialectRegistry ®istry)
Register the LLVMInlinerInterface implementation of DialectInlinerInterface with the LLVM dialect.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< SmallVector< Value > > getControlFlowPredecessors(Value value)
Computes a vector of all control predecessors of value.
WalkContinuation walkSlice(mlir::ValueRange rootValues, WalkCallback walkCallback)
Walks the slice starting from the rootValues using a depth-first traversal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This trait indicates that a terminator operation is "return-like".