29 #include "llvm/ADT/SmallBitVector.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
81 return std::make_tuple(ptr, ptr);
93 setRequiresNumElements();
105 auto allocaOp = cast<memref::AllocaOp>(op);
107 typeConverter->
convertType(allocaOp.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
113 auto allocatedElementPtr =
114 rewriter.
create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
129 Location loc = allocaScopeOp.getLoc();
134 auto *remainingOpsBlock =
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
143 allocaScopeOp.getLoc()));
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
155 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
162 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
164 returnOp, returnOp.getResults(), continueBlock);
168 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
177 struct AssumeAlignmentOpLowering
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, {},
202 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
206 Value ptrValue = rewriter.
create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207 rewriter.
create<LLVM::AssumeOp>(
208 loc, rewriter.
create<LLVM::ICmpOp>(
209 loc, LLVM::ICmpPredicate::eq,
210 rewriter.
create<LLVM::AndOp>(loc, ptrValue, mask), zero));
227 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
230 LLVM::LLVMFuncOp freeFunc =
233 if (
auto unrankedTy =
234 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
236 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
257 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
259 Type operandType = dimOp.getSource().getType();
260 if (isa<UnrankedMemRefType>(operandType)) {
262 operandType, dimOp, adaptor.getOperands(), rewriter);
263 if (
failed(extractedSize))
265 rewriter.
replaceOp(dimOp, {*extractedSize});
268 if (isa<MemRefType>(operandType)) {
270 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271 adaptor.getOperands(), rewriter)});
274 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
279 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
284 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285 auto scalarMemRefType =
288 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289 if (
failed(maybeAddressSpace)) {
290 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
294 unsigned addressSpace = *maybeAddressSpace;
300 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
308 loc, indexPtrTy, elementType, underlyingRankedDesc,
317 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
320 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
324 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
325 if (
auto idx = dimOp.getConstantIndex())
328 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
334 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
340 MemRefType memRefType = cast<MemRefType>(operandType);
342 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
344 if (i >= 0 && i < memRefType.getRank()) {
345 if (memRefType.isDynamicDim(i)) {
348 return descriptor.size(rewriter, loc, i);
351 int64_t dimSize = memRefType.getDimSize(i);
355 Value index = adaptor.getIndex();
356 int64_t rank = memRefType.getRank();
358 return memrefDescriptor.size(rewriter, loc, index, rank);
365 template <
typename Derived>
369 using Base = LoadStoreOpLowering<Derived>;
372 MemRefType type = op.getMemRefType();
373 return isConvertibleAndHasIdentityMaps(type) ?
success() :
failure();
404 struct GenericAtomicRMWOpLowering
405 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
409 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
411 auto loc = atomicOp.getLoc();
412 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
424 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426 adaptor.getIndices(), rewriter);
428 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
429 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
435 auto loopArgument = loopBlock->getArgument(0);
437 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
447 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
450 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
452 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
457 loopBlock, newLoaded);
462 rewriter.
replaceOp(atomicOp, {newLoaded});
470 convertGlobalMemrefTypeToLLVM(MemRefType type,
478 Type arrayTy = elementType;
480 for (int64_t dim : llvm::reverse(type.getShape()))
486 struct GlobalMemrefOpLowering
491 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
493 MemRefType type = global.getType();
494 if (!isConvertibleAndHasIdentityMaps(type))
497 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
499 LLVM::Linkage linkage =
500 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
503 if (!global.isExternal() && !global.isUninitialized()) {
504 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505 initialValue = elementsAttr;
509 if (type.getRank() == 0)
510 initialValue = elementsAttr.getSplatValue<
Attribute>();
513 uint64_t alignment = global.getAlignment().value_or(0);
515 getTypeConverter()->getMemRefAddressSpace(type);
517 return global.emitOpError(
518 "memory space cannot be converted to an integer address space");
520 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521 initialValue, alignment, *addressSpace);
522 if (!global.isExternal() && global.isUninitialized()) {
524 newGlobal.getInitializerRegion().push_back(blk);
527 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
528 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
547 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
548 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
553 getTypeConverter()->getMemRefAddressSpace(type);
554 if (
failed(maybeAddressSpace))
556 unsigned memSpace = *maybeAddressSpace;
558 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
561 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
565 auto gep = rewriter.
create<LLVM::GEPOp>(
566 loc, ptrTy, arrayTy, addressOf,
572 auto intPtrType = getIntPtrType(memSpace);
573 Value deadBeefConst =
576 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
580 return std::make_tuple(deadBeefPtr, gep);
586 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
590 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
592 auto type = loadOp.getMemRefType();
595 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
596 adaptor.getIndices(), rewriter);
598 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
599 false, loadOp.getNontemporal());
606 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
610 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
612 auto type = op.getMemRefType();
614 Value dataPtr = getStridedElementPtr(op.
getLoc(), type, adaptor.getMemref(),
615 adaptor.getIndices(), rewriter);
617 0,
false, op.getNontemporal());
624 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
628 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
630 auto type = prefetchOp.getMemRefType();
631 auto loc = prefetchOp.getLoc();
633 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
634 adaptor.getIndices(), rewriter);
638 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
642 localityHint, isData);
651 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
654 Type operandType = op.getMemref().getType();
655 if (dyn_cast<UnrankedMemRefType>(operandType)) {
657 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
660 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
664 rankedMemRefType.getRank())});
674 LogicalResult match(memref::CastOp memRefCastOp)
const override {
675 Type srcType = memRefCastOp.getOperand().getType();
676 Type dstType = memRefCastOp.getType();
683 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
688 assert(isa<UnrankedMemRefType>(srcType) ||
689 isa<UnrankedMemRefType>(dstType));
692 return !(isa<UnrankedMemRefType>(srcType) &&
693 isa<UnrankedMemRefType>(dstType))
698 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
700 auto srcType = memRefCastOp.getOperand().getType();
701 auto dstType = memRefCastOp.getType();
702 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
703 auto loc = memRefCastOp.getLoc();
706 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
707 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
709 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
714 auto srcMemRefType = cast<MemRefType>(srcType);
715 int64_t rank = srcMemRefType.getRank();
717 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
718 loc, adaptor.getSource(), rewriter);
721 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
727 memRefDesc.setRank(rewriter, loc, rankVal);
729 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
732 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
738 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
741 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
742 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
744 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
758 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
761 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
766 Value numElements = rewriter.
create<LLVM::ConstantOp>(
768 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
769 auto size = srcDesc.size(rewriter, loc, pos);
770 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
774 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
777 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
779 Type elementType = typeConverter->
convertType(srcType.getElementType());
781 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
782 Value srcOffset = srcDesc.offset(rewriter, loc);
784 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
786 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
787 Value targetOffset = targetDesc.offset(rewriter, loc);
789 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
790 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
798 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
801 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
802 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
805 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
808 auto *typeConverter = getTypeConverter();
810 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
815 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
820 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
822 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
823 Value unrankedSource =
824 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
825 : adaptor.getSource();
826 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
827 Value unrankedTarget =
828 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
829 : adaptor.getTarget();
837 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
838 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
842 auto sourcePtr =
promote(unrankedSource);
843 auto targetPtr =
promote(unrankedTarget);
847 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
850 rewriter.
create<LLVM::CallOp>(loc, copyFn,
854 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
862 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
864 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
865 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
868 auto memrefType = dyn_cast<mlir::MemRefType>(type);
873 (memrefType.getLayout().isIdentity() ||
874 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
878 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
879 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
881 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
885 struct MemorySpaceCastOpLowering
891 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
895 Type resultType = op.getDest().getType();
896 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
897 auto resultDescType =
898 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
899 Type newPtrType = resultDescType.getBody()[0];
905 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
907 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
909 resultTypeR, descVals);
913 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
916 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
918 getTypeConverter()->getMemRefAddressSpace(sourceType);
919 if (
failed(maybeSourceAddrSpace))
921 "non-integer source address space");
922 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
924 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
925 if (
failed(maybeResultAddrSpace))
927 "non-integer result address space");
928 unsigned resultAddrSpace = *maybeResultAddrSpace;
931 Value rank = sourceDesc.rank(rewriter, loc);
932 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
936 rewriter, loc, typeConverter->convertType(resultTypeU));
937 result.setRank(rewriter, loc, rank);
940 result, resultAddrSpace, sizes);
941 Value resultUnderlyingSize = sizes.front();
942 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
943 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
944 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
947 auto sourceElemPtrType =
949 auto resultElemPtrType =
952 Value allocatedPtr = sourceDesc.allocatedPtr(
953 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
955 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
956 sourceUnderlyingDesc, sourceElemPtrType);
957 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
958 loc, resultElemPtrType, allocatedPtr);
959 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
960 loc, resultElemPtrType, alignedPtr);
962 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
963 resultElemPtrType, allocatedPtr);
964 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
965 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
968 Value sourceIndexVals =
969 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
970 sourceUnderlyingDesc, sourceElemPtrType);
971 Value resultIndexVals =
972 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
973 resultUnderlyingDesc, resultElemPtrType);
975 int64_t bytesToSkip =
977 ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
978 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
981 loc,
getIndexType(), resultUnderlyingSize, bytesToSkipConst);
982 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
995 static void extractPointersAndOffset(
Location loc,
998 Value originalOperand,
999 Value convertedOperand,
1001 Value *offset =
nullptr) {
1003 if (isa<MemRefType>(operandType)) {
1005 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1006 *alignedPtr = desc.alignedPtr(rewriter, loc);
1007 if (offset !=
nullptr)
1008 *offset = desc.offset(rewriter, loc);
1014 cast<UnrankedMemRefType>(operandType));
1015 auto elementPtrType =
1021 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1024 rewriter, loc, underlyingDescPtr, elementPtrType);
1026 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1027 if (offset !=
nullptr) {
1029 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1033 struct MemRefReinterpretCastOpLowering
1039 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1041 Type srcType = castOp.getSource().getType();
1044 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1045 adaptor, &descriptor)))
1047 rewriter.
replaceOp(castOp, {descriptor});
1054 memref::ReinterpretCastOp castOp,
1055 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1056 MemRefType targetMemRefType =
1057 cast<MemRefType>(castOp.getResult().getType());
1058 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1060 if (!llvmTargetDescriptorTy)
1068 Value allocatedPtr, alignedPtr;
1069 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1070 castOp.getSource(), adaptor.getSource(),
1071 &allocatedPtr, &alignedPtr);
1072 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1073 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1076 if (castOp.isDynamicOffset(0))
1077 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1079 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1082 unsigned dynSizeId = 0;
1083 unsigned dynStrideId = 0;
1084 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1085 if (castOp.isDynamicSize(i))
1086 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1088 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1090 if (castOp.isDynamicStride(i))
1091 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1093 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1100 struct MemRefReshapeOpLowering
1105 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1107 Type srcType = reshapeOp.getSource().getType();
1110 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1111 adaptor, &descriptor)))
1113 rewriter.
replaceOp(reshapeOp, {descriptor});
1120 Type srcType, memref::ReshapeOp reshapeOp,
1121 memref::ReshapeOp::Adaptor adaptor,
1122 Value *descriptor)
const {
1123 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1124 if (shapeMemRefType.hasStaticShape()) {
1125 MemRefType targetMemRefType =
1126 cast<MemRefType>(reshapeOp.getResult().getType());
1127 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1129 if (!llvmTargetDescriptorTy)
1138 Value allocatedPtr, alignedPtr;
1139 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1140 reshapeOp.getSource(), adaptor.getSource(),
1141 &allocatedPtr, &alignedPtr);
1142 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1143 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1150 reshapeOp,
"failed to get stride and offset exprs");
1152 if (!isStaticStrideOrOffset(offset))
1154 "dynamic offset is unsupported");
1156 desc.setConstantOffset(rewriter, loc, offset);
1158 assert(targetMemRefType.getLayout().isIdentity() &&
1159 "Identity layout map is a precondition of a valid reshape op");
1162 Value stride =
nullptr;
1163 int64_t targetRank = targetMemRefType.getRank();
1164 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1165 if (!ShapedType::isDynamic(strides[i])) {
1170 }
else if (!stride) {
1180 if (!targetMemRefType.isDynamicDim(i)) {
1182 targetMemRefType.getDimSize(i));
1184 Value shapeOp = reshapeOp.getShape();
1186 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1188 if (dimSize.
getType() != indexType)
1190 rewriter, loc, indexType, dimSize);
1191 assert(dimSize &&
"Invalid memref element type");
1194 desc.setSize(rewriter, loc, i, dimSize);
1195 desc.setStride(rewriter, loc, i, stride);
1198 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1208 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1211 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1212 unsigned addressSpace =
1213 *getTypeConverter()->getMemRefAddressSpace(targetType);
1218 rewriter, loc, typeConverter->
convertType(targetType));
1219 targetDesc.setRank(rewriter, loc, resultRank);
1222 targetDesc, addressSpace, sizes);
1223 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1226 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1229 Value allocatedPtr, alignedPtr, offset;
1230 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1231 reshapeOp.getSource(), adaptor.getSource(),
1232 &allocatedPtr, &alignedPtr, &offset);
1235 auto elementPtrType =
1239 elementPtrType, allocatedPtr);
1241 underlyingDescPtr, elementPtrType,
1244 underlyingDescPtr, elementPtrType,
1250 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1252 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1253 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1255 Value resultRankMinusOne =
1256 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1259 Type indexType = getTypeConverter()->getIndexType();
1263 {indexType, indexType}, {loc, loc});
1266 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1270 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1279 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1288 loc, llvmIndexPtrType,
1289 typeConverter->
convertType(shapeMemRefType.getElementType()),
1290 shapeOperandPtr, indexArg);
1291 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1293 targetSizesBase, indexArg, size);
1297 targetStridesBase, indexArg, strideArg);
1298 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1301 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1310 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1311 remainder, std::nullopt);
1316 *descriptor = targetDesc;
1323 template <
typename ReshapeOp>
1324 class ReassociatingReshapeOpConversion
1328 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1331 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1335 "reassociation operations should have been expanded beforehand");
1345 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1348 subViewOp,
"subview operations should have been expanded beforehand");
1364 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1366 auto loc = transposeOp.getLoc();
1370 if (transposeOp.getPermutation().isIdentity())
1375 typeConverter->
convertType(transposeOp.getIn().getType()));
1379 targetMemRef.setAllocatedPtr(rewriter, loc,
1380 viewMemRef.allocatedPtr(rewriter, loc));
1381 targetMemRef.setAlignedPtr(rewriter, loc,
1382 viewMemRef.alignedPtr(rewriter, loc));
1385 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1391 for (
const auto &en :
1393 int targetPos = en.index();
1394 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1395 targetMemRef.setSize(rewriter, loc, targetPos,
1396 viewMemRef.size(rewriter, loc, sourcePos));
1397 targetMemRef.setStride(rewriter, loc, targetPos,
1398 viewMemRef.stride(rewriter, loc, sourcePos));
1401 rewriter.
replaceOp(transposeOp, {targetMemRef});
1418 Type indexType)
const {
1419 assert(idx < shape.size());
1420 if (!ShapedType::isDynamic(shape[idx]))
1424 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1425 return dynamicSizes[nDynamic];
1434 Value runningStride,
unsigned idx,
Type indexType)
const {
1435 assert(idx < strides.size());
1436 if (!ShapedType::isDynamic(strides[idx]))
1439 return runningStride
1440 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1442 assert(!runningStride);
1447 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1449 auto loc = viewOp.getLoc();
1451 auto viewMemRefType = viewOp.getType();
1452 auto targetElementTy =
1453 typeConverter->
convertType(viewMemRefType.getElementType());
1454 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1455 if (!targetDescTy || !targetElementTy ||
1458 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1464 if (
failed(successStrides))
1465 return viewOp.emitWarning(
"cannot cast to non-strided shape"),
failure();
1466 assert(offset == 0 &&
"expected offset to be 0");
1470 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1471 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1479 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1480 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1481 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1484 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1485 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1487 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1488 adaptor.getByteShift());
1490 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1496 targetMemRef.setOffset(
1501 if (viewMemRefType.getRank() == 0)
1505 Value stride =
nullptr, nextSize =
nullptr;
1506 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1508 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1509 adaptor.getSizes(), i, indexType);
1510 targetMemRef.setSize(rewriter, loc, i, size);
1513 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1514 targetMemRef.setStride(rewriter, loc, i, stride);
1518 rewriter.
replaceOp(viewOp, {targetMemRef});
1529 static std::optional<LLVM::AtomicBinOp>
1530 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1531 switch (atomicOp.getKind()) {
1532 case arith::AtomicRMWKind::addf:
1533 return LLVM::AtomicBinOp::fadd;
1534 case arith::AtomicRMWKind::addi:
1535 return LLVM::AtomicBinOp::add;
1536 case arith::AtomicRMWKind::assign:
1537 return LLVM::AtomicBinOp::xchg;
1538 case arith::AtomicRMWKind::maximumf:
1539 return LLVM::AtomicBinOp::fmax;
1540 case arith::AtomicRMWKind::maxs:
1542 case arith::AtomicRMWKind::maxu:
1543 return LLVM::AtomicBinOp::umax;
1544 case arith::AtomicRMWKind::minimumf:
1545 return LLVM::AtomicBinOp::fmin;
1546 case arith::AtomicRMWKind::mins:
1548 case arith::AtomicRMWKind::minu:
1549 return LLVM::AtomicBinOp::umin;
1550 case arith::AtomicRMWKind::ori:
1551 return LLVM::AtomicBinOp::_or;
1552 case arith::AtomicRMWKind::andi:
1553 return LLVM::AtomicBinOp::_and;
1555 return std::nullopt;
1557 llvm_unreachable(
"Invalid AtomicRMWKind");
1560 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1564 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1566 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1569 auto memRefType = atomicOp.getMemRefType();
1575 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1576 adaptor.getIndices(), rewriter);
1578 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1579 LLVM::AtomicOrdering::acq_rel);
1585 class ConvertExtractAlignedPointerAsIndex
1592 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1597 extractOp, getTypeConverter()->getIndexType(),
1598 desc.alignedPtr(rewriter, extractOp->getLoc()));
1605 class ExtractStridedMetadataOpLowering
1612 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1621 Location loc = extractStridedMetadataOp.getLoc();
1622 Value source = extractStridedMetadataOp.getSource();
1624 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1625 int64_t rank = sourceMemRefType.getRank();
1627 results.reserve(2 + rank * 2);
1630 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1631 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1633 rewriter, loc, *getTypeConverter(),
1634 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1635 baseBuffer, alignedBuffer);
1636 results.push_back((
Value)dstMemRef);
1639 results.push_back(sourceMemRef.offset(rewriter, loc));
1642 for (
unsigned i = 0; i < rank; ++i)
1643 results.push_back(sourceMemRef.size(rewriter, loc, i));
1645 for (
unsigned i = 0; i < rank; ++i)
1646 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1648 rewriter.
replaceOp(extractStridedMetadataOp, results);
1660 AllocaScopeOpLowering,
1661 AtomicRMWOpLowering,
1662 AssumeAlignmentOpLowering,
1663 ConvertExtractAlignedPointerAsIndex,
1665 ExtractStridedMetadataOpLowering,
1666 GenericAtomicRMWOpLowering,
1667 GlobalMemrefOpLowering,
1668 GetGlobalMemrefOpLowering,
1670 MemRefCastOpLowering,
1671 MemRefCopyOpLowering,
1672 MemorySpaceCastOpLowering,
1673 MemRefReinterpretCastOpLowering,
1674 MemRefReshapeOpLowering,
1677 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1678 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1682 ViewOpLowering>(converter);
1686 patterns.
add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1688 patterns.
add<AllocOpLowering, DeallocOpLowering>(converter);
1692 struct FinalizeMemRefToLLVMConversionPass
1693 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1694 FinalizeMemRefToLLVMConversionPass> {
1695 using FinalizeMemRefToLLVMConversionPassBase::
1696 FinalizeMemRefToLLVMConversionPassBase;
1698 void runOnOperation()
override {
1700 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1702 dataLayoutAnalysis.getAtOrAbove(op));
1707 options.useGenericFunctions = useGenericFunctions;
1710 options.overrideIndexBitwidth(indexBitwidth);
1713 &dataLayoutAnalysis);
1717 target.addLegalOp<func::FuncOp>();
1719 signalPassFailure();
1726 void loadDependentDialects(
MLIRContext *context)
const final {
1727 context->loadDialect<LLVM::LLVMDialect>();
1732 void populateConvertToLLVMConversionPatterns(
1743 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
result_range getResults()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Lowering for AllocOp and AllocaOp.
This class represents an efficient way to signal success or failure.