28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
81 return std::make_tuple(ptr, ptr);
93 setRequiresNumElements();
105 auto allocaOp = cast<memref::AllocaOp>(op);
107 typeConverter->
convertType(allocaOp.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
113 auto allocatedElementPtr =
114 rewriter.
create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
129 Location loc = allocaScopeOp.getLoc();
134 auto *remainingOpsBlock =
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
143 allocaScopeOp.getLoc()));
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
155 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
162 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
164 returnOp, returnOp.getResults(), continueBlock);
168 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
177 struct AssumeAlignmentOpLowering
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, {},
202 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
206 Value ptrValue = rewriter.
create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207 rewriter.
create<LLVM::AssumeOp>(
208 loc, rewriter.
create<LLVM::ICmpOp>(
209 loc, LLVM::ICmpPredicate::eq,
210 rewriter.
create<LLVM::AndOp>(loc, ptrValue, mask), zero));
227 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
230 LLVM::LLVMFuncOp freeFunc =
231 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
233 if (
auto unrankedTy =
234 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
236 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
238 rewriter, op.getLoc(),
257 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
259 Type operandType = dimOp.getSource().getType();
260 if (isa<UnrankedMemRefType>(operandType)) {
261 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
262 operandType, dimOp, adaptor.getOperands(), rewriter);
263 if (failed(extractedSize))
265 rewriter.
replaceOp(dimOp, {*extractedSize});
268 if (isa<MemRefType>(operandType)) {
270 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271 adaptor.getOperands(), rewriter)});
274 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
279 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
284 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285 auto scalarMemRefType =
287 FailureOr<unsigned> maybeAddressSpace =
288 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289 if (failed(maybeAddressSpace)) {
290 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
294 unsigned addressSpace = *maybeAddressSpace;
300 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
308 loc, indexPtrTy, elementType, underlyingRankedDesc,
317 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
320 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
324 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
325 if (
auto idx = dimOp.getConstantIndex())
328 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
334 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
340 MemRefType memRefType = cast<MemRefType>(operandType);
341 Type indexType = getIndexType();
342 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
344 if (i >= 0 && i < memRefType.getRank()) {
345 if (memRefType.isDynamicDim(i)) {
348 return descriptor.size(rewriter, loc, i);
351 int64_t dimSize = memRefType.getDimSize(i);
355 Value index = adaptor.getIndex();
356 int64_t rank = memRefType.getRank();
358 return memrefDescriptor.size(rewriter, loc, index, rank);
365 template <
typename Derived>
369 using Base = LoadStoreOpLowering<Derived>;
371 LogicalResult match(Derived op)
const override {
372 MemRefType type = op.getMemRefType();
373 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
404 struct GenericAtomicRMWOpLowering
405 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
409 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
411 auto loc = atomicOp.getLoc();
412 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
424 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426 adaptor.getIndices(), rewriter);
428 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
429 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
435 auto loopArgument = loopBlock->getArgument(0);
437 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
447 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
450 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
452 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
457 loopBlock, newLoaded);
462 rewriter.
replaceOp(atomicOp, {newLoaded});
470 convertGlobalMemrefTypeToLLVM(MemRefType type,
478 Type arrayTy = elementType;
480 for (int64_t dim : llvm::reverse(type.getShape()))
486 struct GlobalMemrefOpLowering
491 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
493 MemRefType type = global.getType();
494 if (!isConvertibleAndHasIdentityMaps(type))
497 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
499 LLVM::Linkage linkage =
500 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
503 if (!global.isExternal() && !global.isUninitialized()) {
504 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505 initialValue = elementsAttr;
509 if (type.getRank() == 0)
510 initialValue = elementsAttr.getSplatValue<
Attribute>();
513 uint64_t alignment = global.getAlignment().value_or(0);
514 FailureOr<unsigned> addressSpace =
515 getTypeConverter()->getMemRefAddressSpace(type);
516 if (failed(addressSpace))
517 return global.emitOpError(
518 "memory space cannot be converted to an integer address space");
520 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521 initialValue, alignment, *addressSpace);
522 if (!global.isExternal() && global.isUninitialized()) {
523 rewriter.
createBlock(&newGlobal.getInitializerRegion());
525 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
526 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
545 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
546 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
550 FailureOr<unsigned> maybeAddressSpace =
551 getTypeConverter()->getMemRefAddressSpace(type);
552 if (failed(maybeAddressSpace))
554 unsigned memSpace = *maybeAddressSpace;
556 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
559 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
563 auto gep = rewriter.
create<LLVM::GEPOp>(
564 loc, ptrTy, arrayTy, addressOf,
570 auto intPtrType = getIntPtrType(memSpace);
571 Value deadBeefConst =
574 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
578 return std::make_tuple(deadBeefPtr, gep);
584 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
588 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
590 auto type = loadOp.getMemRefType();
593 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
594 adaptor.getIndices(), rewriter);
596 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
597 false, loadOp.getNontemporal());
604 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
608 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
610 auto type = op.getMemRefType();
612 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
613 adaptor.getIndices(), rewriter);
615 0,
false, op.getNontemporal());
622 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
626 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
628 auto type = prefetchOp.getMemRefType();
629 auto loc = prefetchOp.getLoc();
631 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
632 adaptor.getIndices(), rewriter);
636 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
640 localityHint, isData);
649 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
652 Type operandType = op.getMemref().getType();
653 if (dyn_cast<UnrankedMemRefType>(operandType)) {
655 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
658 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
659 Type indexType = getIndexType();
662 rankedMemRefType.getRank())});
672 LogicalResult match(memref::CastOp memRefCastOp)
const override {
673 Type srcType = memRefCastOp.getOperand().getType();
674 Type dstType = memRefCastOp.getType();
681 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
682 return success(typeConverter->
convertType(srcType) ==
686 assert(isa<UnrankedMemRefType>(srcType) ||
687 isa<UnrankedMemRefType>(dstType));
690 return !(isa<UnrankedMemRefType>(srcType) &&
691 isa<UnrankedMemRefType>(dstType))
696 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
698 auto srcType = memRefCastOp.getOperand().getType();
699 auto dstType = memRefCastOp.getType();
700 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
701 auto loc = memRefCastOp.getLoc();
704 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
705 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
707 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
712 auto srcMemRefType = cast<MemRefType>(srcType);
713 int64_t rank = srcMemRefType.getRank();
715 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
716 loc, adaptor.getSource(), rewriter);
719 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
725 memRefDesc.setRank(rewriter, loc, rankVal);
727 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
730 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
736 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
739 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
740 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
742 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
756 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
758 auto loc = op.getLoc();
759 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
764 Value numElements = rewriter.
create<LLVM::ConstantOp>(
766 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
767 auto size = srcDesc.size(rewriter, loc, pos);
768 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
772 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
775 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
777 Type elementType = typeConverter->
convertType(srcType.getElementType());
779 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
780 Value srcOffset = srcDesc.offset(rewriter, loc);
782 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
784 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
785 Value targetOffset = targetDesc.offset(rewriter, loc);
787 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
788 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
796 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
798 auto loc = op.getLoc();
799 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
800 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
803 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
804 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
806 auto *typeConverter = getTypeConverter();
808 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
813 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
818 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
820 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
821 Value unrankedSource =
822 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
823 : adaptor.getSource();
824 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
825 Value unrankedTarget =
826 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
827 : adaptor.getTarget();
830 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
835 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
836 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
840 auto sourcePtr =
promote(unrankedSource);
841 auto targetPtr =
promote(unrankedTarget);
845 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
847 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
848 rewriter.
create<LLVM::CallOp>(loc, copyFn,
852 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
860 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
862 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
863 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
866 auto memrefType = dyn_cast<mlir::MemRefType>(type);
871 (memrefType.getLayout().isIdentity() ||
872 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
876 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
877 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
879 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
883 struct MemorySpaceCastOpLowering
889 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
893 Type resultType = op.getDest().getType();
894 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
895 auto resultDescType =
896 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
897 Type newPtrType = resultDescType.getBody()[0];
903 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
905 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
907 resultTypeR, descVals);
911 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
914 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
915 FailureOr<unsigned> maybeSourceAddrSpace =
916 getTypeConverter()->getMemRefAddressSpace(sourceType);
917 if (failed(maybeSourceAddrSpace))
919 "non-integer source address space");
920 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
921 FailureOr<unsigned> maybeResultAddrSpace =
922 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
923 if (failed(maybeResultAddrSpace))
925 "non-integer result address space");
926 unsigned resultAddrSpace = *maybeResultAddrSpace;
929 Value rank = sourceDesc.rank(rewriter, loc);
930 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
934 rewriter, loc, typeConverter->convertType(resultTypeU));
935 result.setRank(rewriter, loc, rank);
938 result, resultAddrSpace, sizes);
939 Value resultUnderlyingSize = sizes.front();
940 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
941 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
942 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
945 auto sourceElemPtrType =
947 auto resultElemPtrType =
950 Value allocatedPtr = sourceDesc.allocatedPtr(
951 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
953 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
954 sourceUnderlyingDesc, sourceElemPtrType);
955 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
956 loc, resultElemPtrType, allocatedPtr);
957 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
958 loc, resultElemPtrType, alignedPtr);
960 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
961 resultElemPtrType, allocatedPtr);
962 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
963 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
966 Value sourceIndexVals =
967 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
968 sourceUnderlyingDesc, sourceElemPtrType);
969 Value resultIndexVals =
970 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
971 resultUnderlyingDesc, resultElemPtrType);
973 int64_t bytesToSkip =
975 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
977 loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
979 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
980 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
993 static void extractPointersAndOffset(
Location loc,
996 Value originalOperand,
997 Value convertedOperand,
999 Value *offset =
nullptr) {
1001 if (isa<MemRefType>(operandType)) {
1003 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1004 *alignedPtr = desc.alignedPtr(rewriter, loc);
1005 if (offset !=
nullptr)
1006 *offset = desc.offset(rewriter, loc);
1012 cast<UnrankedMemRefType>(operandType));
1013 auto elementPtrType =
1019 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1022 rewriter, loc, underlyingDescPtr, elementPtrType);
1024 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1025 if (offset !=
nullptr) {
1027 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1031 struct MemRefReinterpretCastOpLowering
1037 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1039 Type srcType = castOp.getSource().getType();
1042 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1043 adaptor, &descriptor)))
1045 rewriter.
replaceOp(castOp, {descriptor});
1050 LogicalResult convertSourceMemRefToDescriptor(
1052 memref::ReinterpretCastOp castOp,
1053 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1054 MemRefType targetMemRefType =
1055 cast<MemRefType>(castOp.getResult().getType());
1056 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1058 if (!llvmTargetDescriptorTy)
1066 Value allocatedPtr, alignedPtr;
1067 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1068 castOp.getSource(), adaptor.getSource(),
1069 &allocatedPtr, &alignedPtr);
1070 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1071 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1074 if (castOp.isDynamicOffset(0))
1075 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1077 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1080 unsigned dynSizeId = 0;
1081 unsigned dynStrideId = 0;
1082 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1083 if (castOp.isDynamicSize(i))
1084 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1086 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1088 if (castOp.isDynamicStride(i))
1089 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1091 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1098 struct MemRefReshapeOpLowering
1103 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1105 Type srcType = reshapeOp.getSource().getType();
1108 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1109 adaptor, &descriptor)))
1111 rewriter.
replaceOp(reshapeOp, {descriptor});
1118 Type srcType, memref::ReshapeOp reshapeOp,
1119 memref::ReshapeOp::Adaptor adaptor,
1120 Value *descriptor)
const {
1121 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1122 if (shapeMemRefType.hasStaticShape()) {
1123 MemRefType targetMemRefType =
1124 cast<MemRefType>(reshapeOp.getResult().getType());
1125 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1127 if (!llvmTargetDescriptorTy)
1136 Value allocatedPtr, alignedPtr;
1137 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1138 reshapeOp.getSource(), adaptor.getSource(),
1139 &allocatedPtr, &alignedPtr);
1140 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1141 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1148 reshapeOp,
"failed to get stride and offset exprs");
1150 if (!isStaticStrideOrOffset(offset))
1152 "dynamic offset is unsupported");
1154 desc.setConstantOffset(rewriter, loc, offset);
1156 assert(targetMemRefType.getLayout().isIdentity() &&
1157 "Identity layout map is a precondition of a valid reshape op");
1159 Type indexType = getIndexType();
1160 Value stride =
nullptr;
1161 int64_t targetRank = targetMemRefType.getRank();
1162 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1163 if (!ShapedType::isDynamic(strides[i])) {
1168 }
else if (!stride) {
1178 if (!targetMemRefType.isDynamicDim(i)) {
1180 targetMemRefType.getDimSize(i));
1182 Value shapeOp = reshapeOp.getShape();
1184 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1185 Type indexType = getIndexType();
1186 if (dimSize.
getType() != indexType)
1188 rewriter, loc, indexType, dimSize);
1189 assert(dimSize &&
"Invalid memref element type");
1192 desc.setSize(rewriter, loc, i, dimSize);
1193 desc.setStride(rewriter, loc, i, stride);
1196 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1206 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1209 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1210 unsigned addressSpace =
1211 *getTypeConverter()->getMemRefAddressSpace(targetType);
1216 rewriter, loc, typeConverter->
convertType(targetType));
1217 targetDesc.setRank(rewriter, loc, resultRank);
1220 targetDesc, addressSpace, sizes);
1221 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1224 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1227 Value allocatedPtr, alignedPtr, offset;
1228 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1229 reshapeOp.getSource(), adaptor.getSource(),
1230 &allocatedPtr, &alignedPtr, &offset);
1233 auto elementPtrType =
1237 elementPtrType, allocatedPtr);
1239 underlyingDescPtr, elementPtrType,
1242 underlyingDescPtr, elementPtrType,
1248 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1250 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1251 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1253 Value resultRankMinusOne =
1254 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1257 Type indexType = getTypeConverter()->getIndexType();
1261 {indexType, indexType}, {loc, loc});
1264 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1268 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1277 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1286 loc, llvmIndexPtrType,
1287 typeConverter->
convertType(shapeMemRefType.getElementType()),
1288 shapeOperandPtr, indexArg);
1289 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1291 targetSizesBase, indexArg, size);
1295 targetStridesBase, indexArg, strideArg);
1296 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1299 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1308 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1309 remainder, std::nullopt);
1314 *descriptor = targetDesc;
1321 template <
typename ReshapeOp>
1322 class ReassociatingReshapeOpConversion
1326 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1329 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1333 "reassociation operations should have been expanded beforehand");
1343 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1346 subViewOp,
"subview operations should have been expanded beforehand");
1362 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1364 auto loc = transposeOp.getLoc();
1368 if (transposeOp.getPermutation().isIdentity())
1369 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1373 typeConverter->
convertType(transposeOp.getIn().getType()));
1377 targetMemRef.setAllocatedPtr(rewriter, loc,
1378 viewMemRef.allocatedPtr(rewriter, loc));
1379 targetMemRef.setAlignedPtr(rewriter, loc,
1380 viewMemRef.alignedPtr(rewriter, loc));
1383 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1389 for (
const auto &en :
1391 int targetPos = en.index();
1392 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1393 targetMemRef.setSize(rewriter, loc, targetPos,
1394 viewMemRef.size(rewriter, loc, sourcePos));
1395 targetMemRef.setStride(rewriter, loc, targetPos,
1396 viewMemRef.stride(rewriter, loc, sourcePos));
1399 rewriter.
replaceOp(transposeOp, {targetMemRef});
1416 Type indexType)
const {
1417 assert(idx < shape.size());
1418 if (!ShapedType::isDynamic(shape[idx]))
1422 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1423 return dynamicSizes[nDynamic];
1432 Value runningStride,
unsigned idx,
Type indexType)
const {
1433 assert(idx < strides.size());
1434 if (!ShapedType::isDynamic(strides[idx]))
1437 return runningStride
1438 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1440 assert(!runningStride);
1445 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1447 auto loc = viewOp.getLoc();
1449 auto viewMemRefType = viewOp.getType();
1450 auto targetElementTy =
1451 typeConverter->
convertType(viewMemRefType.getElementType());
1452 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1453 if (!targetDescTy || !targetElementTy ||
1456 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1462 if (failed(successStrides))
1463 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1464 assert(offset == 0 &&
"expected offset to be 0");
1468 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1469 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1477 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1478 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1479 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1482 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1483 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1485 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1486 adaptor.getByteShift());
1488 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1490 Type indexType = getIndexType();
1494 targetMemRef.setOffset(
1499 if (viewMemRefType.getRank() == 0)
1500 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1503 Value stride =
nullptr, nextSize =
nullptr;
1504 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1506 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1507 adaptor.getSizes(), i, indexType);
1508 targetMemRef.setSize(rewriter, loc, i, size);
1511 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1512 targetMemRef.setStride(rewriter, loc, i, stride);
1516 rewriter.
replaceOp(viewOp, {targetMemRef});
1527 static std::optional<LLVM::AtomicBinOp>
1528 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1529 switch (atomicOp.getKind()) {
1530 case arith::AtomicRMWKind::addf:
1531 return LLVM::AtomicBinOp::fadd;
1532 case arith::AtomicRMWKind::addi:
1533 return LLVM::AtomicBinOp::add;
1534 case arith::AtomicRMWKind::assign:
1535 return LLVM::AtomicBinOp::xchg;
1536 case arith::AtomicRMWKind::maximumf:
1537 return LLVM::AtomicBinOp::fmax;
1538 case arith::AtomicRMWKind::maxs:
1540 case arith::AtomicRMWKind::maxu:
1541 return LLVM::AtomicBinOp::umax;
1542 case arith::AtomicRMWKind::minimumf:
1543 return LLVM::AtomicBinOp::fmin;
1544 case arith::AtomicRMWKind::mins:
1546 case arith::AtomicRMWKind::minu:
1547 return LLVM::AtomicBinOp::umin;
1548 case arith::AtomicRMWKind::ori:
1549 return LLVM::AtomicBinOp::_or;
1550 case arith::AtomicRMWKind::andi:
1551 return LLVM::AtomicBinOp::_and;
1553 return std::nullopt;
1555 llvm_unreachable(
"Invalid AtomicRMWKind");
1558 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1562 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1564 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1567 auto memRefType = atomicOp.getMemRefType();
1573 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1574 adaptor.getIndices(), rewriter);
1576 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1577 LLVM::AtomicOrdering::acq_rel);
1583 class ConvertExtractAlignedPointerAsIndex
1590 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1598 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1604 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1607 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1612 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1619 class ExtractStridedMetadataOpLowering
1626 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1635 Location loc = extractStridedMetadataOp.getLoc();
1636 Value source = extractStridedMetadataOp.getSource();
1638 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1639 int64_t rank = sourceMemRefType.getRank();
1641 results.reserve(2 + rank * 2);
1644 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1645 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1647 rewriter, loc, *getTypeConverter(),
1648 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1649 baseBuffer, alignedBuffer);
1650 results.push_back((
Value)dstMemRef);
1653 results.push_back(sourceMemRef.offset(rewriter, loc));
1656 for (
unsigned i = 0; i < rank; ++i)
1657 results.push_back(sourceMemRef.size(rewriter, loc, i));
1659 for (
unsigned i = 0; i < rank; ++i)
1660 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1662 rewriter.
replaceOp(extractStridedMetadataOp, results);
1674 AllocaScopeOpLowering,
1675 AtomicRMWOpLowering,
1676 AssumeAlignmentOpLowering,
1677 ConvertExtractAlignedPointerAsIndex,
1679 ExtractStridedMetadataOpLowering,
1680 GenericAtomicRMWOpLowering,
1681 GlobalMemrefOpLowering,
1682 GetGlobalMemrefOpLowering,
1684 MemRefCastOpLowering,
1685 MemRefCopyOpLowering,
1686 MemorySpaceCastOpLowering,
1687 MemRefReinterpretCastOpLowering,
1688 MemRefReshapeOpLowering,
1691 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1692 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1696 ViewOpLowering>(converter);
1700 patterns.
add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1702 patterns.
add<AllocOpLowering, DeallocOpLowering>(converter);
1706 struct FinalizeMemRefToLLVMConversionPass
1707 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1708 FinalizeMemRefToLLVMConversionPass> {
1709 using FinalizeMemRefToLLVMConversionPassBase::
1710 FinalizeMemRefToLLVMConversionPassBase;
1712 void runOnOperation()
override {
1714 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1716 dataLayoutAnalysis.getAtOrAbove(op));
1721 options.useGenericFunctions = useGenericFunctions;
1724 options.overrideIndexBitwidth(indexBitwidth);
1727 &dataLayoutAnalysis);
1731 target.addLegalOp<func::FuncOp>();
1733 signalPassFailure();
1740 void loadDependentDialects(
MLIRContext *context)
const final {
1741 context->loadDialect<LLVM::LLVMDialect>();
1746 void populateConvertToLLVMConversionPatterns(
1757 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.