28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
45 static FailureOr<LLVM::LLVMFuncOp>
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
81 return std::make_tuple(ptr, ptr);
93 setRequiresNumElements();
105 auto allocaOp = cast<memref::AllocaOp>(op);
107 typeConverter->
convertType(allocaOp.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
113 auto allocatedElementPtr =
114 rewriter.
create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
129 Location loc = allocaScopeOp.getLoc();
134 auto *remainingOpsBlock =
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
143 allocaScopeOp.getLoc()));
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
155 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
162 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
164 returnOp, returnOp.getResults(), continueBlock);
168 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
177 struct AssumeAlignmentOpLowering
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, {},
200 Value alignmentConst =
220 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
223 FailureOr<LLVM::LLVMFuncOp> freeFunc =
224 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225 if (failed(freeFunc))
228 if (
auto unrankedTy =
229 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
231 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
233 rewriter, op.getLoc(),
253 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
255 Type operandType = dimOp.getSource().getType();
256 if (isa<UnrankedMemRefType>(operandType)) {
257 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
258 operandType, dimOp, adaptor.getOperands(), rewriter);
259 if (failed(extractedSize))
261 rewriter.
replaceOp(dimOp, {*extractedSize});
264 if (isa<MemRefType>(operandType)) {
266 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
267 adaptor.getOperands(), rewriter)});
270 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
275 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
280 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
281 auto scalarMemRefType =
283 FailureOr<unsigned> maybeAddressSpace =
284 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
285 if (failed(maybeAddressSpace)) {
286 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
290 unsigned addressSpace = *maybeAddressSpace;
296 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
304 loc, indexPtrTy, elementType, underlyingRankedDesc,
313 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
316 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
320 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
321 if (
auto idx = dimOp.getConstantIndex())
324 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
325 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
330 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
336 MemRefType memRefType = cast<MemRefType>(operandType);
337 Type indexType = getIndexType();
338 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
340 if (i >= 0 && i < memRefType.getRank()) {
341 if (memRefType.isDynamicDim(i)) {
344 return descriptor.size(rewriter, loc, i);
347 int64_t dimSize = memRefType.getDimSize(i);
351 Value index = adaptor.getIndex();
352 int64_t rank = memRefType.getRank();
354 return memrefDescriptor.size(rewriter, loc, index, rank);
361 template <
typename Derived>
365 using Base = LoadStoreOpLowering<Derived>;
367 LogicalResult match(Derived op)
const override {
368 MemRefType type = op.getMemRefType();
369 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
400 struct GenericAtomicRMWOpLowering
401 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
405 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
407 auto loc = atomicOp.getLoc();
408 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
420 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
421 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
422 adaptor.getIndices(), rewriter);
424 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
425 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
431 auto loopArgument = loopBlock->getArgument(0);
433 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
443 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
444 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
445 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
446 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
448 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
449 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
453 loopBlock, newLoaded);
458 rewriter.
replaceOp(atomicOp, {newLoaded});
466 convertGlobalMemrefTypeToLLVM(MemRefType type,
474 Type arrayTy = elementType;
476 for (int64_t dim : llvm::reverse(type.getShape()))
482 struct GlobalMemrefOpLowering
487 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
489 MemRefType type = global.getType();
490 if (!isConvertibleAndHasIdentityMaps(type))
493 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
495 LLVM::Linkage linkage =
496 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
499 if (!global.isExternal() && !global.isUninitialized()) {
500 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
501 initialValue = elementsAttr;
505 if (type.getRank() == 0)
506 initialValue = elementsAttr.getSplatValue<
Attribute>();
509 uint64_t alignment = global.getAlignment().value_or(0);
510 FailureOr<unsigned> addressSpace =
511 getTypeConverter()->getMemRefAddressSpace(type);
512 if (failed(addressSpace))
513 return global.emitOpError(
514 "memory space cannot be converted to an integer address space");
516 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
517 initialValue, alignment, *addressSpace);
518 if (!global.isExternal() && global.isUninitialized()) {
519 rewriter.
createBlock(&newGlobal.getInitializerRegion());
521 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
522 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
541 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
542 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
546 FailureOr<unsigned> maybeAddressSpace =
547 getTypeConverter()->getMemRefAddressSpace(type);
548 if (failed(maybeAddressSpace))
550 unsigned memSpace = *maybeAddressSpace;
552 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
555 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
559 auto gep = rewriter.
create<LLVM::GEPOp>(
560 loc, ptrTy, arrayTy, addressOf,
566 auto intPtrType = getIntPtrType(memSpace);
567 Value deadBeefConst =
570 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
574 return std::make_tuple(deadBeefPtr, gep);
580 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
584 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
586 auto type = loadOp.getMemRefType();
589 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
590 adaptor.getIndices(), rewriter);
592 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
593 false, loadOp.getNontemporal());
600 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
604 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
606 auto type = op.getMemRefType();
608 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
609 adaptor.getIndices(), rewriter);
611 0,
false, op.getNontemporal());
618 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
622 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
624 auto type = prefetchOp.getMemRefType();
625 auto loc = prefetchOp.getLoc();
627 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
628 adaptor.getIndices(), rewriter);
632 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
636 localityHint, isData);
645 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
648 Type operandType = op.getMemref().getType();
649 if (dyn_cast<UnrankedMemRefType>(operandType)) {
651 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
654 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
655 Type indexType = getIndexType();
658 rankedMemRefType.getRank())});
668 LogicalResult match(memref::CastOp memRefCastOp)
const override {
669 Type srcType = memRefCastOp.getOperand().getType();
670 Type dstType = memRefCastOp.getType();
677 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
678 return success(typeConverter->
convertType(srcType) ==
682 assert(isa<UnrankedMemRefType>(srcType) ||
683 isa<UnrankedMemRefType>(dstType));
686 return !(isa<UnrankedMemRefType>(srcType) &&
687 isa<UnrankedMemRefType>(dstType))
692 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
694 auto srcType = memRefCastOp.getOperand().getType();
695 auto dstType = memRefCastOp.getType();
696 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
697 auto loc = memRefCastOp.getLoc();
700 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
701 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
703 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
708 auto srcMemRefType = cast<MemRefType>(srcType);
709 int64_t rank = srcMemRefType.getRank();
711 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
712 loc, adaptor.getSource(), rewriter);
715 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
721 memRefDesc.setRank(rewriter, loc, rankVal);
723 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
726 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
732 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
735 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
736 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
738 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
752 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
754 auto loc = op.getLoc();
755 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
760 Value numElements = rewriter.
create<LLVM::ConstantOp>(
762 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
763 auto size = srcDesc.size(rewriter, loc, pos);
764 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
768 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
771 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
773 Type elementType = typeConverter->
convertType(srcType.getElementType());
775 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
776 Value srcOffset = srcDesc.offset(rewriter, loc);
778 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
780 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
781 Value targetOffset = targetDesc.offset(rewriter, loc);
783 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
784 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
792 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
794 auto loc = op.getLoc();
795 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
796 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
799 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
800 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
802 auto *typeConverter = getTypeConverter();
804 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
809 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
814 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
816 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
817 Value unrankedSource =
818 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
819 : adaptor.getSource();
820 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
821 Value unrankedTarget =
822 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
823 : adaptor.getTarget();
826 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
831 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
832 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
836 auto sourcePtr =
promote(unrankedSource);
837 auto targetPtr =
promote(unrankedTarget);
841 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
843 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
846 rewriter.
create<LLVM::CallOp>(loc, copyFn.value(),
850 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
858 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
860 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
861 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
864 auto memrefType = dyn_cast<mlir::MemRefType>(type);
869 (memrefType.getLayout().isIdentity() ||
870 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
874 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
875 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
877 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
881 struct MemorySpaceCastOpLowering
887 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
891 Type resultType = op.getDest().getType();
892 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
893 auto resultDescType =
894 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
895 Type newPtrType = resultDescType.getBody()[0];
901 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
903 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
905 resultTypeR, descVals);
909 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
912 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
913 FailureOr<unsigned> maybeSourceAddrSpace =
914 getTypeConverter()->getMemRefAddressSpace(sourceType);
915 if (failed(maybeSourceAddrSpace))
917 "non-integer source address space");
918 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
919 FailureOr<unsigned> maybeResultAddrSpace =
920 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
921 if (failed(maybeResultAddrSpace))
923 "non-integer result address space");
924 unsigned resultAddrSpace = *maybeResultAddrSpace;
927 Value rank = sourceDesc.rank(rewriter, loc);
928 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
932 rewriter, loc, typeConverter->convertType(resultTypeU));
933 result.setRank(rewriter, loc, rank);
936 result, resultAddrSpace, sizes);
937 Value resultUnderlyingSize = sizes.front();
938 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
939 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
940 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
943 auto sourceElemPtrType =
945 auto resultElemPtrType =
948 Value allocatedPtr = sourceDesc.allocatedPtr(
949 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
951 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
952 sourceUnderlyingDesc, sourceElemPtrType);
953 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
954 loc, resultElemPtrType, allocatedPtr);
955 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
956 loc, resultElemPtrType, alignedPtr);
958 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
959 resultElemPtrType, allocatedPtr);
960 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
961 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
964 Value sourceIndexVals =
965 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
966 sourceUnderlyingDesc, sourceElemPtrType);
967 Value resultIndexVals =
968 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
969 resultUnderlyingDesc, resultElemPtrType);
971 int64_t bytesToSkip =
973 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
974 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
975 loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
977 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
978 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
991 static void extractPointersAndOffset(
Location loc,
994 Value originalOperand,
995 Value convertedOperand,
997 Value *offset =
nullptr) {
999 if (isa<MemRefType>(operandType)) {
1001 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1002 *alignedPtr = desc.alignedPtr(rewriter, loc);
1003 if (offset !=
nullptr)
1004 *offset = desc.offset(rewriter, loc);
1010 cast<UnrankedMemRefType>(operandType));
1011 auto elementPtrType =
1017 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1020 rewriter, loc, underlyingDescPtr, elementPtrType);
1022 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1023 if (offset !=
nullptr) {
1025 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1029 struct MemRefReinterpretCastOpLowering
1035 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1037 Type srcType = castOp.getSource().getType();
1040 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1041 adaptor, &descriptor)))
1043 rewriter.
replaceOp(castOp, {descriptor});
1048 LogicalResult convertSourceMemRefToDescriptor(
1050 memref::ReinterpretCastOp castOp,
1051 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1052 MemRefType targetMemRefType =
1053 cast<MemRefType>(castOp.getResult().getType());
1054 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1056 if (!llvmTargetDescriptorTy)
1064 Value allocatedPtr, alignedPtr;
1065 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1066 castOp.getSource(), adaptor.getSource(),
1067 &allocatedPtr, &alignedPtr);
1068 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1069 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1072 if (castOp.isDynamicOffset(0))
1073 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1075 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1078 unsigned dynSizeId = 0;
1079 unsigned dynStrideId = 0;
1080 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1081 if (castOp.isDynamicSize(i))
1082 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1084 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1086 if (castOp.isDynamicStride(i))
1087 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1089 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1096 struct MemRefReshapeOpLowering
1101 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1103 Type srcType = reshapeOp.getSource().getType();
1106 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1107 adaptor, &descriptor)))
1109 rewriter.
replaceOp(reshapeOp, {descriptor});
1116 Type srcType, memref::ReshapeOp reshapeOp,
1117 memref::ReshapeOp::Adaptor adaptor,
1118 Value *descriptor)
const {
1119 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1120 if (shapeMemRefType.hasStaticShape()) {
1121 MemRefType targetMemRefType =
1122 cast<MemRefType>(reshapeOp.getResult().getType());
1123 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1125 if (!llvmTargetDescriptorTy)
1134 Value allocatedPtr, alignedPtr;
1135 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1136 reshapeOp.getSource(), adaptor.getSource(),
1137 &allocatedPtr, &alignedPtr);
1138 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1139 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1144 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1146 reshapeOp,
"failed to get stride and offset exprs");
1148 if (!isStaticStrideOrOffset(offset))
1150 "dynamic offset is unsupported");
1152 desc.setConstantOffset(rewriter, loc, offset);
1154 assert(targetMemRefType.getLayout().isIdentity() &&
1155 "Identity layout map is a precondition of a valid reshape op");
1157 Type indexType = getIndexType();
1158 Value stride =
nullptr;
1159 int64_t targetRank = targetMemRefType.getRank();
1160 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1161 if (!ShapedType::isDynamic(strides[i])) {
1166 }
else if (!stride) {
1176 if (!targetMemRefType.isDynamicDim(i)) {
1178 targetMemRefType.getDimSize(i));
1180 Value shapeOp = reshapeOp.getShape();
1182 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1183 Type indexType = getIndexType();
1184 if (dimSize.
getType() != indexType)
1186 rewriter, loc, indexType, dimSize);
1187 assert(dimSize &&
"Invalid memref element type");
1190 desc.setSize(rewriter, loc, i, dimSize);
1191 desc.setStride(rewriter, loc, i, stride);
1194 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1204 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1207 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1208 unsigned addressSpace =
1209 *getTypeConverter()->getMemRefAddressSpace(targetType);
1214 rewriter, loc, typeConverter->
convertType(targetType));
1215 targetDesc.setRank(rewriter, loc, resultRank);
1218 targetDesc, addressSpace, sizes);
1219 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1222 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1225 Value allocatedPtr, alignedPtr, offset;
1226 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1227 reshapeOp.getSource(), adaptor.getSource(),
1228 &allocatedPtr, &alignedPtr, &offset);
1231 auto elementPtrType =
1235 elementPtrType, allocatedPtr);
1237 underlyingDescPtr, elementPtrType,
1240 underlyingDescPtr, elementPtrType,
1246 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1248 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1249 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1251 Value resultRankMinusOne =
1252 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1255 Type indexType = getTypeConverter()->getIndexType();
1259 {indexType, indexType}, {loc, loc});
1262 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1266 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1275 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1284 loc, llvmIndexPtrType,
1285 typeConverter->
convertType(shapeMemRefType.getElementType()),
1286 shapeOperandPtr, indexArg);
1287 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1289 targetSizesBase, indexArg, size);
1293 targetStridesBase, indexArg, strideArg);
1294 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1297 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1306 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1307 remainder, std::nullopt);
1312 *descriptor = targetDesc;
1319 template <
typename ReshapeOp>
1320 class ReassociatingReshapeOpConversion
1324 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1327 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1331 "reassociation operations should have been expanded beforehand");
1341 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1344 subViewOp,
"subview operations should have been expanded beforehand");
1360 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1362 auto loc = transposeOp.getLoc();
1366 if (transposeOp.getPermutation().isIdentity())
1367 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1371 typeConverter->
convertType(transposeOp.getIn().getType()));
1375 targetMemRef.setAllocatedPtr(rewriter, loc,
1376 viewMemRef.allocatedPtr(rewriter, loc));
1377 targetMemRef.setAlignedPtr(rewriter, loc,
1378 viewMemRef.alignedPtr(rewriter, loc));
1381 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1387 for (
const auto &en :
1389 int targetPos = en.index();
1390 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1391 targetMemRef.setSize(rewriter, loc, targetPos,
1392 viewMemRef.size(rewriter, loc, sourcePos));
1393 targetMemRef.setStride(rewriter, loc, targetPos,
1394 viewMemRef.stride(rewriter, loc, sourcePos));
1397 rewriter.
replaceOp(transposeOp, {targetMemRef});
1414 Type indexType)
const {
1415 assert(idx < shape.size());
1416 if (!ShapedType::isDynamic(shape[idx]))
1420 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1421 return dynamicSizes[nDynamic];
1430 Value runningStride,
unsigned idx,
Type indexType)
const {
1431 assert(idx < strides.size());
1432 if (!ShapedType::isDynamic(strides[idx]))
1435 return runningStride
1436 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1438 assert(!runningStride);
1443 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1445 auto loc = viewOp.getLoc();
1447 auto viewMemRefType = viewOp.getType();
1448 auto targetElementTy =
1449 typeConverter->
convertType(viewMemRefType.getElementType());
1450 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1451 if (!targetDescTy || !targetElementTy ||
1454 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1459 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1460 if (failed(successStrides))
1461 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1462 assert(offset == 0 &&
"expected offset to be 0");
1466 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1467 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1475 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1476 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1477 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1480 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1481 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1483 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1484 adaptor.getByteShift());
1486 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1488 Type indexType = getIndexType();
1492 targetMemRef.setOffset(
1497 if (viewMemRefType.getRank() == 0)
1498 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1501 Value stride =
nullptr, nextSize =
nullptr;
1502 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1504 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1505 adaptor.getSizes(), i, indexType);
1506 targetMemRef.setSize(rewriter, loc, i, size);
1509 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1510 targetMemRef.setStride(rewriter, loc, i, stride);
1514 rewriter.
replaceOp(viewOp, {targetMemRef});
1525 static std::optional<LLVM::AtomicBinOp>
1526 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1527 switch (atomicOp.getKind()) {
1528 case arith::AtomicRMWKind::addf:
1529 return LLVM::AtomicBinOp::fadd;
1530 case arith::AtomicRMWKind::addi:
1531 return LLVM::AtomicBinOp::add;
1532 case arith::AtomicRMWKind::assign:
1533 return LLVM::AtomicBinOp::xchg;
1534 case arith::AtomicRMWKind::maximumf:
1535 return LLVM::AtomicBinOp::fmax;
1536 case arith::AtomicRMWKind::maxs:
1538 case arith::AtomicRMWKind::maxu:
1539 return LLVM::AtomicBinOp::umax;
1540 case arith::AtomicRMWKind::minimumf:
1541 return LLVM::AtomicBinOp::fmin;
1542 case arith::AtomicRMWKind::mins:
1544 case arith::AtomicRMWKind::minu:
1545 return LLVM::AtomicBinOp::umin;
1546 case arith::AtomicRMWKind::ori:
1547 return LLVM::AtomicBinOp::_or;
1548 case arith::AtomicRMWKind::andi:
1549 return LLVM::AtomicBinOp::_and;
1551 return std::nullopt;
1553 llvm_unreachable(
"Invalid AtomicRMWKind");
1556 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1560 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1562 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1565 auto memRefType = atomicOp.getMemRefType();
1568 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1571 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1572 adaptor.getIndices(), rewriter);
1574 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1575 LLVM::AtomicOrdering::acq_rel);
1581 class ConvertExtractAlignedPointerAsIndex
1588 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1596 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1602 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1605 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1610 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1617 class ExtractStridedMetadataOpLowering
1624 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1633 Location loc = extractStridedMetadataOp.getLoc();
1634 Value source = extractStridedMetadataOp.getSource();
1636 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1637 int64_t rank = sourceMemRefType.getRank();
1639 results.reserve(2 + rank * 2);
1642 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1643 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1645 rewriter, loc, *getTypeConverter(),
1646 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1647 baseBuffer, alignedBuffer);
1648 results.push_back((
Value)dstMemRef);
1651 results.push_back(sourceMemRef.offset(rewriter, loc));
1654 for (
unsigned i = 0; i < rank; ++i)
1655 results.push_back(sourceMemRef.size(rewriter, loc, i));
1657 for (
unsigned i = 0; i < rank; ++i)
1658 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1660 rewriter.
replaceOp(extractStridedMetadataOp, results);
1672 AllocaScopeOpLowering,
1673 AtomicRMWOpLowering,
1674 AssumeAlignmentOpLowering,
1675 ConvertExtractAlignedPointerAsIndex,
1677 ExtractStridedMetadataOpLowering,
1678 GenericAtomicRMWOpLowering,
1679 GlobalMemrefOpLowering,
1680 GetGlobalMemrefOpLowering,
1682 MemRefCastOpLowering,
1683 MemRefCopyOpLowering,
1684 MemorySpaceCastOpLowering,
1685 MemRefReinterpretCastOpLowering,
1686 MemRefReshapeOpLowering,
1689 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1690 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1694 ViewOpLowering>(converter);
1698 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1700 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1704 struct FinalizeMemRefToLLVMConversionPass
1705 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1706 FinalizeMemRefToLLVMConversionPass> {
1707 using FinalizeMemRefToLLVMConversionPassBase::
1708 FinalizeMemRefToLLVMConversionPassBase;
1710 void runOnOperation()
override {
1712 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1714 dataLayoutAnalysis.getAtOrAbove(op));
1719 options.useGenericFunctions = useGenericFunctions;
1722 options.overrideIndexBitwidth(indexBitwidth);
1725 &dataLayoutAnalysis);
1729 target.addLegalOp<func::FuncOp>();
1731 signalPassFailure();
1738 void loadDependentDialects(
MLIRContext *context)
const final {
1739 context->loadDialect<LLVM::LLVMDialect>();
1744 void populateConvertToLLVMConversionPatterns(
1755 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(Operation *moduleOp)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.