22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "llvm-inliner"
34 allocaOp->getUsers().end());
35 while (!stack.empty()) {
37 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
39 if (isa<LLVM::BitcastOp>(op))
63 Block *callerEntryBlock =
nullptr;
77 Block *calleeEntryBlock = &(*inlinedBlocks.begin());
78 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
82 bool shouldInsertLifetimes =
false;
83 bool hasDynamicAlloca =
false;
87 for (
auto allocaOp : calleeEntryBlock->
getOps<LLVM::AllocaOp>()) {
88 IntegerAttr arraySize;
90 hasDynamicAlloca =
true;
93 bool shouldInsertLifetime =
95 shouldInsertLifetimes |= shouldInsertLifetime;
96 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
99 for (
Block &block : llvm::drop_begin(inlinedBlocks)) {
100 if (hasDynamicAlloca)
103 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](
auto allocaOp) {
104 return !matchPattern(allocaOp.getArraySize(), m_Constant());
107 if (allocasToMove.empty() && !hasDynamicAlloca)
111 if (hasDynamicAlloca) {
116 stackPtr = LLVM::StackSaveOp::create(
121 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
123 LLVM::ConstantOp::create(builder, allocaOp->getLoc(),
124 allocaOp.getArraySize().
getType(), arraySize);
126 if (shouldInsertLifetime) {
129 LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(),
130 arraySize.getValue().getLimitedValue(),
131 allocaOp.getResult());
133 allocaOp->moveAfter(newConstant);
134 allocaOp.getArraySizeMutable().assign(newConstant.getResult());
136 if (!shouldInsertLifetimes && !hasDynamicAlloca)
139 for (
Block &block : inlinedBlocks) {
143 if (hasDynamicAlloca)
144 LLVM::StackRestoreOp::create(builder, call->
getLoc(), stackPtr);
145 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
146 if (shouldInsertLifetime)
147 LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(),
148 arraySize.getValue().getLimitedValue(),
149 allocaOp.getResult());
169 walker.
addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
171 domainAttr.getContext(), domainAttr.getDescription());
174 walker.
addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
176 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
177 scopeAttr.getDescription());
181 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
186 walker.
walk(arrayAttr);
189 llvm::map_to_vector(arrayAttr, [&](
Attribute attr) {
190 return mapping.lookup(attr);
194 for (
Block &block : inlinedBlocks) {
196 if (
auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
197 aliasInterface.setAliasScopes(
198 convertScopeList(aliasInterface.getAliasScopesOrNull()));
199 aliasInterface.setNoAliasScopes(
200 convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
203 if (
auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
205 walker.
walk(noAliasScope.getScopeAttr());
207 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
208 mapping.lookup(noAliasScope.getScopeAttr())));
224 llvm::append_range(result, lhs);
225 llvm::append_range(result, rhs);
232 static FailureOr<SmallVector<Value>>
244 if (controlFlowPredecessors)
248 if (isa<OpResult>(val)) {
249 result.push_back(val);
281 for (
Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
283 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
286 ssaCopies.insert(ssaCopy);
288 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
290 noAliasParams.insert(ssaCopy);
296 auto exit = llvm::make_scope_exit([&] {
297 for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
298 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
304 if (noAliasParams.empty())
310 call->
getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
312 for (LLVM::SSACopyOp copyOp : noAliasParams) {
314 pointerScopes[copyOp] = scope;
317 LLVM::NoAliasScopeDeclOp::create(builder, call->
getLoc(), scope);
322 for (
Block &inlinedBlock : inlinedBlocks) {
323 inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
329 for (
Value pointer : pointerArgs) {
330 FailureOr<SmallVector<Value>> underlyingObjectSet =
332 if (failed(underlyingObjectSet))
335 std::inserter(basedOnPointers, basedOnPointers.begin()));
338 bool aliasesOtherKnownObject =
false;
348 if (llvm::any_of(basedOnPointers, [&](
Value object) {
352 if (
auto ssaCopy =
object.getDefiningOp<LLVM::SSACopyOp>()) {
355 aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
359 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
360 object.getDefiningOp())) {
361 aliasesOtherKnownObject =
true;
371 for (LLVM::SSACopyOp noAlias : noAliasParams) {
372 if (basedOnPointers.contains(noAlias))
375 noAliasScopes.push_back(pointerScopes[noAlias]);
378 if (!noAliasScopes.empty())
379 aliasInterface.setNoAliasScopes(
407 if (aliasesOtherKnownObject ||
408 isa<LLVM::CallOp>(aliasInterface.getOperation()))
412 for (LLVM::SSACopyOp noAlias : noAliasParams)
413 if (basedOnPointers.contains(noAlias))
414 aliasScopes.push_back(pointerScopes[noAlias]);
416 if (!aliasScopes.empty())
417 aliasInterface.setAliasScopes(
429 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
430 if (!callAliasInterface)
433 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
434 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
437 if (!aliasScopes && !noAliasScopes)
442 for (
Block &block : inlinedBlocks) {
443 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
446 aliasInterface.getAliasScopesOrNull(), aliasScopes));
450 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
467 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
468 if (!callAccessGroupInterface)
471 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
477 for (
Block &block : inlinedBlocks)
478 for (
auto accessGroupOpInterface :
479 block.getOps<LLVM::AccessGroupOpInterface>())
481 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
493 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
497 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
511 replacer.
addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
512 -> std::pair<Attribute, WalkResult> {
513 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
514 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
515 if (!newStartLoc && !newEndLoc)
518 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
519 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
520 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
521 loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
522 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
523 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
524 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
525 loopAnnotation.getParallelAccesses());
530 for (
Block &block : inlinedBlocks)
539 uint64_t requestedAlignment,
541 uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
542 if (requestedAlignment <= allocaAlignment)
544 return allocaAlignment;
548 if (naturalStackAlignmentBits == 0 ||
551 8 * requestedAlignment <= naturalStackAlignmentBits ||
554 8 * allocaAlignment > naturalStackAlignmentBits) {
555 alloca.setAlignment(requestedAlignment);
556 allocaAlignment = requestedAlignment;
558 return allocaAlignment;
570 if (
auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
573 if (
auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
574 if (
auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
575 definingOp, addressOf.getGlobalNameAttr()))
576 return global.getAlignment().value_or(1);
583 if (
auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
586 auto blockArg = llvm::cast<BlockArgument>(value);
587 if (
Attribute alignAttr = func.getArgAttr(
588 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
589 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
599 uint64_t elementTypeSize,
600 uint64_t targetAlignment) {
609 Value one = LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
611 allocaOp = LLVM::AllocaOp::create(builder, loc, argument.
getType(),
612 elementType, one, targetAlignment);
616 LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
618 LLVM::MemcpyOp::create(builder, loc, allocaOp, argument, copySize,
630 uint64_t requestedAlignment) {
631 auto func = cast<LLVM::LLVMFuncOp>(callable);
632 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
635 bool isReadOnly = memoryEffects &&
636 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
637 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
642 if (requestedAlignment <= minimumAlignment)
644 uint64_t currentAlignment =
646 if (currentAlignment >= requestedAlignment)
649 uint64_t targetAlignment =
std::max(requestedAlignment, minimumAlignment);
651 builder, argument.
getLoc(), argument, elementType,
652 dataLayout.
getTypeSize(elementType), targetAlignment);
659 LLVMInlinerInterface(
Dialect *dialect)
662 disallowedFunctionAttrs({
670 bool wouldBeCloned)
const final {
671 auto callOp = dyn_cast<LLVM::CallOp>(call);
673 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: call is not an '"
674 << LLVM::CallOp::getOperationName() <<
"' op\n");
677 if (callOp.getNoInline()) {
678 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: call is marked no_inline\n");
681 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
683 LLVM_DEBUG(llvm::dbgs()
684 <<
"Cannot inline: callable is not an '"
685 << LLVM::LLVMFuncOp::getOperationName() <<
"' op\n");
688 if (funcOp.isNoInline()) {
689 LLVM_DEBUG(llvm::dbgs()
690 <<
"Cannot inline: function is marked no_inline\n");
693 if (funcOp.isVarArg()) {
694 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline: callable is variadic\n");
698 if (
auto attrs = funcOp.getArgAttrs()) {
699 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
700 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
701 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
702 <<
": inalloca arguments not supported\n");
708 if (funcOp.getPersonality()) {
709 LLVM_DEBUG(llvm::dbgs() <<
"Cannot inline " << funcOp.getSymName()
710 <<
": unhandled function personality\n");
713 if (funcOp.getPassthrough()) {
715 if (llvm::any_of(*funcOp.getPassthrough(), [&](
Attribute attr) {
716 auto stringAttr = dyn_cast<StringAttr>(attr);
719 if (disallowedFunctionAttrs.contains(stringAttr)) {
720 LLVM_DEBUG(llvm::dbgs()
721 <<
"Cannot inline " << funcOp.getSymName()
722 <<
": found disallowed function attribute "
723 << stringAttr <<
"\n");
741 return !(isa<LLVM::VaStartOp>(op) || isa<LLVM::BlockTagOp>(op));
746 void handleTerminator(
Operation *op,
Block *newDest)
const final {
748 auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
754 LLVM::BrOp::create(builder, op->getLoc(), returnOp.getOperands(), newDest);
758 bool allowSingleBlockOptimization(
760 if (!inlinedBlocks.empty() &&
761 isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator()))
771 auto returnOp = cast<LLVM::ReturnOp>(op);
774 assert(returnOp.getNumOperands() == valuesToRepl.size());
775 for (
auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
776 dst.replaceAllUsesWith(src);
781 DictionaryAttr argumentAttrs)
const final {
782 if (std::optional<NamedAttribute> attr =
783 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
784 Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
785 uint64_t requestedAlignment = 1;
786 if (std::optional<NamedAttribute> alignAttr =
787 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
788 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
808 auto copyOp = LLVM::SSACopyOp::create(builder, call->getLoc(), argument);
809 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
810 copyOp->setDiscardableAttr(
811 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
812 builder.getUnitAttr());
816 void processInlinedCallBlocks(
835 dialect->addInterfaces<LLVMInlinerInterface>();
841 dialect->addInterfaces<LLVMInlinerInterface>();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp)
Check whether the given alloca is an input to a lifetime intrinsic, optionally passing through one or...
static void appendCallOpAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any alias scopes of the call operation to any inlined memory operation.
static void handleLoopAnnotations(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Updates locations inside loop annotations to reflect that they were inlined.
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs)
Creates a new ArrayAttr by concatenating lhs with rhs.
static void createNewAliasScopesFromNoAliasParameter(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Creates a new AliasScopeAttr for every noalias parameter and attaches it to the appropriate inlined m...
static void handleAccessGroups(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any access groups of the call operation to any inlined memory operation.
static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, uint64_t requestedAlignment)
Handles a function argument marked with the byval attribute by introducing a memcpy or realigning the...
static void handleAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles all interactions with alias scopes during inlining.
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, uint64_t requestedAlignment, DataLayout const &dataLayout)
If requestedAlignment is higher than the alignment specified on alloca, realigns alloca if this does ...
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, DataLayout const &dataLayout)
Tries to find and return the alignment of the pointer value by looking for an alignment attribute on ...
static FailureOr< SmallVector< Value > > getUnderlyingObjectSet(Value pointerValue)
Attempts to return the set of all underlying pointer values that pointerValue is based on.
static void deepCloneAliasScopes(iterator_range< Region::iterator > inlinedBlocks)
Maps all alias scopes in the inlined operations to deep clones of the scopes and domain.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, uint64_t elementTypeSize, uint64_t targetAlignment)
Introduces a new alloca and copies the memory pointed to by argument to the address of the new alloca...
static void handleInlinedAllocas(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles alloca operations in the inlined blocks:
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This is an attribute/type replacer that is naively cached.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getStackAlignment() const
Returns the natural alignment of the stack in bits.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
A trait of region holding operations that define a new scope for automatic allocations,...
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A class to signal how to proceed with the walk of the backward slice:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkContinuation skip()
Creates a continuation that advances the walk without adding any predecessor values to the work list.
static WalkContinuation advanceTo(mlir::ValueRange nextValues)
Creates a continuation that adds the user-specified nextValues to the work list and advances the walk...
static WalkContinuation interrupt()
Creates a continuation that interrupts the walk.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void registerInlinerInterface(DialectRegistry ®istry)
Register the LLVMInlinerInterface implementation of DialectInlinerInterface with the LLVM dialect.
void registerInlinerInterface(DialectRegistry ®istry)
Register the NVVMInlinerInterface implementation of DialectInlinerInterface with the NVVM dialect.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< SmallVector< Value > > getControlFlowPredecessors(Value value)
Computes a vector of all control predecessors of value.
WalkContinuation walkSlice(mlir::ValueRange rootValues, WalkCallback walkCallback)
Walks the slice starting from the rootValues using a depth-first traversal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This trait indicates that a terminator operation is "return-like".