27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/Support/MathExtras.h"
32 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
33 #include "mlir/Conversion/Passes.h.inc"
40 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
41 return !ShapedType::isDynamic(strideOrOffset);
44 static FailureOr<LLVM::LLVMFuncOp>
55 static FailureOr<LLVM::LLVMFuncOp>
65 static FailureOr<LLVM::LLVMFuncOp>
83 Value bump = rewriter.
create<LLVM::SubOp>(loc, alignment, one);
84 Value bumped = rewriter.
create<LLVM::AddOp>(loc, input, bump);
85 Value mod = rewriter.
create<LLVM::URemOp>(loc, bumped, alignment);
86 return rewriter.
create<LLVM::SubOp>(loc, bumped, mod);
96 layout = &analysis->getAbove(op);
98 Type elementType = memRefType.getElementType();
99 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
101 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
109 MemRefType memRefType,
Type elementPtrType,
111 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
112 FailureOr<unsigned> maybeMemrefAddrSpace =
114 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
115 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
116 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
117 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
127 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
129 auto loc = op.getLoc();
130 MemRefType memRefType = op.getType();
131 if (!isConvertibleAndHasIdentityMaps(memRefType))
135 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
136 rewriter, getTypeConverter(),
138 if (failed(allocFuncOp))
148 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
149 rewriter, sizes, strides, sizeBytes,
true);
151 Value alignment = getAlignment(rewriter, loc, op);
154 sizeBytes = rewriter.
create<LLVM::AddOp>(loc, sizeBytes, alignment);
159 assert(elementPtrType &&
"could not compute element ptr type");
161 rewriter.
create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
164 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
165 elementPtrType, *getTypeConverter());
166 Value alignedPtr = allocatedPtr;
170 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
172 createAligned(rewriter, loc, allocatedInt, alignment);
174 rewriter.
create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
178 auto memRefDescriptor = this->createMemRefDescriptor(
179 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
182 rewriter.
replaceOp(op, {memRefDescriptor});
187 template <
typename OpType>
190 MemRefType memRefType = op.
getType();
192 if (
auto alignmentAttr = op.getAlignment()) {
193 Type indexType = getIndexType();
196 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
201 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
211 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
213 auto loc = op.getLoc();
214 MemRefType memRefType = op.getType();
215 if (!isConvertibleAndHasIdentityMaps(memRefType))
219 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
220 rewriter, getTypeConverter(),
222 if (failed(allocFuncOp))
232 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
233 rewriter, sizes, strides, sizeBytes, !
false);
235 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
237 Value allocAlignment =
242 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
243 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
246 auto results = rewriter.
create<LLVM::CallOp>(
247 loc, allocFuncOp.value(),
ValueRange({allocAlignment, sizeBytes}));
250 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
251 elementPtrType, *getTypeConverter());
254 auto memRefDescriptor = this->createMemRefDescriptor(
255 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
258 rewriter.
replaceOp(op, {memRefDescriptor});
263 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
270 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
272 if (std::optional<uint64_t> alignment = op.getAlignment())
278 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
279 getTypeConverter(), op.getType(), op, defaultLayout);
280 return std::max(kMinAlignedAllocAlignment,
281 llvm::PowerOf2Ceil(eltSizeBytes));
286 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
288 uint64_t sizeDivisor =
289 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
290 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
291 if (type.isDynamicDim(i))
293 sizeDivisor = sizeDivisor * type.getDimSize(i);
295 return sizeDivisor % factor == 0;
310 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
312 auto loc = op.getLoc();
313 MemRefType memRefType = op.getType();
314 if (!isConvertibleAndHasIdentityMaps(memRefType))
324 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
325 rewriter, sizes, strides, size, !
true);
330 typeConverter->
convertType(op.getType().getElementType());
331 FailureOr<unsigned> maybeAddressSpace =
332 getTypeConverter()->getMemRefAddressSpace(op.getType());
333 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
334 unsigned addrSpace = *maybeAddressSpace;
335 auto elementPtrType =
338 auto allocatedElementPtr = rewriter.
create<LLVM::AllocaOp>(
339 loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
342 auto memRefDescriptor = this->createMemRefDescriptor(
343 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
347 rewriter.
replaceOp(op, {memRefDescriptor});
352 struct AllocaScopeOpLowering
357 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
360 Location loc = allocaScopeOp.getLoc();
365 auto *remainingOpsBlock =
367 Block *continueBlock;
368 if (allocaScopeOp.getNumResults() == 0) {
369 continueBlock = remainingOpsBlock;
372 remainingOpsBlock, allocaScopeOp.getResultTypes(),
374 allocaScopeOp.getLoc()));
379 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
380 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
386 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
393 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
395 returnOp, returnOp.getResults(), continueBlock);
399 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
408 struct AssumeAlignmentOpLowering
416 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
418 Value memref = adaptor.getMemref();
419 unsigned alignment = op.getAlignment();
422 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
423 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, {},
431 Value alignmentConst =
451 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
454 FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
455 rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
456 if (failed(freeFunc))
459 if (
auto unrankedTy =
460 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
462 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
464 rewriter, op.getLoc(),
484 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
486 Type operandType = dimOp.getSource().getType();
487 if (isa<UnrankedMemRefType>(operandType)) {
488 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
489 operandType, dimOp, adaptor.getOperands(), rewriter);
490 if (failed(extractedSize))
492 rewriter.
replaceOp(dimOp, {*extractedSize});
495 if (isa<MemRefType>(operandType)) {
497 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
498 adaptor.getOperands(), rewriter)});
501 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
506 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
511 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
512 auto scalarMemRefType =
514 FailureOr<unsigned> maybeAddressSpace =
515 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
516 if (failed(maybeAddressSpace)) {
517 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
521 unsigned addressSpace = *maybeAddressSpace;
527 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
535 loc, indexPtrTy, elementType, underlyingRankedDesc,
544 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
547 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
551 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
552 if (
auto idx = dimOp.getConstantIndex())
555 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
556 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
561 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
567 MemRefType memRefType = cast<MemRefType>(operandType);
568 Type indexType = getIndexType();
569 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
571 if (i >= 0 && i < memRefType.getRank()) {
572 if (memRefType.isDynamicDim(i)) {
575 return descriptor.size(rewriter, loc, i);
578 int64_t dimSize = memRefType.getDimSize(i);
582 Value index = adaptor.getIndex();
583 int64_t rank = memRefType.getRank();
585 return memrefDescriptor.size(rewriter, loc, index, rank);
592 template <
typename Derived>
596 using Base = LoadStoreOpLowering<Derived>;
626 struct GenericAtomicRMWOpLowering
627 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
631 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
633 auto loc = atomicOp.getLoc();
634 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
646 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
647 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
648 adaptor.getIndices(), rewriter);
650 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
651 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
657 auto loopArgument = loopBlock->getArgument(0);
659 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
669 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
670 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
671 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
672 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
674 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
675 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
679 loopBlock, newLoaded);
684 rewriter.
replaceOp(atomicOp, {newLoaded});
692 convertGlobalMemrefTypeToLLVM(MemRefType type,
700 Type arrayTy = elementType;
702 for (int64_t dim : llvm::reverse(type.getShape()))
708 struct GlobalMemrefOpLowering
713 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
715 MemRefType type = global.getType();
716 if (!isConvertibleAndHasIdentityMaps(type))
719 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
721 LLVM::Linkage linkage =
722 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
725 if (!global.isExternal() && !global.isUninitialized()) {
726 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
727 initialValue = elementsAttr;
731 if (type.getRank() == 0)
732 initialValue = elementsAttr.getSplatValue<
Attribute>();
735 uint64_t alignment = global.getAlignment().value_or(0);
736 FailureOr<unsigned> addressSpace =
737 getTypeConverter()->getMemRefAddressSpace(type);
738 if (failed(addressSpace))
739 return global.emitOpError(
740 "memory space cannot be converted to an integer address space");
742 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
743 initialValue, alignment, *addressSpace);
744 if (!global.isExternal() && global.isUninitialized()) {
745 rewriter.
createBlock(&newGlobal.getInitializerRegion());
747 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
748 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
757 struct GetGlobalMemrefOpLowering
764 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
766 auto loc = op.getLoc();
767 MemRefType memRefType = op.getType();
768 if (!isConvertibleAndHasIdentityMaps(memRefType))
778 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
779 rewriter, sizes, strides, sizeBytes, !
false);
781 MemRefType type = cast<MemRefType>(op.getResult().getType());
785 FailureOr<unsigned> maybeAddressSpace =
786 getTypeConverter()->getMemRefAddressSpace(type);
787 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
788 unsigned memSpace = *maybeAddressSpace;
790 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
793 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
797 auto gep = rewriter.
create<LLVM::GEPOp>(
798 loc, ptrTy, arrayTy, addressOf,
804 auto intPtrType = getIntPtrType(memSpace);
805 Value deadBeefConst =
808 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
813 auto memRefDescriptor = this->createMemRefDescriptor(
814 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
817 rewriter.
replaceOp(op, {memRefDescriptor});
824 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
828 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
830 auto type = loadOp.getMemRefType();
833 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
834 adaptor.getIndices(), rewriter);
836 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
837 false, loadOp.getNontemporal());
844 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
848 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
850 auto type = op.getMemRefType();
852 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
853 adaptor.getIndices(), rewriter);
855 0,
false, op.getNontemporal());
862 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
866 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
868 auto type = prefetchOp.getMemRefType();
869 auto loc = prefetchOp.getLoc();
871 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
872 adaptor.getIndices(), rewriter);
876 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
880 localityHint, isData);
889 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
892 Type operandType = op.getMemref().getType();
893 if (dyn_cast<UnrankedMemRefType>(operandType)) {
895 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
898 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
899 Type indexType = getIndexType();
902 rankedMemRefType.getRank())});
913 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
915 Type srcType = memRefCastOp.getOperand().getType();
916 Type dstType = memRefCastOp.getType();
923 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
929 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
932 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
933 auto loc = memRefCastOp.getLoc();
936 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
937 rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
941 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
946 auto srcMemRefType = cast<MemRefType>(srcType);
947 int64_t rank = srcMemRefType.getRank();
949 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
950 loc, adaptor.getSource(), rewriter);
953 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
959 memRefDesc.setRank(rewriter, loc, rankVal);
961 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
964 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
970 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
973 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
974 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
976 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
992 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
994 auto loc = op.getLoc();
995 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1000 Value numElements = rewriter.
create<LLVM::ConstantOp>(
1002 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1003 auto size = srcDesc.size(rewriter, loc, pos);
1004 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
1008 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1011 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
1013 Type elementType = typeConverter->
convertType(srcType.getElementType());
1015 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1016 Value srcOffset = srcDesc.offset(rewriter, loc);
1018 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
1020 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1021 Value targetOffset = targetDesc.offset(rewriter, loc);
1023 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
1024 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
1032 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1034 auto loc = op.getLoc();
1035 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1036 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1039 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1040 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1042 auto *typeConverter = getTypeConverter();
1044 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1049 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
1054 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1056 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1057 Value unrankedSource =
1058 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1059 : adaptor.getSource();
1060 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1061 Value unrankedTarget =
1062 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1063 : adaptor.getTarget();
1066 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1071 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1072 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
1076 auto sourcePtr =
promote(unrankedSource);
1077 auto targetPtr =
promote(unrankedTarget);
1081 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1083 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1084 sourcePtr.getType());
1087 rewriter.
create<LLVM::CallOp>(loc, copyFn.value(),
1091 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1099 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1101 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1102 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1105 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1109 return memrefType &&
1110 (memrefType.getLayout().isIdentity() ||
1111 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1115 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1116 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1118 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1122 struct MemorySpaceCastOpLowering
1128 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1132 Type resultType = op.getDest().getType();
1133 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1134 auto resultDescType =
1135 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1136 Type newPtrType = resultDescType.getBody()[0];
1142 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1144 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1146 resultTypeR, descVals);
1150 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1153 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1154 FailureOr<unsigned> maybeSourceAddrSpace =
1155 getTypeConverter()->getMemRefAddressSpace(sourceType);
1156 if (failed(maybeSourceAddrSpace))
1158 "non-integer source address space");
1159 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1160 FailureOr<unsigned> maybeResultAddrSpace =
1161 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1162 if (failed(maybeResultAddrSpace))
1164 "non-integer result address space");
1165 unsigned resultAddrSpace = *maybeResultAddrSpace;
1168 Value rank = sourceDesc.rank(rewriter, loc);
1169 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1173 rewriter, loc, typeConverter->convertType(resultTypeU));
1174 result.setRank(rewriter, loc, rank);
1177 result, resultAddrSpace, sizes);
1178 Value resultUnderlyingSize = sizes.front();
1179 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
1180 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
1181 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1184 auto sourceElemPtrType =
1186 auto resultElemPtrType =
1189 Value allocatedPtr = sourceDesc.allocatedPtr(
1190 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1192 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1193 sourceUnderlyingDesc, sourceElemPtrType);
1194 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1195 loc, resultElemPtrType, allocatedPtr);
1196 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1197 loc, resultElemPtrType, alignedPtr);
1199 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1200 resultElemPtrType, allocatedPtr);
1201 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1202 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1205 Value sourceIndexVals =
1206 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1207 sourceUnderlyingDesc, sourceElemPtrType);
1208 Value resultIndexVals =
1209 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1210 resultUnderlyingDesc, resultElemPtrType);
1212 int64_t bytesToSkip =
1214 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1215 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
1216 loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
1218 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1219 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1232 static void extractPointersAndOffset(
Location loc,
1235 Value originalOperand,
1236 Value convertedOperand,
1238 Value *offset =
nullptr) {
1240 if (isa<MemRefType>(operandType)) {
1242 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1243 *alignedPtr = desc.alignedPtr(rewriter, loc);
1244 if (offset !=
nullptr)
1245 *offset = desc.offset(rewriter, loc);
1251 cast<UnrankedMemRefType>(operandType));
1252 auto elementPtrType =
1258 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1261 rewriter, loc, underlyingDescPtr, elementPtrType);
1263 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1264 if (offset !=
nullptr) {
1266 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1270 struct MemRefReinterpretCastOpLowering
1276 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1278 Type srcType = castOp.getSource().getType();
1281 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1282 adaptor, &descriptor)))
1284 rewriter.
replaceOp(castOp, {descriptor});
1289 LogicalResult convertSourceMemRefToDescriptor(
1291 memref::ReinterpretCastOp castOp,
1292 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1293 MemRefType targetMemRefType =
1294 cast<MemRefType>(castOp.getResult().getType());
1295 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1297 if (!llvmTargetDescriptorTy)
1305 Value allocatedPtr, alignedPtr;
1306 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1307 castOp.getSource(), adaptor.getSource(),
1308 &allocatedPtr, &alignedPtr);
1309 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1310 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1313 if (castOp.isDynamicOffset(0))
1314 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1316 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1319 unsigned dynSizeId = 0;
1320 unsigned dynStrideId = 0;
1321 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1322 if (castOp.isDynamicSize(i))
1323 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1325 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1327 if (castOp.isDynamicStride(i))
1328 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1330 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1337 struct MemRefReshapeOpLowering
1342 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1344 Type srcType = reshapeOp.getSource().getType();
1347 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1348 adaptor, &descriptor)))
1350 rewriter.
replaceOp(reshapeOp, {descriptor});
1357 Type srcType, memref::ReshapeOp reshapeOp,
1358 memref::ReshapeOp::Adaptor adaptor,
1359 Value *descriptor)
const {
1360 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1361 if (shapeMemRefType.hasStaticShape()) {
1362 MemRefType targetMemRefType =
1363 cast<MemRefType>(reshapeOp.getResult().getType());
1364 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1366 if (!llvmTargetDescriptorTy)
1375 Value allocatedPtr, alignedPtr;
1376 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1377 reshapeOp.getSource(), adaptor.getSource(),
1378 &allocatedPtr, &alignedPtr);
1379 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1380 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1385 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1387 reshapeOp,
"failed to get stride and offset exprs");
1389 if (!isStaticStrideOrOffset(offset))
1391 "dynamic offset is unsupported");
1393 desc.setConstantOffset(rewriter, loc, offset);
1395 assert(targetMemRefType.getLayout().isIdentity() &&
1396 "Identity layout map is a precondition of a valid reshape op");
1398 Type indexType = getIndexType();
1399 Value stride =
nullptr;
1400 int64_t targetRank = targetMemRefType.getRank();
1401 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1402 if (!ShapedType::isDynamic(strides[i])) {
1407 }
else if (!stride) {
1417 if (!targetMemRefType.isDynamicDim(i)) {
1419 targetMemRefType.getDimSize(i));
1421 Value shapeOp = reshapeOp.getShape();
1423 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1424 Type indexType = getIndexType();
1425 if (dimSize.
getType() != indexType)
1427 rewriter, loc, indexType, dimSize);
1428 assert(dimSize &&
"Invalid memref element type");
1431 desc.setSize(rewriter, loc, i, dimSize);
1432 desc.setStride(rewriter, loc, i, stride);
1435 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1445 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1448 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1449 unsigned addressSpace =
1450 *getTypeConverter()->getMemRefAddressSpace(targetType);
1455 rewriter, loc, typeConverter->
convertType(targetType));
1456 targetDesc.setRank(rewriter, loc, resultRank);
1459 targetDesc, addressSpace, sizes);
1460 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1463 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1466 Value allocatedPtr, alignedPtr, offset;
1467 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1468 reshapeOp.getSource(), adaptor.getSource(),
1469 &allocatedPtr, &alignedPtr, &offset);
1472 auto elementPtrType =
1476 elementPtrType, allocatedPtr);
1478 underlyingDescPtr, elementPtrType,
1481 underlyingDescPtr, elementPtrType,
1487 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1489 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1490 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1492 Value resultRankMinusOne =
1493 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1496 Type indexType = getTypeConverter()->getIndexType();
1500 {indexType, indexType}, {loc, loc});
1503 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1507 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1516 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1525 loc, llvmIndexPtrType,
1526 typeConverter->
convertType(shapeMemRefType.getElementType()),
1527 shapeOperandPtr, indexArg);
1528 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1530 targetSizesBase, indexArg, size);
1534 targetStridesBase, indexArg, strideArg);
1535 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1538 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1547 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1548 remainder, std::nullopt);
1553 *descriptor = targetDesc;
1560 template <
typename ReshapeOp>
1561 class ReassociatingReshapeOpConversion
1565 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1568 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1572 "reassociation operations should have been expanded beforehand");
1582 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1585 subViewOp,
"subview operations should have been expanded beforehand");
1601 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1603 auto loc = transposeOp.getLoc();
1607 if (transposeOp.getPermutation().isIdentity())
1608 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1612 typeConverter->
convertType(transposeOp.getIn().getType()));
1616 targetMemRef.setAllocatedPtr(rewriter, loc,
1617 viewMemRef.allocatedPtr(rewriter, loc));
1618 targetMemRef.setAlignedPtr(rewriter, loc,
1619 viewMemRef.alignedPtr(rewriter, loc));
1622 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1628 for (
const auto &en :
1630 int targetPos = en.index();
1631 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1632 targetMemRef.setSize(rewriter, loc, targetPos,
1633 viewMemRef.size(rewriter, loc, sourcePos));
1634 targetMemRef.setStride(rewriter, loc, targetPos,
1635 viewMemRef.stride(rewriter, loc, sourcePos));
1638 rewriter.
replaceOp(transposeOp, {targetMemRef});
1655 Type indexType)
const {
1656 assert(idx < shape.size());
1657 if (!ShapedType::isDynamic(shape[idx]))
1661 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1662 return dynamicSizes[nDynamic];
1671 Value runningStride,
unsigned idx,
Type indexType)
const {
1672 assert(idx < strides.size());
1673 if (!ShapedType::isDynamic(strides[idx]))
1676 return runningStride
1677 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1679 assert(!runningStride);
1684 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1686 auto loc = viewOp.getLoc();
1688 auto viewMemRefType = viewOp.getType();
1689 auto targetElementTy =
1690 typeConverter->
convertType(viewMemRefType.getElementType());
1691 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1692 if (!targetDescTy || !targetElementTy ||
1695 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1700 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1701 if (failed(successStrides))
1702 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1703 assert(offset == 0 &&
"expected offset to be 0");
1707 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1708 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1716 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1717 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1718 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1721 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1722 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1724 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1725 adaptor.getByteShift());
1727 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1729 Type indexType = getIndexType();
1733 targetMemRef.setOffset(
1738 if (viewMemRefType.getRank() == 0)
1739 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1742 Value stride =
nullptr, nextSize =
nullptr;
1743 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1745 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1746 adaptor.getSizes(), i, indexType);
1747 targetMemRef.setSize(rewriter, loc, i, size);
1750 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1751 targetMemRef.setStride(rewriter, loc, i, stride);
1755 rewriter.
replaceOp(viewOp, {targetMemRef});
1766 static std::optional<LLVM::AtomicBinOp>
1767 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1768 switch (atomicOp.getKind()) {
1769 case arith::AtomicRMWKind::addf:
1770 return LLVM::AtomicBinOp::fadd;
1771 case arith::AtomicRMWKind::addi:
1772 return LLVM::AtomicBinOp::add;
1773 case arith::AtomicRMWKind::assign:
1774 return LLVM::AtomicBinOp::xchg;
1775 case arith::AtomicRMWKind::maximumf:
1776 return LLVM::AtomicBinOp::fmax;
1777 case arith::AtomicRMWKind::maxs:
1779 case arith::AtomicRMWKind::maxu:
1780 return LLVM::AtomicBinOp::umax;
1781 case arith::AtomicRMWKind::minimumf:
1782 return LLVM::AtomicBinOp::fmin;
1783 case arith::AtomicRMWKind::mins:
1785 case arith::AtomicRMWKind::minu:
1786 return LLVM::AtomicBinOp::umin;
1787 case arith::AtomicRMWKind::ori:
1788 return LLVM::AtomicBinOp::_or;
1789 case arith::AtomicRMWKind::andi:
1790 return LLVM::AtomicBinOp::_and;
1792 return std::nullopt;
1794 llvm_unreachable(
"Invalid AtomicRMWKind");
1797 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1801 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1803 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1806 auto memRefType = atomicOp.getMemRefType();
1809 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1812 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1813 adaptor.getIndices(), rewriter);
1815 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1816 LLVM::AtomicOrdering::acq_rel);
1822 class ConvertExtractAlignedPointerAsIndex
1829 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1837 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1843 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1846 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1851 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1858 class ExtractStridedMetadataOpLowering
1865 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1874 Location loc = extractStridedMetadataOp.getLoc();
1875 Value source = extractStridedMetadataOp.getSource();
1877 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1878 int64_t rank = sourceMemRefType.getRank();
1880 results.reserve(2 + rank * 2);
1883 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1884 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1886 rewriter, loc, *getTypeConverter(),
1887 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1888 baseBuffer, alignedBuffer);
1889 results.push_back((
Value)dstMemRef);
1892 results.push_back(sourceMemRef.offset(rewriter, loc));
1895 for (
unsigned i = 0; i < rank; ++i)
1896 results.push_back(sourceMemRef.size(rewriter, loc, i));
1898 for (
unsigned i = 0; i < rank; ++i)
1899 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1901 rewriter.
replaceOp(extractStridedMetadataOp, results);
1913 AllocaScopeOpLowering,
1914 AtomicRMWOpLowering,
1915 AssumeAlignmentOpLowering,
1916 ConvertExtractAlignedPointerAsIndex,
1918 ExtractStridedMetadataOpLowering,
1919 GenericAtomicRMWOpLowering,
1920 GlobalMemrefOpLowering,
1921 GetGlobalMemrefOpLowering,
1923 MemRefCastOpLowering,
1924 MemRefCopyOpLowering,
1925 MemorySpaceCastOpLowering,
1926 MemRefReinterpretCastOpLowering,
1927 MemRefReshapeOpLowering,
1930 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1931 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1935 ViewOpLowering>(converter);
1939 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1941 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1945 struct FinalizeMemRefToLLVMConversionPass
1946 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1947 FinalizeMemRefToLLVMConversionPass> {
1948 using FinalizeMemRefToLLVMConversionPassBase::
1949 FinalizeMemRefToLLVMConversionPassBase;
1951 void runOnOperation()
override {
1953 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1955 dataLayoutAnalysis.getAtOrAbove(op));
1960 options.useGenericFunctions = useGenericFunctions;
1963 options.overrideIndexBitwidth(indexBitwidth);
1966 &dataLayoutAnalysis);
1970 target.addLegalOp<func::FuncOp>();
1972 signalPassFailure();
1979 void loadDependentDialects(
MLIRContext *context)
const final {
1980 context->loadDialect<LLVM::LLVMDialect>();
1985 void populateConvertToLLVMConversionPatterns(
1996 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
result_range getResults()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.