28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
81 return std::make_tuple(ptr, ptr);
93 setRequiresNumElements();
105 auto allocaOp = cast<memref::AllocaOp>(op);
107 typeConverter->
convertType(allocaOp.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
113 auto allocatedElementPtr =
114 rewriter.
create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
129 Location loc = allocaScopeOp.getLoc();
134 auto *remainingOpsBlock =
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
143 allocaScopeOp.getLoc()));
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
155 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
162 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
164 returnOp, returnOp.getResults(), continueBlock);
168 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
177 struct AssumeAlignmentOpLowering
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, {},
200 Value alignmentConst =
220 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
223 LLVM::LLVMFuncOp freeFunc =
224 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
226 if (
auto unrankedTy =
227 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
229 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
231 rewriter, op.getLoc(),
250 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
252 Type operandType = dimOp.getSource().getType();
253 if (isa<UnrankedMemRefType>(operandType)) {
254 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
255 operandType, dimOp, adaptor.getOperands(), rewriter);
256 if (failed(extractedSize))
258 rewriter.
replaceOp(dimOp, {*extractedSize});
261 if (isa<MemRefType>(operandType)) {
263 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
264 adaptor.getOperands(), rewriter)});
267 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
272 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
277 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
278 auto scalarMemRefType =
280 FailureOr<unsigned> maybeAddressSpace =
281 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
282 if (failed(maybeAddressSpace)) {
283 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
287 unsigned addressSpace = *maybeAddressSpace;
293 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
301 loc, indexPtrTy, elementType, underlyingRankedDesc,
310 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
313 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
317 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
318 if (
auto idx = dimOp.getConstantIndex())
321 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
322 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
327 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
333 MemRefType memRefType = cast<MemRefType>(operandType);
334 Type indexType = getIndexType();
335 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
337 if (i >= 0 && i < memRefType.getRank()) {
338 if (memRefType.isDynamicDim(i)) {
341 return descriptor.size(rewriter, loc, i);
344 int64_t dimSize = memRefType.getDimSize(i);
348 Value index = adaptor.getIndex();
349 int64_t rank = memRefType.getRank();
351 return memrefDescriptor.size(rewriter, loc, index, rank);
358 template <
typename Derived>
362 using Base = LoadStoreOpLowering<Derived>;
364 LogicalResult match(Derived op)
const override {
365 MemRefType type = op.getMemRefType();
366 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
397 struct GenericAtomicRMWOpLowering
398 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
402 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
404 auto loc = atomicOp.getLoc();
405 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
417 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
418 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
419 adaptor.getIndices(), rewriter);
421 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
422 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
428 auto loopArgument = loopBlock->getArgument(0);
430 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
440 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
441 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
442 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
443 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
445 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
446 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
450 loopBlock, newLoaded);
455 rewriter.
replaceOp(atomicOp, {newLoaded});
463 convertGlobalMemrefTypeToLLVM(MemRefType type,
471 Type arrayTy = elementType;
473 for (int64_t dim : llvm::reverse(type.getShape()))
479 struct GlobalMemrefOpLowering
484 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
486 MemRefType type = global.getType();
487 if (!isConvertibleAndHasIdentityMaps(type))
490 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
492 LLVM::Linkage linkage =
493 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
496 if (!global.isExternal() && !global.isUninitialized()) {
497 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
498 initialValue = elementsAttr;
502 if (type.getRank() == 0)
503 initialValue = elementsAttr.getSplatValue<
Attribute>();
506 uint64_t alignment = global.getAlignment().value_or(0);
507 FailureOr<unsigned> addressSpace =
508 getTypeConverter()->getMemRefAddressSpace(type);
509 if (failed(addressSpace))
510 return global.emitOpError(
511 "memory space cannot be converted to an integer address space");
513 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
514 initialValue, alignment, *addressSpace);
515 if (!global.isExternal() && global.isUninitialized()) {
516 rewriter.
createBlock(&newGlobal.getInitializerRegion());
518 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
519 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
538 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
539 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
543 FailureOr<unsigned> maybeAddressSpace =
544 getTypeConverter()->getMemRefAddressSpace(type);
545 if (failed(maybeAddressSpace))
547 unsigned memSpace = *maybeAddressSpace;
549 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
552 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
556 auto gep = rewriter.
create<LLVM::GEPOp>(
557 loc, ptrTy, arrayTy, addressOf,
563 auto intPtrType = getIntPtrType(memSpace);
564 Value deadBeefConst =
567 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
571 return std::make_tuple(deadBeefPtr, gep);
577 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
581 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
583 auto type = loadOp.getMemRefType();
586 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
587 adaptor.getIndices(), rewriter);
589 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
590 false, loadOp.getNontemporal());
597 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
601 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
603 auto type = op.getMemRefType();
605 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
606 adaptor.getIndices(), rewriter);
608 0,
false, op.getNontemporal());
615 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
619 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
621 auto type = prefetchOp.getMemRefType();
622 auto loc = prefetchOp.getLoc();
624 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
625 adaptor.getIndices(), rewriter);
629 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
633 localityHint, isData);
642 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
645 Type operandType = op.getMemref().getType();
646 if (dyn_cast<UnrankedMemRefType>(operandType)) {
648 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
651 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
652 Type indexType = getIndexType();
655 rankedMemRefType.getRank())});
665 LogicalResult match(memref::CastOp memRefCastOp)
const override {
666 Type srcType = memRefCastOp.getOperand().getType();
667 Type dstType = memRefCastOp.getType();
674 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
675 return success(typeConverter->
convertType(srcType) ==
679 assert(isa<UnrankedMemRefType>(srcType) ||
680 isa<UnrankedMemRefType>(dstType));
683 return !(isa<UnrankedMemRefType>(srcType) &&
684 isa<UnrankedMemRefType>(dstType))
689 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
691 auto srcType = memRefCastOp.getOperand().getType();
692 auto dstType = memRefCastOp.getType();
693 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
694 auto loc = memRefCastOp.getLoc();
697 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
698 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
700 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
705 auto srcMemRefType = cast<MemRefType>(srcType);
706 int64_t rank = srcMemRefType.getRank();
708 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
709 loc, adaptor.getSource(), rewriter);
712 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
718 memRefDesc.setRank(rewriter, loc, rankVal);
720 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
723 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
729 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
732 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
733 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
735 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
749 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
751 auto loc = op.getLoc();
752 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
757 Value numElements = rewriter.
create<LLVM::ConstantOp>(
759 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
760 auto size = srcDesc.size(rewriter, loc, pos);
761 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
765 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
768 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
770 Type elementType = typeConverter->
convertType(srcType.getElementType());
772 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
773 Value srcOffset = srcDesc.offset(rewriter, loc);
775 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
777 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
778 Value targetOffset = targetDesc.offset(rewriter, loc);
780 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
781 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
789 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
791 auto loc = op.getLoc();
792 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
793 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
796 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
797 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
799 auto *typeConverter = getTypeConverter();
801 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
806 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
811 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
813 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
814 Value unrankedSource =
815 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
816 : adaptor.getSource();
817 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
818 Value unrankedTarget =
819 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
820 : adaptor.getTarget();
823 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
828 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
829 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
833 auto sourcePtr =
promote(unrankedSource);
834 auto targetPtr =
promote(unrankedTarget);
838 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
840 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841 rewriter.
create<LLVM::CallOp>(loc, copyFn,
845 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
853 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
855 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
856 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
859 auto memrefType = dyn_cast<mlir::MemRefType>(type);
864 (memrefType.getLayout().isIdentity() ||
865 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
869 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
870 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
872 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
876 struct MemorySpaceCastOpLowering
882 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
886 Type resultType = op.getDest().getType();
887 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
888 auto resultDescType =
889 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
890 Type newPtrType = resultDescType.getBody()[0];
896 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
898 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
900 resultTypeR, descVals);
904 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
907 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
908 FailureOr<unsigned> maybeSourceAddrSpace =
909 getTypeConverter()->getMemRefAddressSpace(sourceType);
910 if (failed(maybeSourceAddrSpace))
912 "non-integer source address space");
913 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
914 FailureOr<unsigned> maybeResultAddrSpace =
915 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
916 if (failed(maybeResultAddrSpace))
918 "non-integer result address space");
919 unsigned resultAddrSpace = *maybeResultAddrSpace;
922 Value rank = sourceDesc.rank(rewriter, loc);
923 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
927 rewriter, loc, typeConverter->convertType(resultTypeU));
928 result.setRank(rewriter, loc, rank);
931 result, resultAddrSpace, sizes);
932 Value resultUnderlyingSize = sizes.front();
933 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
934 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
935 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
938 auto sourceElemPtrType =
940 auto resultElemPtrType =
943 Value allocatedPtr = sourceDesc.allocatedPtr(
944 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
946 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
947 sourceUnderlyingDesc, sourceElemPtrType);
948 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
949 loc, resultElemPtrType, allocatedPtr);
950 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
951 loc, resultElemPtrType, alignedPtr);
953 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
954 resultElemPtrType, allocatedPtr);
955 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
956 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
959 Value sourceIndexVals =
960 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
961 sourceUnderlyingDesc, sourceElemPtrType);
962 Value resultIndexVals =
963 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
964 resultUnderlyingDesc, resultElemPtrType);
966 int64_t bytesToSkip =
968 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
969 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
970 loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
972 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
973 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
986 static void extractPointersAndOffset(
Location loc,
989 Value originalOperand,
990 Value convertedOperand,
992 Value *offset =
nullptr) {
994 if (isa<MemRefType>(operandType)) {
996 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
997 *alignedPtr = desc.alignedPtr(rewriter, loc);
998 if (offset !=
nullptr)
999 *offset = desc.offset(rewriter, loc);
1005 cast<UnrankedMemRefType>(operandType));
1006 auto elementPtrType =
1012 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1015 rewriter, loc, underlyingDescPtr, elementPtrType);
1017 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1018 if (offset !=
nullptr) {
1020 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1024 struct MemRefReinterpretCastOpLowering
1030 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1032 Type srcType = castOp.getSource().getType();
1035 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1036 adaptor, &descriptor)))
1038 rewriter.
replaceOp(castOp, {descriptor});
1043 LogicalResult convertSourceMemRefToDescriptor(
1045 memref::ReinterpretCastOp castOp,
1046 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1047 MemRefType targetMemRefType =
1048 cast<MemRefType>(castOp.getResult().getType());
1049 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1051 if (!llvmTargetDescriptorTy)
1059 Value allocatedPtr, alignedPtr;
1060 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1061 castOp.getSource(), adaptor.getSource(),
1062 &allocatedPtr, &alignedPtr);
1063 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1064 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1067 if (castOp.isDynamicOffset(0))
1068 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1070 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1073 unsigned dynSizeId = 0;
1074 unsigned dynStrideId = 0;
1075 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1076 if (castOp.isDynamicSize(i))
1077 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1079 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1081 if (castOp.isDynamicStride(i))
1082 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1084 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1091 struct MemRefReshapeOpLowering
1096 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1098 Type srcType = reshapeOp.getSource().getType();
1101 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1102 adaptor, &descriptor)))
1104 rewriter.
replaceOp(reshapeOp, {descriptor});
1111 Type srcType, memref::ReshapeOp reshapeOp,
1112 memref::ReshapeOp::Adaptor adaptor,
1113 Value *descriptor)
const {
1114 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1115 if (shapeMemRefType.hasStaticShape()) {
1116 MemRefType targetMemRefType =
1117 cast<MemRefType>(reshapeOp.getResult().getType());
1118 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1120 if (!llvmTargetDescriptorTy)
1129 Value allocatedPtr, alignedPtr;
1130 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1131 reshapeOp.getSource(), adaptor.getSource(),
1132 &allocatedPtr, &alignedPtr);
1133 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1134 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1141 reshapeOp,
"failed to get stride and offset exprs");
1143 if (!isStaticStrideOrOffset(offset))
1145 "dynamic offset is unsupported");
1147 desc.setConstantOffset(rewriter, loc, offset);
1149 assert(targetMemRefType.getLayout().isIdentity() &&
1150 "Identity layout map is a precondition of a valid reshape op");
1152 Type indexType = getIndexType();
1153 Value stride =
nullptr;
1154 int64_t targetRank = targetMemRefType.getRank();
1155 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1156 if (!ShapedType::isDynamic(strides[i])) {
1161 }
else if (!stride) {
1171 if (!targetMemRefType.isDynamicDim(i)) {
1173 targetMemRefType.getDimSize(i));
1175 Value shapeOp = reshapeOp.getShape();
1177 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1178 Type indexType = getIndexType();
1179 if (dimSize.
getType() != indexType)
1181 rewriter, loc, indexType, dimSize);
1182 assert(dimSize &&
"Invalid memref element type");
1185 desc.setSize(rewriter, loc, i, dimSize);
1186 desc.setStride(rewriter, loc, i, stride);
1189 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1199 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1202 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1203 unsigned addressSpace =
1204 *getTypeConverter()->getMemRefAddressSpace(targetType);
1209 rewriter, loc, typeConverter->
convertType(targetType));
1210 targetDesc.setRank(rewriter, loc, resultRank);
1213 targetDesc, addressSpace, sizes);
1214 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1217 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1220 Value allocatedPtr, alignedPtr, offset;
1221 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1222 reshapeOp.getSource(), adaptor.getSource(),
1223 &allocatedPtr, &alignedPtr, &offset);
1226 auto elementPtrType =
1230 elementPtrType, allocatedPtr);
1232 underlyingDescPtr, elementPtrType,
1235 underlyingDescPtr, elementPtrType,
1241 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1243 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1244 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1246 Value resultRankMinusOne =
1247 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1250 Type indexType = getTypeConverter()->getIndexType();
1254 {indexType, indexType}, {loc, loc});
1257 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1261 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1270 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1279 loc, llvmIndexPtrType,
1280 typeConverter->
convertType(shapeMemRefType.getElementType()),
1281 shapeOperandPtr, indexArg);
1282 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1284 targetSizesBase, indexArg, size);
1288 targetStridesBase, indexArg, strideArg);
1289 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1292 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1301 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1302 remainder, std::nullopt);
1307 *descriptor = targetDesc;
1314 template <
typename ReshapeOp>
1315 class ReassociatingReshapeOpConversion
1319 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1322 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1326 "reassociation operations should have been expanded beforehand");
1336 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1339 subViewOp,
"subview operations should have been expanded beforehand");
1355 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1357 auto loc = transposeOp.getLoc();
1361 if (transposeOp.getPermutation().isIdentity())
1362 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1366 typeConverter->
convertType(transposeOp.getIn().getType()));
1370 targetMemRef.setAllocatedPtr(rewriter, loc,
1371 viewMemRef.allocatedPtr(rewriter, loc));
1372 targetMemRef.setAlignedPtr(rewriter, loc,
1373 viewMemRef.alignedPtr(rewriter, loc));
1376 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1382 for (
const auto &en :
1384 int targetPos = en.index();
1385 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1386 targetMemRef.setSize(rewriter, loc, targetPos,
1387 viewMemRef.size(rewriter, loc, sourcePos));
1388 targetMemRef.setStride(rewriter, loc, targetPos,
1389 viewMemRef.stride(rewriter, loc, sourcePos));
1392 rewriter.
replaceOp(transposeOp, {targetMemRef});
1409 Type indexType)
const {
1410 assert(idx < shape.size());
1411 if (!ShapedType::isDynamic(shape[idx]))
1415 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1416 return dynamicSizes[nDynamic];
1425 Value runningStride,
unsigned idx,
Type indexType)
const {
1426 assert(idx < strides.size());
1427 if (!ShapedType::isDynamic(strides[idx]))
1430 return runningStride
1431 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1433 assert(!runningStride);
1438 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1440 auto loc = viewOp.getLoc();
1442 auto viewMemRefType = viewOp.getType();
1443 auto targetElementTy =
1444 typeConverter->
convertType(viewMemRefType.getElementType());
1445 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1446 if (!targetDescTy || !targetElementTy ||
1449 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1455 if (failed(successStrides))
1456 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1457 assert(offset == 0 &&
"expected offset to be 0");
1461 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1462 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1470 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1471 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1472 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1475 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1476 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1478 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1479 adaptor.getByteShift());
1481 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1483 Type indexType = getIndexType();
1487 targetMemRef.setOffset(
1492 if (viewMemRefType.getRank() == 0)
1493 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1496 Value stride =
nullptr, nextSize =
nullptr;
1497 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1499 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1500 adaptor.getSizes(), i, indexType);
1501 targetMemRef.setSize(rewriter, loc, i, size);
1504 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1505 targetMemRef.setStride(rewriter, loc, i, stride);
1509 rewriter.
replaceOp(viewOp, {targetMemRef});
1520 static std::optional<LLVM::AtomicBinOp>
1521 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1522 switch (atomicOp.getKind()) {
1523 case arith::AtomicRMWKind::addf:
1524 return LLVM::AtomicBinOp::fadd;
1525 case arith::AtomicRMWKind::addi:
1526 return LLVM::AtomicBinOp::add;
1527 case arith::AtomicRMWKind::assign:
1528 return LLVM::AtomicBinOp::xchg;
1529 case arith::AtomicRMWKind::maximumf:
1530 return LLVM::AtomicBinOp::fmax;
1531 case arith::AtomicRMWKind::maxs:
1533 case arith::AtomicRMWKind::maxu:
1534 return LLVM::AtomicBinOp::umax;
1535 case arith::AtomicRMWKind::minimumf:
1536 return LLVM::AtomicBinOp::fmin;
1537 case arith::AtomicRMWKind::mins:
1539 case arith::AtomicRMWKind::minu:
1540 return LLVM::AtomicBinOp::umin;
1541 case arith::AtomicRMWKind::ori:
1542 return LLVM::AtomicBinOp::_or;
1543 case arith::AtomicRMWKind::andi:
1544 return LLVM::AtomicBinOp::_and;
1546 return std::nullopt;
1548 llvm_unreachable(
"Invalid AtomicRMWKind");
1551 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1555 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1557 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1560 auto memRefType = atomicOp.getMemRefType();
1566 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1567 adaptor.getIndices(), rewriter);
1569 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1570 LLVM::AtomicOrdering::acq_rel);
1576 class ConvertExtractAlignedPointerAsIndex
1583 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1591 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1597 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1600 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1605 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1612 class ExtractStridedMetadataOpLowering
1619 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1628 Location loc = extractStridedMetadataOp.getLoc();
1629 Value source = extractStridedMetadataOp.getSource();
1631 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1632 int64_t rank = sourceMemRefType.getRank();
1634 results.reserve(2 + rank * 2);
1637 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1638 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1640 rewriter, loc, *getTypeConverter(),
1641 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1642 baseBuffer, alignedBuffer);
1643 results.push_back((
Value)dstMemRef);
1646 results.push_back(sourceMemRef.offset(rewriter, loc));
1649 for (
unsigned i = 0; i < rank; ++i)
1650 results.push_back(sourceMemRef.size(rewriter, loc, i));
1652 for (
unsigned i = 0; i < rank; ++i)
1653 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1655 rewriter.
replaceOp(extractStridedMetadataOp, results);
1667 AllocaScopeOpLowering,
1668 AtomicRMWOpLowering,
1669 AssumeAlignmentOpLowering,
1670 ConvertExtractAlignedPointerAsIndex,
1672 ExtractStridedMetadataOpLowering,
1673 GenericAtomicRMWOpLowering,
1674 GlobalMemrefOpLowering,
1675 GetGlobalMemrefOpLowering,
1677 MemRefCastOpLowering,
1678 MemRefCopyOpLowering,
1679 MemorySpaceCastOpLowering,
1680 MemRefReinterpretCastOpLowering,
1681 MemRefReshapeOpLowering,
1684 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1685 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1689 ViewOpLowering>(converter);
1693 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1695 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1699 struct FinalizeMemRefToLLVMConversionPass
1700 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1701 FinalizeMemRefToLLVMConversionPass> {
1702 using FinalizeMemRefToLLVMConversionPassBase::
1703 FinalizeMemRefToLLVMConversionPassBase;
1705 void runOnOperation()
override {
1707 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1709 dataLayoutAnalysis.getAtOrAbove(op));
1714 options.useGenericFunctions = useGenericFunctions;
1717 options.overrideIndexBitwidth(indexBitwidth);
1720 &dataLayoutAnalysis);
1724 target.addLegalOp<func::FuncOp>();
1726 signalPassFailure();
1733 void loadDependentDialects(
MLIRContext *context)
const final {
1734 context->loadDialect<LLVM::LLVMDialect>();
1739 void populateConvertToLLVMConversionPatterns(
1750 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.