27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
32#define DEBUG_TYPE "memref-to-llvm"
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
46static bool isStaticStrideOrOffset(
int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
50static FailureOr<LLVM::LLVMFuncOp>
61static FailureOr<LLVM::LLVMFuncOp>
73static FailureOr<LLVM::LLVMFuncOp>
89static Value createAligned(ConversionPatternRewriter &rewriter,
Location loc,
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.
getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
106 layout = &analysis->getAbove(op);
108 Type elementType = memRefType.getElementType();
109 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
111 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
119 MemRefType memRefType,
Type elementPtrType,
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
124 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
142 symbolTables(symbolTables) {}
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter)
const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes,
true);
170 Value alignment = getAlignment(rewriter, loc, op);
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType &&
"could not compute element ptr type");
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
191 createAligned(rewriter, loc, allocatedInt, alignment);
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
201 rewriter.replaceOp(op, {memRefDescriptor});
206 template <
typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter,
Location loc,
209 MemRefType memRefType = op.
getType();
211 if (
auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
215 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
220 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
234 symbolTables(symbolTables) {}
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !
false);
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
264 Value allocAlignment =
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
272 Type elementPtrType = this->getElementPtrType(memRefType);
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType,
ptr,
ptr, sizes, strides, rewriter);
286 rewriter.replaceOp(op, {memRefDescriptor});
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
300 if (std::optional<uint64_t> alignment = op.getAlignment())
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
323 return sizeDivisor % factor == 0;
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !
true);
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
361 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
376 rewriter.replaceOp(op, {memRefDescriptor});
381struct AllocaScopeOpLowering
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter)
const override {
389 Location loc = allocaScopeOp.getLoc();
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc,
ValueRange(), remainingOpsBlock);
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc,
ValueRange(), beforeBody);
419 rewriter.setInsertionPointToEnd(afterBody);
421 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
430 rewriter.replaceOp(allocaScopeOp, continueBlock->
getArguments());
436struct AssumeAlignmentOpLowering
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter)
const override {
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType,
memref,
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(
true));
459 Value alignmentConst =
463 rewriter.replaceOp(op,
memref);
468struct DistinctObjectsOpLowering
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter)
const override {
479 if (operands.size() <= 1) {
481 rewriter.replaceOp(op, operands);
487 for (
auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
499 for (
auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (
auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
508 rewriter.replaceOp(op, operands);
524 symbolTables(symbolTables) {}
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter)
const override {
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
533 if (failed(freeFunc))
536 if (
auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter)
const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
569 rewriter.replaceOp(dimOp, {*extractedSize});
572 if (isa<MemRefType>(operandType)) {
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
578 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
583 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
585 ConversionPatternRewriter &rewriter)
const {
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
598 unsigned addressSpace = *maybeAddressSpace;
606 Type elementType = typeConverter->convertType(scalarMemRefType);
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
617 Value idxPlusOne = LLVM::AddOp::create(
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
630 if (
auto idx = dimOp.getConstantIndex())
633 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
639 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
641 ConversionPatternRewriter &rewriter)
const {
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t>
index = getConstantDimIndex(dimOp)) {
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
653 return descriptor.
size(rewriter, loc, i);
656 int64_t dimSize = memRefType.getDimSize(i);
661 int64_t rank = memRefType.getRank();
663 return memrefDescriptor.
size(rewriter, loc,
index, rank);
670template <
typename Derived>
674 using Base = LoadStoreOpLowering<Derived>;
704struct GenericAtomicRMWOpLowering
705 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
715 auto *initBlock = rewriter.getInsertionBlock();
716 auto *loopBlock = rewriter.splitBlock(initBlock,
Block::iterator(atomicOp));
717 loopBlock->addArgument(valueType, loc);
723 rewriter.setInsertionPointToEnd(initBlock);
724 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725 auto dataPtr = getStridedElementPtr(
726 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727 Value init = LLVM::LoadOp::create(
728 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
730 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
733 rewriter.setInsertionPointToStart(loopBlock);
736 auto loopArgument = loopBlock->getArgument(0);
738 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
742 mapping.
map(nestedOp.getResults(),
clone->getResults());
748 return atomicOp.emitError(
"result not defined in region");
753 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
754 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
756 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
757 result, successOrdering, failureOrdering);
759 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
760 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
764 loopBlock, newLoaded);
766 rewriter.setInsertionPointToEnd(endBlock);
769 rewriter.replaceOp(atomicOp, {newLoaded});
777convertGlobalMemrefTypeToLLVM(MemRefType type,
784 Type elementType = typeConverter.convertType(type.getElementType());
785 Type arrayTy = elementType;
787 for (
int64_t dim : llvm::reverse(type.getShape()))
788 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
801 symbolTables(symbolTables) {}
804 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
805 ConversionPatternRewriter &rewriter)
const override {
806 MemRefType type = global.getType();
807 if (!isConvertibleAndHasIdentityMaps(type))
810 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
812 LLVM::Linkage linkage =
813 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
814 bool isExternal = global.isExternal();
815 bool isUninitialized = global.isUninitialized();
818 if (!isExternal && !isUninitialized) {
819 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
820 initialValue = elementsAttr;
824 if (type.getRank() == 0)
825 initialValue = elementsAttr.getSplatValue<
Attribute>();
828 uint64_t alignment = global.getAlignment().value_or(0);
829 FailureOr<unsigned> addressSpace =
830 getTypeConverter()->getMemRefAddressSpace(type);
831 if (failed(addressSpace))
832 return global.emitOpError(
833 "memory space cannot be converted to an integer address space");
841 symbolTable->remove(global);
845 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
846 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
847 initialValue, alignment, *addressSpace);
851 symbolTable->
insert(newGlobal, rewriter.getInsertionPoint());
853 if (!isExternal && isUninitialized) {
854 rewriter.createBlock(&newGlobal.getInitializerRegion());
856 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
857 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
866struct GetGlobalMemrefOpLowering
873 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
874 ConversionPatternRewriter &rewriter)
const override {
875 auto loc = op.getLoc();
876 MemRefType memRefType = op.getType();
877 if (!isConvertibleAndHasIdentityMaps(memRefType))
878 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
887 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
888 rewriter, sizes, strides, sizeBytes, !
false);
890 MemRefType type = cast<MemRefType>(op.getResult().getType());
894 FailureOr<unsigned> maybeAddressSpace =
895 getTypeConverter()->getMemRefAddressSpace(type);
896 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
897 unsigned memSpace = *maybeAddressSpace;
899 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
900 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
902 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
907 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
913 auto intPtrType = getIntPtrType(memSpace);
914 Value deadBeefConst =
917 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
922 auto memRefDescriptor = this->createMemRefDescriptor(
923 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
926 rewriter.replaceOp(op, {memRefDescriptor});
933struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
937 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
938 ConversionPatternRewriter &rewriter)
const override {
939 auto type = loadOp.getMemRefType();
944 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
947 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
948 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
949 loadOp.getAlignment().value_or(0),
false, loadOp.getNontemporal());
956struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
960 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter)
const override {
962 auto type = op.getMemRefType();
968 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
970 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
971 op.getAlignment().value_or(0),
972 false, op.getNontemporal());
979struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
983 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter)
const override {
985 auto type = prefetchOp.getMemRefType();
986 auto loc = prefetchOp.getLoc();
988 Value dataPtr = getStridedElementPtr(
989 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
992 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
993 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
995 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
996 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
997 localityHint, isData);
1006 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter)
const override {
1009 Type operandType = op.getMemref().getType();
1010 if (isa<UnrankedMemRefType>(operandType)) {
1012 rewriter.replaceOp(op, {desc.
rank(rewriter, loc)});
1015 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1016 Type indexType = getIndexType();
1017 rewriter.replaceOp(op,
1019 rankedMemRefType.getRank())});
1030 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter)
const override {
1032 Type srcType = memRefCastOp.getOperand().getType();
1033 Type dstType = memRefCastOp.getType();
1040 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1041 if (typeConverter->convertType(srcType) !=
1042 typeConverter->convertType(dstType))
1046 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1049 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1050 auto loc = memRefCastOp.getLoc();
1053 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1054 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1058 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1063 auto srcMemRefType = cast<MemRefType>(srcType);
1064 int64_t rank = srcMemRefType.getRank();
1066 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1067 loc, adaptor.getSource(), rewriter);
1070 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1071 rewriter.getIndexAttr(rank));
1076 memRefDesc.
setRank(rewriter, loc, rankVal);
1079 rewriter.replaceOp(memRefCastOp, (
Value)memRefDesc);
1081 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1090 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType,
ptr);
1091 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1093 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
1113 symbolTables(symbolTables) {}
1116 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1117 ConversionPatternRewriter &rewriter)
const {
1118 auto loc = op.getLoc();
1119 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1124 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1125 rewriter.getIndexAttr(1));
1126 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1127 auto size = srcDesc.
size(rewriter, loc, pos);
1128 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1132 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1135 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1137 Type elementType = typeConverter->convertType(srcType.getElementType());
1141 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.
getType(),
1142 elementType, srcBasePtr, srcOffset);
1145 Value targetOffset = targetDesc.
offset(rewriter, loc);
1147 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.
getType(), elementType,
1148 targetBasePtr, targetOffset);
1149 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1151 rewriter.eraseOp(op);
1157 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter)
const {
1159 auto loc = op.getLoc();
1160 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1161 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1164 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1165 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1167 auto *typeConverter = getTypeConverter();
1172 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1174 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank,
ptr});
1178 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1180 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1181 Value unrankedSource =
1182 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1183 : adaptor.getSource();
1184 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1185 Value unrankedTarget =
1186 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1187 : adaptor.getTarget();
1190 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1191 rewriter.getIndexAttr(1));
1192 auto promote = [&](
Value desc) {
1193 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1195 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1196 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1200 auto sourcePtr = promote(unrankedSource);
1201 auto targetPtr = promote(unrankedTarget);
1205 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1207 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1208 sourcePtr.getType(), symbolTables);
1211 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1215 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1217 rewriter.eraseOp(op);
1223 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1224 ConversionPatternRewriter &rewriter)
const override {
1225 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1226 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1229 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1233 return memrefType &&
1234 (memrefType.getLayout().isIdentity() ||
1235 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1239 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1240 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1242 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1246struct MemorySpaceCastOpLowering
1252 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1253 ConversionPatternRewriter &rewriter)
const override {
1256 Type resultType = op.getDest().getType();
1257 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1258 auto convertedType =
1259 typeConverter->convertType<LLVM::LLVMStructType>(resultTypeR);
1261 return rewriter.notifyMatchFailure(op,
"memref type conversion failed");
1262 Type newPtrType = convertedType.getBody()[0];
1268 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1270 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1272 resultTypeR, descVals);
1273 rewriter.replaceOp(op,
result);
1276 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1279 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1280 FailureOr<unsigned> maybeSourceAddrSpace =
1281 getTypeConverter()->getMemRefAddressSpace(sourceType);
1282 if (failed(maybeSourceAddrSpace))
1283 return rewriter.notifyMatchFailure(loc,
1284 "non-integer source address space");
1285 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1286 FailureOr<unsigned> maybeResultAddrSpace =
1287 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1288 if (failed(maybeResultAddrSpace))
1289 return rewriter.notifyMatchFailure(loc,
1290 "non-integer result address space");
1291 unsigned resultAddrSpace = *maybeResultAddrSpace;
1294 Value rank = sourceDesc.
rank(rewriter, loc);
1299 rewriter, loc, typeConverter->convertType(resultTypeU));
1300 result.setRank(rewriter, loc, rank);
1302 rewriter, loc, *getTypeConverter(),
result, resultAddrSpace);
1303 Value resultUnderlyingDesc =
1304 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1305 rewriter.getI8Type(), resultUnderlyingSize);
1306 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1309 auto sourceElemPtrType =
1310 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1311 auto resultElemPtrType =
1312 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1315 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1317 sourceDesc.
alignedPtr(rewriter, loc, *getTypeConverter(),
1318 sourceUnderlyingDesc, sourceElemPtrType);
1319 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1320 rewriter, loc, resultElemPtrType, allocatedPtr);
1321 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1322 resultElemPtrType, alignedPtr);
1324 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1325 resultElemPtrType, allocatedPtr);
1326 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1327 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1330 Value sourceIndexVals =
1331 sourceDesc.
offsetBasePtr(rewriter, loc, *getTypeConverter(),
1332 sourceUnderlyingDesc, sourceElemPtrType);
1333 Value resultIndexVals =
1334 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1335 resultUnderlyingDesc, resultElemPtrType);
1338 2 * llvm::divideCeil(
1339 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1340 Value bytesToSkipConst = LLVM::ConstantOp::create(
1341 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1343 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1344 resultUnderlyingSize, bytesToSkipConst);
1345 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1351 return rewriter.notifyMatchFailure(loc,
"unexpected memref type");
1358static void extractPointersAndOffset(
Location loc,
1359 ConversionPatternRewriter &rewriter,
1361 Value originalOperand,
1362 Value convertedOperand,
1364 Value *offset =
nullptr) {
1366 if (isa<MemRefType>(operandType)) {
1369 *alignedPtr = desc.
alignedPtr(rewriter, loc);
1370 if (offset !=
nullptr)
1371 *offset = desc.
offset(rewriter, loc);
1377 cast<UnrankedMemRefType>(operandType));
1378 auto elementPtrType =
1379 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1387 rewriter, loc, underlyingDescPtr, elementPtrType);
1389 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1390 if (offset !=
nullptr) {
1392 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1396struct MemRefReinterpretCastOpLowering
1402 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1403 ConversionPatternRewriter &rewriter)
const override {
1404 Type srcType = castOp.getSource().getType();
1407 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1408 adaptor, &descriptor)))
1410 rewriter.replaceOp(castOp, {descriptor});
1415 LogicalResult convertSourceMemRefToDescriptor(
1416 ConversionPatternRewriter &rewriter,
Type srcType,
1417 memref::ReinterpretCastOp castOp,
1418 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1419 MemRefType targetMemRefType =
1420 cast<MemRefType>(castOp.getResult().getType());
1421 auto llvmTargetDescriptorTy =
1422 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1423 if (!llvmTargetDescriptorTy)
1431 Value allocatedPtr, alignedPtr;
1432 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1433 castOp.getSource(), adaptor.getSource(),
1434 &allocatedPtr, &alignedPtr);
1435 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1436 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1439 if (castOp.isDynamicOffset(0))
1440 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1442 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1445 unsigned dynSizeId = 0;
1446 unsigned dynStrideId = 0;
1447 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1448 if (castOp.isDynamicSize(i))
1449 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1451 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1453 if (castOp.isDynamicStride(i))
1454 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1456 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1463struct MemRefReshapeOpLowering
1468 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1469 ConversionPatternRewriter &rewriter)
const override {
1470 Type srcType = reshapeOp.getSource().getType();
1473 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1474 adaptor, &descriptor)))
1476 rewriter.replaceOp(reshapeOp, {descriptor});
1482 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1483 Type srcType, memref::ReshapeOp reshapeOp,
1484 memref::ReshapeOp::Adaptor adaptor,
1485 Value *descriptor)
const {
1486 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1487 if (shapeMemRefType.hasStaticShape()) {
1488 MemRefType targetMemRefType =
1489 cast<MemRefType>(reshapeOp.getResult().getType());
1490 auto llvmTargetDescriptorTy =
1491 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1492 if (!llvmTargetDescriptorTy)
1501 Value allocatedPtr, alignedPtr;
1502 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1503 reshapeOp.getSource(), adaptor.getSource(),
1504 &allocatedPtr, &alignedPtr);
1505 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1506 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1511 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1512 return rewriter.notifyMatchFailure(
1513 reshapeOp,
"failed to get stride and offset exprs");
1515 if (!isStaticStrideOrOffset(offset))
1516 return rewriter.notifyMatchFailure(reshapeOp,
1517 "dynamic offset is unsupported");
1519 desc.setConstantOffset(rewriter, loc, offset);
1521 assert(targetMemRefType.getLayout().isIdentity() &&
1522 "Identity layout map is a precondition of a valid reshape op");
1524 Type indexType = getIndexType();
1525 Value stride =
nullptr;
1526 int64_t targetRank = targetMemRefType.getRank();
1527 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1528 if (ShapedType::isStatic(strides[i])) {
1533 }
else if (!stride) {
1543 if (!targetMemRefType.isDynamicDim(i)) {
1545 targetMemRefType.getDimSize(i));
1547 Value shapeOp = reshapeOp.getShape();
1549 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp,
index);
1550 Type indexType = getIndexType();
1551 if (dimSize.
getType() != indexType)
1552 dimSize = typeConverter->materializeTargetConversion(
1553 rewriter, loc, indexType, dimSize);
1554 assert(dimSize &&
"Invalid memref element type");
1557 desc.setSize(rewriter, loc, i, dimSize);
1558 desc.setStride(rewriter, loc, i, stride);
1561 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1571 Value resultRank = shapeDesc.
size(rewriter, loc, 0);
1574 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1575 unsigned addressSpace =
1576 *getTypeConverter()->getMemRefAddressSpace(targetType);
1581 rewriter, loc, typeConverter->convertType(targetType));
1582 targetDesc.setRank(rewriter, loc, resultRank);
1584 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1585 Value underlyingDescPtr = LLVM::AllocaOp::create(
1586 rewriter, loc, getPtrType(), IntegerType::get(
getContext(), 8),
1588 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1591 Value allocatedPtr, alignedPtr, offset;
1592 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1593 reshapeOp.getSource(), adaptor.getSource(),
1594 &allocatedPtr, &alignedPtr, &offset);
1597 auto elementPtrType =
1598 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1601 elementPtrType, allocatedPtr);
1603 underlyingDescPtr, elementPtrType,
1606 underlyingDescPtr, elementPtrType,
1612 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1614 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1617 Value resultRankMinusOne =
1618 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1620 Block *initBlock = rewriter.getInsertionBlock();
1621 Type indexType = getTypeConverter()->getIndexType();
1622 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1624 Block *condBlock = rewriter.createBlock(initBlock->
getParent(), {},
1625 {indexType, indexType}, {loc, loc});
1628 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1629 rewriter.mergeBlocks(remainingBlock, condBlock,
ValueRange());
1631 rewriter.setInsertionPointToEnd(initBlock);
1632 LLVM::BrOp::create(rewriter, loc,
1633 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1634 rewriter.setInsertionPointToStart(condBlock);
1639 Value pred = LLVM::ICmpOp::create(
1640 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1641 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1644 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1645 rewriter.setInsertionPointToStart(bodyBlock);
1648 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1649 Value sizeLoadGep = LLVM::GEPOp::create(
1650 rewriter, loc, llvmIndexPtrType,
1651 typeConverter->convertType(shapeMemRefType.getElementType()),
1652 shapeOperandPtr, indexArg);
1653 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1655 targetSizesBase, indexArg, size);
1659 targetStridesBase, indexArg, strideArg);
1660 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1663 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1664 LLVM::BrOp::create(rewriter, loc,
ValueRange({decrement, nextStride}),
1668 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1671 rewriter.setInsertionPointToEnd(condBlock);
1672 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock,
ValueRange(),
1676 rewriter.setInsertionPointToStart(remainder);
1678 *descriptor = targetDesc;
1685template <
typename ReshapeOp>
1686class ReassociatingReshapeOpConversion
1690 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1693 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1694 ConversionPatternRewriter &rewriter)
const override {
1695 return rewriter.notifyMatchFailure(
1697 "reassociation operations should have been expanded beforehand");
1707 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter)
const override {
1709 return rewriter.notifyMatchFailure(
1710 subViewOp,
"subview operations should have been expanded beforehand");
1726 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1727 ConversionPatternRewriter &rewriter)
const override {
1728 auto loc = transposeOp.getLoc();
1732 if (transposeOp.getPermutation().isIdentity())
1733 return rewriter.replaceOp(transposeOp, {viewMemRef}),
success();
1737 typeConverter->convertType(transposeOp.getIn().getType()));
1741 targetMemRef.setAllocatedPtr(rewriter, loc,
1743 targetMemRef.setAlignedPtr(rewriter, loc,
1747 targetMemRef.setOffset(rewriter, loc, viewMemRef.
offset(rewriter, loc));
1753 for (
const auto &en :
1754 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1755 int targetPos = en.index();
1756 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1757 targetMemRef.setSize(rewriter, loc, targetPos,
1758 viewMemRef.
size(rewriter, loc, sourcePos));
1759 targetMemRef.setStride(rewriter, loc, targetPos,
1760 viewMemRef.
stride(rewriter, loc, sourcePos));
1763 rewriter.replaceOp(transposeOp, {targetMemRef});
1778 Value getSize(ConversionPatternRewriter &rewriter,
Location loc,
1780 Type indexType)
const {
1781 assert(idx <
shape.size());
1782 if (ShapedType::isStatic(
shape[idx]))
1786 llvm::count_if(
shape.take_front(idx), ShapedType::isDynamic);
1787 return dynamicSizes[nDynamic];
1794 Value getStride(ConversionPatternRewriter &rewriter,
Location loc,
1796 Value runningStride,
unsigned idx,
Type indexType)
const {
1797 assert(idx < strides.size());
1798 if (ShapedType::isStatic(strides[idx]))
1801 return runningStride
1802 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1804 assert(!runningStride);
1809 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1810 ConversionPatternRewriter &rewriter)
const override {
1811 auto loc = viewOp.getLoc();
1813 auto viewMemRefType = viewOp.getType();
1814 auto targetElementTy =
1815 typeConverter->convertType(viewMemRefType.getElementType());
1816 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1817 if (!targetDescTy || !targetElementTy ||
1820 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1825 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1826 if (failed(successStrides))
1827 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1828 assert(offset == 0 &&
"expected offset to be 0");
1832 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1833 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1842 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1843 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1847 alignedPtr = LLVM::GEPOp::create(
1848 rewriter, loc, alignedPtr.
getType(),
1849 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1850 adaptor.getByteShift());
1852 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1854 Type indexType = getIndexType();
1858 targetMemRef.setOffset(
1863 if (viewMemRefType.getRank() == 0)
1864 return rewriter.replaceOp(viewOp, {targetMemRef}),
success();
1867 Value stride =
nullptr, nextSize =
nullptr;
1868 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1870 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1871 adaptor.getSizes(), i, indexType);
1872 targetMemRef.setSize(rewriter, loc, i, size);
1875 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1876 targetMemRef.setStride(rewriter, loc, i, stride);
1880 rewriter.replaceOp(viewOp, {targetMemRef});
1891static std::optional<LLVM::AtomicBinOp>
1892matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1893 switch (atomicOp.getKind()) {
1894 case arith::AtomicRMWKind::addf:
1895 return LLVM::AtomicBinOp::fadd;
1896 case arith::AtomicRMWKind::addi:
1897 return LLVM::AtomicBinOp::add;
1898 case arith::AtomicRMWKind::assign:
1899 return LLVM::AtomicBinOp::xchg;
1900 case arith::AtomicRMWKind::maximumf:
1902 LDBG() <<
"the lowering of memref.atomicrmw maximumf changed "
1903 "from fmax to fmaximum, expect more NaNs";
1904 return LLVM::AtomicBinOp::fmaximum;
1905 case arith::AtomicRMWKind::maxnumf:
1906 return LLVM::AtomicBinOp::fmax;
1907 case arith::AtomicRMWKind::maxs:
1908 return LLVM::AtomicBinOp::max;
1909 case arith::AtomicRMWKind::maxu:
1910 return LLVM::AtomicBinOp::umax;
1911 case arith::AtomicRMWKind::minimumf:
1913 LDBG() <<
"the lowering of memref.atomicrmw minimum changed "
1914 "from fmin to fminimum, expect more NaNs";
1915 return LLVM::AtomicBinOp::fminimum;
1916 case arith::AtomicRMWKind::minnumf:
1917 return LLVM::AtomicBinOp::fmin;
1918 case arith::AtomicRMWKind::mins:
1919 return LLVM::AtomicBinOp::min;
1920 case arith::AtomicRMWKind::minu:
1921 return LLVM::AtomicBinOp::umin;
1922 case arith::AtomicRMWKind::ori:
1923 return LLVM::AtomicBinOp::_or;
1924 case arith::AtomicRMWKind::xori:
1925 return LLVM::AtomicBinOp::_xor;
1926 case arith::AtomicRMWKind::andi:
1927 return LLVM::AtomicBinOp::_and;
1929 return std::nullopt;
1931 llvm_unreachable(
"Invalid AtomicRMWKind");
1934struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1938 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1939 ConversionPatternRewriter &rewriter)
const override {
1940 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1943 auto memRefType = atomicOp.getMemRefType();
1946 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1949 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1950 adaptor.getMemref(), adaptor.getIndices());
1951 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1952 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1953 LLVM::AtomicOrdering::acq_rel);
1959class ConvertExtractAlignedPointerAsIndex
1966 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1968 ConversionPatternRewriter &rewriter)
const override {
1974 alignedPtr = desc.
alignedPtr(rewriter, extractOp->getLoc());
1976 auto elementPtrTy = LLVM::LLVMPointerType::get(
1983 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1987 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1988 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1995class ExtractStridedMetadataOpLowering
2002 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
2004 ConversionPatternRewriter &rewriter)
const override {
2011 Location loc = extractStridedMetadataOp.getLoc();
2012 Value source = extractStridedMetadataOp.getSource();
2014 auto sourceMemRefType = cast<MemRefType>(source.
getType());
2015 int64_t rank = sourceMemRefType.getRank();
2017 results.reserve(2 + rank * 2);
2023 rewriter, loc, *getTypeConverter(),
2024 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2025 baseBuffer, alignedBuffer);
2026 results.push_back((
Value)dstMemRef);
2029 results.push_back(sourceMemRef.
offset(rewriter, loc));
2032 for (
unsigned i = 0; i < rank; ++i)
2033 results.push_back(sourceMemRef.
size(rewriter, loc, i));
2035 for (
unsigned i = 0; i < rank; ++i)
2036 results.push_back(sourceMemRef.
stride(rewriter, loc, i));
2038 rewriter.replaceOp(extractStridedMetadataOp, results);
2051 AllocaScopeOpLowering,
2052 AssumeAlignmentOpLowering,
2053 AtomicRMWOpLowering,
2054 ConvertExtractAlignedPointerAsIndex,
2056 DistinctObjectsOpLowering,
2057 ExtractStridedMetadataOpLowering,
2058 GenericAtomicRMWOpLowering,
2059 GetGlobalMemrefOpLowering,
2061 MemRefCastOpLowering,
2062 MemRefReinterpretCastOpLowering,
2063 MemRefReshapeOpLowering,
2064 MemorySpaceCastOpLowering,
2067 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2068 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2072 ViewOpLowering>(converter);
2074 patterns.
add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2078 patterns.
add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2081 patterns.
add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2085struct FinalizeMemRefToLLVMConversionPass
2086 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
2087 FinalizeMemRefToLLVMConversionPass> {
2088 using FinalizeMemRefToLLVMConversionPassBase::
2089 FinalizeMemRefToLLVMConversionPassBase;
2091 void runOnOperation()
override {
2093 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2095 dataLayoutAnalysis.getAtOrAbove(op));
2100 options.useGenericFunctions = useGenericFunctions;
2103 options.overrideIndexBitwidth(indexBitwidth);
2106 &dataLayoutAnalysis);
2112 target.addLegalOp<func::FuncOp>();
2113 if (failed(applyPartialConversion(op,
target, std::move(patterns))))
2114 signalPassFailure();
2119struct MemRefToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
2120 MemRefToLLVMDialectInterface(Dialect *dialect)
2121 : ConvertToLLVMPatternInterface(dialect) {}
2123 void loadDependentDialects(MLIRContext *context)
const final {
2124 context->loadDialect<LLVM::LLVMDialect>();
2129 void populateConvertToLLVMConversionPatterns(
2130 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2131 RewritePatternSet &patterns)
const final {
2140 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.