22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/DebugLog.h"
27 #define DEBUG_TYPE "llvm-inliner"
36 allocaOp->getUsers().end());
37 while (!stack.empty()) {
39 if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
41 if (isa<LLVM::BitcastOp>(op))
65 Block *callerEntryBlock =
nullptr;
79 Block *calleeEntryBlock = &(*inlinedBlocks.begin());
80 if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
84 bool shouldInsertLifetimes =
false;
85 bool hasDynamicAlloca =
false;
89 for (
auto allocaOp : calleeEntryBlock->
getOps<LLVM::AllocaOp>()) {
90 IntegerAttr arraySize;
92 hasDynamicAlloca =
true;
95 bool shouldInsertLifetime =
97 shouldInsertLifetimes |= shouldInsertLifetime;
98 allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
101 for (
Block &block : llvm::drop_begin(inlinedBlocks)) {
102 if (hasDynamicAlloca)
105 llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](
auto allocaOp) {
106 return !matchPattern(allocaOp.getArraySize(), m_Constant());
109 if (allocasToMove.empty() && !hasDynamicAlloca)
113 if (hasDynamicAlloca) {
118 stackPtr = LLVM::StackSaveOp::create(
123 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
125 LLVM::ConstantOp::create(builder, allocaOp->getLoc(),
126 allocaOp.getArraySize().
getType(), arraySize);
128 if (shouldInsertLifetime) {
131 LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(),
132 allocaOp.getResult());
134 allocaOp->moveAfter(newConstant);
135 allocaOp.getArraySizeMutable().assign(newConstant.getResult());
137 if (!shouldInsertLifetimes && !hasDynamicAlloca)
140 for (
Block &block : inlinedBlocks) {
144 if (hasDynamicAlloca)
145 LLVM::StackRestoreOp::create(builder, call->
getLoc(), stackPtr);
146 for (
auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
147 if (shouldInsertLifetime)
148 LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(),
149 allocaOp.getResult());
169 walker.
addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
171 domainAttr.getContext(), domainAttr.getDescription());
174 walker.
addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
176 cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
177 scopeAttr.getDescription());
181 auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
186 walker.
walk(arrayAttr);
189 llvm::map_to_vector(arrayAttr, [&](
Attribute attr) {
190 return mapping.lookup(attr);
194 for (
Block &block : inlinedBlocks) {
196 if (
auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
197 aliasInterface.setAliasScopes(
198 convertScopeList(aliasInterface.getAliasScopesOrNull()));
199 aliasInterface.setNoAliasScopes(
200 convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
203 if (
auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
205 walker.
walk(noAliasScope.getScopeAttr());
207 noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
208 mapping.lookup(noAliasScope.getScopeAttr())));
224 llvm::append_range(result, lhs);
225 llvm::append_range(result, rhs);
232 static FailureOr<SmallVector<Value>>
238 if (
auto viewOp = val.
getDefiningOp<ViewLikeOpInterface>()) {
239 if (val == viewOp.getViewDest())
240 return WalkContinuation::advanceTo(viewOp.getViewSource());
246 if (controlFlowPredecessors)
250 if (isa<OpResult>(val)) {
251 result.push_back(val);
283 for (
Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
285 auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
288 ssaCopies.insert(ssaCopy);
290 if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
292 noAliasParams.insert(ssaCopy);
298 auto exit = llvm::make_scope_exit([&] {
299 for (LLVM::SSACopyOp ssaCopyOp : ssaCopies) {
300 ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
306 if (noAliasParams.empty())
312 call->
getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
314 for (LLVM::SSACopyOp copyOp : noAliasParams) {
316 pointerScopes[copyOp] = scope;
319 LLVM::NoAliasScopeDeclOp::create(builder, call->
getLoc(), scope);
324 for (
Block &inlinedBlock : inlinedBlocks) {
325 inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
331 for (
Value pointer : pointerArgs) {
332 FailureOr<SmallVector<Value>> underlyingObjectSet =
334 if (
failed(underlyingObjectSet))
337 std::inserter(basedOnPointers, basedOnPointers.begin()));
340 bool aliasesOtherKnownObject =
false;
350 if (llvm::any_of(basedOnPointers, [&](
Value object) {
354 if (
auto ssaCopy =
object.getDefiningOp<LLVM::SSACopyOp>()) {
357 aliasesOtherKnownObject |= !noAliasParams.contains(ssaCopy);
361 if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
362 object.getDefiningOp())) {
363 aliasesOtherKnownObject =
true;
373 for (LLVM::SSACopyOp noAlias : noAliasParams) {
374 if (basedOnPointers.contains(noAlias))
377 noAliasScopes.push_back(pointerScopes[noAlias]);
380 if (!noAliasScopes.empty())
381 aliasInterface.setNoAliasScopes(
409 if (aliasesOtherKnownObject ||
410 isa<LLVM::CallOp>(aliasInterface.getOperation()))
414 for (LLVM::SSACopyOp noAlias : noAliasParams)
415 if (basedOnPointers.contains(noAlias))
416 aliasScopes.push_back(pointerScopes[noAlias]);
418 if (!aliasScopes.empty())
419 aliasInterface.setAliasScopes(
431 auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
432 if (!callAliasInterface)
435 ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
436 ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
439 if (!aliasScopes && !noAliasScopes)
444 for (
Block &block : inlinedBlocks) {
445 block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
448 aliasInterface.getAliasScopesOrNull(), aliasScopes));
452 aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
469 auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
470 if (!callAccessGroupInterface)
473 auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
479 for (
Block &block : inlinedBlocks)
480 for (
auto accessGroupOpInterface :
481 block.getOps<LLVM::AccessGroupOpInterface>())
483 accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
495 auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
499 dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
513 replacer.
addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
514 -> std::pair<Attribute, WalkResult> {
515 FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
516 FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
517 if (!newStartLoc && !newEndLoc)
520 loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
521 loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
522 loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
523 loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
524 loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
525 loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
526 loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
527 loopAnnotation.getParallelAccesses());
532 for (
Block &block : inlinedBlocks)
541 uint64_t requestedAlignment,
543 uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
544 if (requestedAlignment <= allocaAlignment)
546 return allocaAlignment;
550 if (naturalStackAlignmentBits == 0 ||
553 8 * requestedAlignment <= naturalStackAlignmentBits ||
556 8 * allocaAlignment > naturalStackAlignmentBits) {
557 alloca.setAlignment(requestedAlignment);
558 allocaAlignment = requestedAlignment;
560 return allocaAlignment;
572 if (
auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
575 if (
auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
576 if (
auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
577 definingOp, addressOf.getGlobalNameAttr()))
578 return global.getAlignment().value_or(1);
585 if (
auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
588 auto blockArg = llvm::cast<BlockArgument>(value);
589 if (
Attribute alignAttr = func.getArgAttr(
590 blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
591 return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
601 uint64_t elementTypeSize,
602 uint64_t targetAlignment) {
611 Value one = LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
613 allocaOp = LLVM::AllocaOp::create(builder, loc, argument.
getType(),
614 elementType, one, targetAlignment);
618 LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
620 LLVM::MemcpyOp::create(builder, loc, allocaOp, argument, copySize,
632 uint64_t requestedAlignment) {
633 auto func = cast<LLVM::LLVMFuncOp>(callable);
634 LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryEffectsAttr();
637 bool isReadOnly = memoryEffects &&
638 memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
639 memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
644 if (requestedAlignment <= minimumAlignment)
646 uint64_t currentAlignment =
648 if (currentAlignment >= requestedAlignment)
651 uint64_t targetAlignment =
std::max(requestedAlignment, minimumAlignment);
653 builder, argument.
getLoc(), argument, elementType,
654 dataLayout.
getTypeSize(elementType), targetAlignment);
661 LLVMInlinerInterface(
Dialect *dialect)
664 disallowedFunctionAttrs({
672 bool wouldBeCloned)
const final {
673 auto callOp = dyn_cast<LLVM::CallOp>(call);
675 LDBG() <<
"Cannot inline: call is not an '"
676 << LLVM::CallOp::getOperationName() <<
"' op";
679 if (callOp.getNoInline()) {
680 LDBG() <<
"Cannot inline: call is marked no_inline";
683 auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
685 LDBG() <<
"Cannot inline: callable is not an '"
686 << LLVM::LLVMFuncOp::getOperationName() <<
"' op";
689 if (funcOp.isNoInline()) {
690 LDBG() <<
"Cannot inline: function is marked no_inline";
693 if (funcOp.isVarArg()) {
694 LDBG() <<
"Cannot inline: callable is variadic";
698 if (
auto attrs = funcOp.getArgAttrs()) {
699 for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
700 if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
701 LDBG() <<
"Cannot inline " << funcOp.getSymName()
702 <<
": inalloca arguments not supported";
708 if (funcOp.getPersonality()) {
709 LDBG() <<
"Cannot inline " << funcOp.getSymName()
710 <<
": unhandled function personality";
713 if (funcOp.getPassthrough()) {
715 if (llvm::any_of(*funcOp.getPassthrough(), [&](
Attribute attr) {
716 auto stringAttr = dyn_cast<StringAttr>(attr);
719 if (disallowedFunctionAttrs.contains(stringAttr)) {
720 LDBG() <<
"Cannot inline " << funcOp.getSymName()
721 <<
": found disallowed function attribute " << stringAttr;
739 return !(isa<LLVM::VaStartOp>(op) || isa<LLVM::BlockTagOp>(op));
744 void handleTerminator(
Operation *op,
Block *newDest)
const final {
746 auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
752 LLVM::BrOp::create(builder, op->getLoc(), returnOp.getOperands(), newDest);
756 bool allowSingleBlockOptimization(
758 if (!inlinedBlocks.empty() &&
759 isa<LLVM::UnreachableOp>(inlinedBlocks.begin()->getTerminator()))
769 auto returnOp = cast<LLVM::ReturnOp>(op);
772 assert(returnOp.getNumOperands() == valuesToRepl.size());
773 for (
auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
774 dst.replaceAllUsesWith(src);
779 DictionaryAttr argumentAttrs)
const final {
780 if (std::optional<NamedAttribute> attr =
781 argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
782 Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
783 uint64_t requestedAlignment = 1;
784 if (std::optional<NamedAttribute> alignAttr =
785 argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
786 requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
806 auto copyOp = LLVM::SSACopyOp::create(builder, call->getLoc(), argument);
807 if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName()))
808 copyOp->setDiscardableAttr(
809 builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
810 builder.getUnitAttr());
814 void processInlinedCallBlocks(
833 dialect->addInterfaces<LLVMInlinerInterface>();
839 dialect->addInterfaces<LLVMInlinerInterface>();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp)
Check whether the given alloca is an input to a lifetime intrinsic, optionally passing through one or...
static void appendCallOpAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any alias scopes of the call operation to any inlined memory operation.
static void handleLoopAnnotations(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Updates locations inside loop annotations to reflect that they were inlined.
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs)
Creates a new ArrayAttr by concatenating lhs with rhs.
static void createNewAliasScopesFromNoAliasParameter(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Creates a new AliasScopeAttr for every noalias parameter and attaches it to the appropriate inlined m...
static void handleAccessGroups(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Appends any access groups of the call operation to any inlined memory operation.
static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, uint64_t requestedAlignment)
Handles a function argument marked with the byval attribute by introducing a memcpy or realigning the...
static void handleAliasScopes(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles all interactions with alias scopes during inlining.
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, uint64_t requestedAlignment, DataLayout const &dataLayout)
If requestedAlignment is higher than the alignment specified on alloca, realigns alloca if this does ...
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment, DataLayout const &dataLayout)
Tries to find and return the alignment of the pointer value by looking for an alignment attribute on ...
static FailureOr< SmallVector< Value > > getUnderlyingObjectSet(Value pointerValue)
Attempts to return the set of all underlying pointer values that pointerValue is based on.
static void deepCloneAliasScopes(iterator_range< Region::iterator > inlinedBlocks)
Maps all alias scopes in the inlined operations to deep clones of the scopes and domain.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, uint64_t elementTypeSize, uint64_t targetAlignment)
Introduces a new alloca and copies the memory pointed to by argument to the address of the new alloca...
static void handleInlinedAllocas(Operation *call, iterator_range< Region::iterator > inlinedBlocks)
Handles alloca operations in the inlined blocks:
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This is an attribute/type replacer that is naively cached.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getStackAlignment() const
Returns the natural alignment of the stack in bits.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
A trait of region holding operations that define a new scope for automatic allocations,...
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A class to signal how to proceed with the walk of the backward slice:
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkContinuation skip()
Creates a continuation that advances the walk without adding any predecessor values to the work list.
static WalkContinuation advanceTo(mlir::ValueRange nextValues)
Creates a continuation that adds the user-specified nextValues to the work list and advances the walk...
static WalkContinuation interrupt()
Creates a continuation that interrupts the walk.
static WalkResult advance()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void registerInlinerInterface(DialectRegistry ®istry)
Register the LLVMInlinerInterface implementation of DialectInlinerInterface with the LLVM dialect.
void registerInlinerInterface(DialectRegistry ®istry)
Register the NVVMInlinerInterface implementation of DialectInlinerInterface with the NVVM dialect.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< SmallVector< Value > > getControlFlowPredecessors(Value value)
Computes a vector of all control predecessors of value.
WalkContinuation walkSlice(mlir::ValueRange rootValues, WalkCallback walkCallback)
Walks the slice starting from the rootValues using a depth-first traversal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This trait indicates that a terminator operation is "return-like".