27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/Support/MathExtras.h"
32 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
33 #include "mlir/Conversion/Passes.h.inc"
39 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
44 return !ShapedType::isDynamic(strideOrOffset);
47 static FailureOr<LLVM::LLVMFuncOp>
58 static FailureOr<LLVM::LLVMFuncOp>
68 static FailureOr<LLVM::LLVMFuncOp>
86 Value bump = rewriter.
create<LLVM::SubOp>(loc, alignment, one);
87 Value bumped = rewriter.
create<LLVM::AddOp>(loc, input, bump);
88 Value mod = rewriter.
create<LLVM::URemOp>(loc, bumped, alignment);
89 return rewriter.
create<LLVM::SubOp>(loc, bumped, mod);
99 layout = &analysis->getAbove(op);
101 Type elementType = memRefType.getElementType();
102 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
104 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112 MemRefType memRefType,
Type elementPtrType,
114 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
115 FailureOr<unsigned> maybeMemrefAddrSpace =
117 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
118 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
119 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
120 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
130 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
132 auto loc = op.getLoc();
133 MemRefType memRefType = op.getType();
134 if (!isConvertibleAndHasIdentityMaps(memRefType))
138 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
139 rewriter, getTypeConverter(),
141 if (failed(allocFuncOp))
151 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
152 rewriter, sizes, strides, sizeBytes,
true);
154 Value alignment = getAlignment(rewriter, loc, op);
157 sizeBytes = rewriter.
create<LLVM::AddOp>(loc, sizeBytes, alignment);
162 assert(elementPtrType &&
"could not compute element ptr type");
164 rewriter.
create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
167 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
168 elementPtrType, *getTypeConverter());
169 Value alignedPtr = allocatedPtr;
173 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
175 createAligned(rewriter, loc, allocatedInt, alignment);
177 rewriter.
create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
181 auto memRefDescriptor = this->createMemRefDescriptor(
182 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
185 rewriter.
replaceOp(op, {memRefDescriptor});
190 template <
typename OpType>
193 MemRefType memRefType = op.
getType();
195 if (
auto alignmentAttr = op.getAlignment()) {
196 Type indexType = getIndexType();
199 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
204 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
214 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
216 auto loc = op.getLoc();
217 MemRefType memRefType = op.getType();
218 if (!isConvertibleAndHasIdentityMaps(memRefType))
222 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
223 rewriter, getTypeConverter(),
225 if (failed(allocFuncOp))
235 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
236 rewriter, sizes, strides, sizeBytes, !
false);
238 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
240 Value allocAlignment =
245 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
246 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
249 auto results = rewriter.
create<LLVM::CallOp>(
250 loc, allocFuncOp.value(),
ValueRange({allocAlignment, sizeBytes}));
253 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
254 elementPtrType, *getTypeConverter());
257 auto memRefDescriptor = this->createMemRefDescriptor(
258 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
261 rewriter.
replaceOp(op, {memRefDescriptor});
266 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
273 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
275 if (std::optional<uint64_t> alignment = op.getAlignment())
281 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
282 getTypeConverter(), op.getType(), op, defaultLayout);
283 return std::max(kMinAlignedAllocAlignment,
284 llvm::PowerOf2Ceil(eltSizeBytes));
289 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
291 uint64_t sizeDivisor =
292 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
293 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
294 if (type.isDynamicDim(i))
296 sizeDivisor = sizeDivisor * type.getDimSize(i);
298 return sizeDivisor % factor == 0;
313 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
315 auto loc = op.getLoc();
316 MemRefType memRefType = op.getType();
317 if (!isConvertibleAndHasIdentityMaps(memRefType))
327 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
328 rewriter, sizes, strides, size, !
true);
333 typeConverter->
convertType(op.getType().getElementType());
334 FailureOr<unsigned> maybeAddressSpace =
335 getTypeConverter()->getMemRefAddressSpace(op.getType());
336 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
337 unsigned addrSpace = *maybeAddressSpace;
338 auto elementPtrType =
341 auto allocatedElementPtr = rewriter.
create<LLVM::AllocaOp>(
342 loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
345 auto memRefDescriptor = this->createMemRefDescriptor(
346 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
350 rewriter.
replaceOp(op, {memRefDescriptor});
355 struct AllocaScopeOpLowering
360 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
363 Location loc = allocaScopeOp.getLoc();
368 auto *remainingOpsBlock =
370 Block *continueBlock;
371 if (allocaScopeOp.getNumResults() == 0) {
372 continueBlock = remainingOpsBlock;
375 remainingOpsBlock, allocaScopeOp.getResultTypes(),
377 allocaScopeOp.getLoc()));
382 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
383 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
389 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
396 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
398 returnOp, returnOp.getResults(), continueBlock);
402 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
411 struct AssumeAlignmentOpLowering
419 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
421 Value memref = adaptor.getMemref();
422 unsigned alignment = op.getAlignment();
423 auto loc = op.getLoc();
425 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
434 Value alignmentConst =
453 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
456 FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
457 rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
458 if (failed(freeFunc))
461 if (
auto unrankedTy =
462 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
464 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
466 rewriter, op.getLoc(),
486 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
488 Type operandType = dimOp.getSource().getType();
489 if (isa<UnrankedMemRefType>(operandType)) {
490 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
491 operandType, dimOp, adaptor.getOperands(), rewriter);
492 if (failed(extractedSize))
494 rewriter.
replaceOp(dimOp, {*extractedSize});
497 if (isa<MemRefType>(operandType)) {
499 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
500 adaptor.getOperands(), rewriter)});
503 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
508 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
513 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
514 auto scalarMemRefType =
516 FailureOr<unsigned> maybeAddressSpace =
517 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
518 if (failed(maybeAddressSpace)) {
519 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
523 unsigned addressSpace = *maybeAddressSpace;
537 loc, indexPtrTy, elementType, underlyingRankedDesc,
546 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
549 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
553 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
554 if (
auto idx = dimOp.getConstantIndex())
557 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
558 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
563 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
569 MemRefType memRefType = cast<MemRefType>(operandType);
570 Type indexType = getIndexType();
571 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
573 if (i >= 0 && i < memRefType.getRank()) {
574 if (memRefType.isDynamicDim(i)) {
577 return descriptor.
size(rewriter, loc, i);
580 int64_t dimSize = memRefType.getDimSize(i);
584 Value index = adaptor.getIndex();
585 int64_t rank = memRefType.getRank();
587 return memrefDescriptor.
size(rewriter, loc, index, rank);
594 template <
typename Derived>
598 using Base = LoadStoreOpLowering<Derived>;
628 struct GenericAtomicRMWOpLowering
629 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
633 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
635 auto loc = atomicOp.getLoc();
636 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
648 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
650 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
652 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
653 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
659 auto loopArgument = loopBlock->getArgument(0);
661 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
671 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
672 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
673 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
674 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
676 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
677 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
681 loopBlock, newLoaded);
686 rewriter.
replaceOp(atomicOp, {newLoaded});
694 convertGlobalMemrefTypeToLLVM(MemRefType type,
702 Type arrayTy = elementType;
704 for (int64_t dim : llvm::reverse(type.getShape()))
710 struct GlobalMemrefOpLowering
715 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
717 MemRefType type = global.getType();
718 if (!isConvertibleAndHasIdentityMaps(type))
721 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
723 LLVM::Linkage linkage =
724 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
727 if (!global.isExternal() && !global.isUninitialized()) {
728 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
729 initialValue = elementsAttr;
733 if (type.getRank() == 0)
734 initialValue = elementsAttr.getSplatValue<
Attribute>();
737 uint64_t alignment = global.getAlignment().value_or(0);
738 FailureOr<unsigned> addressSpace =
739 getTypeConverter()->getMemRefAddressSpace(type);
740 if (failed(addressSpace))
741 return global.emitOpError(
742 "memory space cannot be converted to an integer address space");
744 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
745 initialValue, alignment, *addressSpace);
746 if (!global.isExternal() && global.isUninitialized()) {
747 rewriter.
createBlock(&newGlobal.getInitializerRegion());
749 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
750 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
759 struct GetGlobalMemrefOpLowering
766 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
768 auto loc = op.getLoc();
769 MemRefType memRefType = op.getType();
770 if (!isConvertibleAndHasIdentityMaps(memRefType))
780 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
781 rewriter, sizes, strides, sizeBytes, !
false);
783 MemRefType type = cast<MemRefType>(op.getResult().getType());
787 FailureOr<unsigned> maybeAddressSpace =
788 getTypeConverter()->getMemRefAddressSpace(type);
789 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
790 unsigned memSpace = *maybeAddressSpace;
792 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
795 rewriter.
create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
799 auto gep = rewriter.
create<LLVM::GEPOp>(
800 loc, ptrTy, arrayTy, addressOf,
806 auto intPtrType = getIntPtrType(memSpace);
807 Value deadBeefConst =
810 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
815 auto memRefDescriptor = this->createMemRefDescriptor(
816 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
819 rewriter.
replaceOp(op, {memRefDescriptor});
826 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
830 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
832 auto type = loadOp.getMemRefType();
841 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
842 false, loadOp.getNontemporal());
849 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
853 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
855 auto type = op.getMemRefType();
864 0,
false, op.getNontemporal());
871 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
875 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
877 auto type = prefetchOp.getMemRefType();
878 auto loc = prefetchOp.getLoc();
881 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
885 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
889 localityHint, isData);
898 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
901 Type operandType = op.getMemref().getType();
902 if (isa<UnrankedMemRefType>(operandType)) {
907 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
908 Type indexType = getIndexType();
911 rankedMemRefType.getRank())});
922 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
924 Type srcType = memRefCastOp.getOperand().getType();
925 Type dstType = memRefCastOp.getType();
932 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
938 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
941 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
942 auto loc = memRefCastOp.getLoc();
945 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
946 rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
950 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
955 auto srcMemRefType = cast<MemRefType>(srcType);
956 int64_t rank = srcMemRefType.getRank();
958 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
959 loc, adaptor.getSource(), rewriter);
962 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
968 memRefDesc.setRank(rewriter, loc, rankVal);
970 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
973 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
982 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, targetStructType, ptr);
983 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
985 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
1001 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1003 auto loc = op.getLoc();
1004 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1009 Value numElements = rewriter.
create<LLVM::ConstantOp>(
1011 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1012 auto size = srcDesc.
size(rewriter, loc, pos);
1013 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
1017 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1020 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
1022 Type elementType = typeConverter->
convertType(srcType.getElementType());
1027 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
1029 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1030 Value targetOffset = targetDesc.offset(rewriter, loc);
1032 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
1033 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
1041 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1043 auto loc = op.getLoc();
1044 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1045 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1048 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1049 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1051 auto *typeConverter = getTypeConverter();
1053 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1058 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
1063 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1065 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1066 Value unrankedSource =
1067 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1068 : adaptor.getSource();
1069 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1070 Value unrankedTarget =
1071 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1072 : adaptor.getTarget();
1075 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1080 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1081 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
1085 auto sourcePtr =
promote(unrankedSource);
1086 auto targetPtr =
promote(unrankedTarget);
1090 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1092 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1096 rewriter.
create<LLVM::CallOp>(loc, copyFn.value(),
1100 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1108 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1110 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1111 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1114 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1118 return memrefType &&
1119 (memrefType.getLayout().isIdentity() ||
1120 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1124 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1125 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1127 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1131 struct MemorySpaceCastOpLowering
1137 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1141 Type resultType = op.getDest().getType();
1142 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1143 auto resultDescType =
1144 cast<LLVM::LLVMStructType>(typeConverter->
convertType(resultTypeR));
1145 Type newPtrType = resultDescType.getBody()[0];
1151 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1153 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1155 resultTypeR, descVals);
1159 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1162 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1163 FailureOr<unsigned> maybeSourceAddrSpace =
1164 getTypeConverter()->getMemRefAddressSpace(sourceType);
1165 if (failed(maybeSourceAddrSpace))
1167 "non-integer source address space");
1168 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1169 FailureOr<unsigned> maybeResultAddrSpace =
1170 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1171 if (failed(maybeResultAddrSpace))
1173 "non-integer result address space");
1174 unsigned resultAddrSpace = *maybeResultAddrSpace;
1177 Value rank = sourceDesc.
rank(rewriter, loc);
1182 rewriter, loc, typeConverter->
convertType(resultTypeU));
1183 result.setRank(rewriter, loc, rank);
1186 result, resultAddrSpace, sizes);
1187 Value resultUnderlyingSize = sizes.front();
1188 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
1189 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
1190 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1193 auto sourceElemPtrType =
1195 auto resultElemPtrType =
1199 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1201 sourceDesc.
alignedPtr(rewriter, loc, *getTypeConverter(),
1202 sourceUnderlyingDesc, sourceElemPtrType);
1203 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1204 loc, resultElemPtrType, allocatedPtr);
1205 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1206 loc, resultElemPtrType, alignedPtr);
1208 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1209 resultElemPtrType, allocatedPtr);
1210 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1211 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1214 Value sourceIndexVals =
1215 sourceDesc.
offsetBasePtr(rewriter, loc, *getTypeConverter(),
1216 sourceUnderlyingDesc, sourceElemPtrType);
1217 Value resultIndexVals =
1218 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1219 resultUnderlyingDesc, resultElemPtrType);
1221 int64_t bytesToSkip =
1223 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1224 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
1225 loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
1227 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1228 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1241 static void extractPointersAndOffset(
Location loc,
1244 Value originalOperand,
1245 Value convertedOperand,
1247 Value *offset =
nullptr) {
1249 if (isa<MemRefType>(operandType)) {
1252 *alignedPtr = desc.
alignedPtr(rewriter, loc);
1253 if (offset !=
nullptr)
1254 *offset = desc.
offset(rewriter, loc);
1260 cast<UnrankedMemRefType>(operandType));
1261 auto elementPtrType =
1270 rewriter, loc, underlyingDescPtr, elementPtrType);
1272 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1273 if (offset !=
nullptr) {
1275 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1279 struct MemRefReinterpretCastOpLowering
1285 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1287 Type srcType = castOp.getSource().getType();
1290 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1291 adaptor, &descriptor)))
1293 rewriter.
replaceOp(castOp, {descriptor});
1298 LogicalResult convertSourceMemRefToDescriptor(
1300 memref::ReinterpretCastOp castOp,
1301 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1302 MemRefType targetMemRefType =
1303 cast<MemRefType>(castOp.getResult().getType());
1304 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1306 if (!llvmTargetDescriptorTy)
1314 Value allocatedPtr, alignedPtr;
1315 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1316 castOp.getSource(), adaptor.getSource(),
1317 &allocatedPtr, &alignedPtr);
1318 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1319 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1322 if (castOp.isDynamicOffset(0))
1323 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1325 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1328 unsigned dynSizeId = 0;
1329 unsigned dynStrideId = 0;
1330 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1331 if (castOp.isDynamicSize(i))
1332 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1334 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1336 if (castOp.isDynamicStride(i))
1337 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1339 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1346 struct MemRefReshapeOpLowering
1351 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1353 Type srcType = reshapeOp.getSource().getType();
1356 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1357 adaptor, &descriptor)))
1359 rewriter.
replaceOp(reshapeOp, {descriptor});
1366 Type srcType, memref::ReshapeOp reshapeOp,
1367 memref::ReshapeOp::Adaptor adaptor,
1368 Value *descriptor)
const {
1369 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1370 if (shapeMemRefType.hasStaticShape()) {
1371 MemRefType targetMemRefType =
1372 cast<MemRefType>(reshapeOp.getResult().getType());
1373 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1375 if (!llvmTargetDescriptorTy)
1384 Value allocatedPtr, alignedPtr;
1385 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1386 reshapeOp.getSource(), adaptor.getSource(),
1387 &allocatedPtr, &alignedPtr);
1388 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1389 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1394 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1396 reshapeOp,
"failed to get stride and offset exprs");
1398 if (!isStaticStrideOrOffset(offset))
1400 "dynamic offset is unsupported");
1402 desc.setConstantOffset(rewriter, loc, offset);
1404 assert(targetMemRefType.getLayout().isIdentity() &&
1405 "Identity layout map is a precondition of a valid reshape op");
1407 Type indexType = getIndexType();
1408 Value stride =
nullptr;
1409 int64_t targetRank = targetMemRefType.getRank();
1410 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1411 if (!ShapedType::isDynamic(strides[i])) {
1416 }
else if (!stride) {
1426 if (!targetMemRefType.isDynamicDim(i)) {
1428 targetMemRefType.getDimSize(i));
1430 Value shapeOp = reshapeOp.getShape();
1432 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1433 Type indexType = getIndexType();
1434 if (dimSize.
getType() != indexType)
1436 rewriter, loc, indexType, dimSize);
1437 assert(dimSize &&
"Invalid memref element type");
1440 desc.setSize(rewriter, loc, i, dimSize);
1441 desc.setStride(rewriter, loc, i, stride);
1444 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1454 Value resultRank = shapeDesc.
size(rewriter, loc, 0);
1457 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1458 unsigned addressSpace =
1459 *getTypeConverter()->getMemRefAddressSpace(targetType);
1464 rewriter, loc, typeConverter->
convertType(targetType));
1465 targetDesc.setRank(rewriter, loc, resultRank);
1468 targetDesc, addressSpace, sizes);
1469 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1472 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1475 Value allocatedPtr, alignedPtr, offset;
1476 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1477 reshapeOp.getSource(), adaptor.getSource(),
1478 &allocatedPtr, &alignedPtr, &offset);
1481 auto elementPtrType =
1485 elementPtrType, allocatedPtr);
1487 underlyingDescPtr, elementPtrType,
1490 underlyingDescPtr, elementPtrType,
1496 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1498 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1501 Value resultRankMinusOne =
1502 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1505 Type indexType = getTypeConverter()->getIndexType();
1509 {indexType, indexType}, {loc, loc});
1512 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1516 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1525 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1534 loc, llvmIndexPtrType,
1535 typeConverter->
convertType(shapeMemRefType.getElementType()),
1536 shapeOperandPtr, indexArg);
1537 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1539 targetSizesBase, indexArg, size);
1543 targetStridesBase, indexArg, strideArg);
1544 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1547 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1556 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1557 remainder, std::nullopt);
1562 *descriptor = targetDesc;
1569 template <
typename ReshapeOp>
1570 class ReassociatingReshapeOpConversion
1574 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1577 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1581 "reassociation operations should have been expanded beforehand");
1591 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1594 subViewOp,
"subview operations should have been expanded beforehand");
1610 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1612 auto loc = transposeOp.getLoc();
1616 if (transposeOp.getPermutation().isIdentity())
1617 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1621 typeConverter->
convertType(transposeOp.getIn().getType()));
1625 targetMemRef.setAllocatedPtr(rewriter, loc,
1627 targetMemRef.setAlignedPtr(rewriter, loc,
1631 targetMemRef.setOffset(rewriter, loc, viewMemRef.
offset(rewriter, loc));
1637 for (
const auto &en :
1639 int targetPos = en.index();
1640 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1641 targetMemRef.setSize(rewriter, loc, targetPos,
1642 viewMemRef.
size(rewriter, loc, sourcePos));
1643 targetMemRef.setStride(rewriter, loc, targetPos,
1644 viewMemRef.
stride(rewriter, loc, sourcePos));
1647 rewriter.
replaceOp(transposeOp, {targetMemRef});
1664 Type indexType)
const {
1665 assert(idx < shape.size());
1666 if (!ShapedType::isDynamic(shape[idx]))
1670 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1671 return dynamicSizes[nDynamic];
1680 Value runningStride,
unsigned idx,
Type indexType)
const {
1681 assert(idx < strides.size());
1682 if (!ShapedType::isDynamic(strides[idx]))
1685 return runningStride
1686 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1688 assert(!runningStride);
1693 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1695 auto loc = viewOp.getLoc();
1697 auto viewMemRefType = viewOp.getType();
1698 auto targetElementTy =
1699 typeConverter->
convertType(viewMemRefType.getElementType());
1700 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1701 if (!targetDescTy || !targetElementTy ||
1704 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1709 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1710 if (failed(successStrides))
1711 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1712 assert(offset == 0 &&
"expected offset to be 0");
1716 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1717 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1726 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1727 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1731 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1733 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1734 adaptor.getByteShift());
1736 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1738 Type indexType = getIndexType();
1742 targetMemRef.setOffset(
1747 if (viewMemRefType.getRank() == 0)
1748 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1751 Value stride =
nullptr, nextSize =
nullptr;
1752 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1754 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1755 adaptor.getSizes(), i, indexType);
1756 targetMemRef.setSize(rewriter, loc, i, size);
1759 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1760 targetMemRef.setStride(rewriter, loc, i, stride);
1764 rewriter.
replaceOp(viewOp, {targetMemRef});
1775 static std::optional<LLVM::AtomicBinOp>
1776 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1777 switch (atomicOp.getKind()) {
1778 case arith::AtomicRMWKind::addf:
1779 return LLVM::AtomicBinOp::fadd;
1780 case arith::AtomicRMWKind::addi:
1781 return LLVM::AtomicBinOp::add;
1782 case arith::AtomicRMWKind::assign:
1783 return LLVM::AtomicBinOp::xchg;
1784 case arith::AtomicRMWKind::maximumf:
1785 return LLVM::AtomicBinOp::fmax;
1786 case arith::AtomicRMWKind::maxs:
1788 case arith::AtomicRMWKind::maxu:
1789 return LLVM::AtomicBinOp::umax;
1790 case arith::AtomicRMWKind::minimumf:
1791 return LLVM::AtomicBinOp::fmin;
1792 case arith::AtomicRMWKind::mins:
1794 case arith::AtomicRMWKind::minu:
1795 return LLVM::AtomicBinOp::umin;
1796 case arith::AtomicRMWKind::ori:
1797 return LLVM::AtomicBinOp::_or;
1798 case arith::AtomicRMWKind::andi:
1799 return LLVM::AtomicBinOp::_and;
1801 return std::nullopt;
1803 llvm_unreachable(
"Invalid AtomicRMWKind");
1806 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1810 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1812 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1815 auto memRefType = atomicOp.getMemRefType();
1818 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1822 adaptor.getMemref(), adaptor.getIndices());
1824 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1825 LLVM::AtomicOrdering::acq_rel);
1831 class ConvertExtractAlignedPointerAsIndex
1838 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1846 alignedPtr = desc.
alignedPtr(rewriter, extractOp->getLoc());
1855 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1860 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1867 class ExtractStridedMetadataOpLowering
1874 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1883 Location loc = extractStridedMetadataOp.getLoc();
1884 Value source = extractStridedMetadataOp.getSource();
1886 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1887 int64_t rank = sourceMemRefType.getRank();
1889 results.reserve(2 + rank * 2);
1895 rewriter, loc, *getTypeConverter(),
1896 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1897 baseBuffer, alignedBuffer);
1898 results.push_back((
Value)dstMemRef);
1901 results.push_back(sourceMemRef.
offset(rewriter, loc));
1904 for (
unsigned i = 0; i < rank; ++i)
1905 results.push_back(sourceMemRef.
size(rewriter, loc, i));
1907 for (
unsigned i = 0; i < rank; ++i)
1908 results.push_back(sourceMemRef.
stride(rewriter, loc, i));
1910 rewriter.
replaceOp(extractStridedMetadataOp, results);
1922 AllocaScopeOpLowering,
1923 AtomicRMWOpLowering,
1924 AssumeAlignmentOpLowering,
1925 ConvertExtractAlignedPointerAsIndex,
1927 ExtractStridedMetadataOpLowering,
1928 GenericAtomicRMWOpLowering,
1929 GlobalMemrefOpLowering,
1930 GetGlobalMemrefOpLowering,
1932 MemRefCastOpLowering,
1933 MemRefCopyOpLowering,
1934 MemorySpaceCastOpLowering,
1935 MemRefReinterpretCastOpLowering,
1936 MemRefReshapeOpLowering,
1939 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1940 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1944 ViewOpLowering>(converter);
1948 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1950 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1954 struct FinalizeMemRefToLLVMConversionPass
1955 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1956 FinalizeMemRefToLLVMConversionPass> {
1957 using FinalizeMemRefToLLVMConversionPassBase::
1958 FinalizeMemRefToLLVMConversionPassBase;
1960 void runOnOperation()
override {
1962 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1964 dataLayoutAnalysis.getAtOrAbove(op));
1969 options.useGenericFunctions = useGenericFunctions;
1972 options.overrideIndexBitwidth(indexBitwidth);
1975 &dataLayoutAnalysis);
1979 target.addLegalOp<func::FuncOp>();
1981 signalPassFailure();
1988 void loadDependentDialects(
MLIRContext *context)
const final {
1989 context->loadDialect<LLVM::LLVMDialect>();
1994 void populateConvertToLLVMConversionPatterns(
2005 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
result_range getResults()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.