27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
32#define DEBUG_TYPE "memref-to-llvm"
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
46static bool isStaticStrideOrOffset(
int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
50static FailureOr<LLVM::LLVMFuncOp>
61static FailureOr<LLVM::LLVMFuncOp>
73static FailureOr<LLVM::LLVMFuncOp>
89static Value createAligned(ConversionPatternRewriter &rewriter,
Location loc,
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.
getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
106 layout = &analysis->getAbove(op);
108 Type elementType = memRefType.getElementType();
109 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
111 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
119 MemRefType memRefType,
Type elementPtrType,
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
124 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
142 symbolTables(symbolTables) {}
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter)
const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes,
true);
170 Value alignment = getAlignment(rewriter, loc, op);
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType &&
"could not compute element ptr type");
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
191 createAligned(rewriter, loc, allocatedInt, alignment);
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
201 rewriter.replaceOp(op, {memRefDescriptor});
206 template <
typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter,
Location loc,
209 MemRefType memRefType = op.
getType();
211 if (
auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
215 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
220 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
234 symbolTables(symbolTables) {}
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !
false);
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
264 Value allocAlignment =
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
272 Type elementPtrType = this->getElementPtrType(memRefType);
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType,
ptr,
ptr, sizes, strides, rewriter);
286 rewriter.replaceOp(op, {memRefDescriptor});
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
300 if (std::optional<uint64_t> alignment = op.getAlignment())
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
323 return sizeDivisor % factor == 0;
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !
true);
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
361 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
376 rewriter.replaceOp(op, {memRefDescriptor});
381struct AllocaScopeOpLowering
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter)
const override {
389 Location loc = allocaScopeOp.getLoc();
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc,
ValueRange(), remainingOpsBlock);
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc,
ValueRange(), beforeBody);
419 rewriter.setInsertionPointToEnd(afterBody);
421 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
430 rewriter.replaceOp(allocaScopeOp, continueBlock->
getArguments());
436struct AssumeAlignmentOpLowering
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter)
const override {
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType,
memref,
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(
true));
459 Value alignmentConst =
463 rewriter.replaceOp(op,
memref);
468struct DistinctObjectsOpLowering
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter)
const override {
479 if (operands.size() <= 1) {
481 rewriter.replaceOp(op, operands);
487 for (
auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
499 for (
auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (
auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
508 rewriter.replaceOp(op, operands);
524 symbolTables(symbolTables) {}
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter)
const override {
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
533 if (failed(freeFunc))
536 if (
auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter)
const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
569 rewriter.replaceOp(dimOp, {*extractedSize});
572 if (isa<MemRefType>(operandType)) {
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
578 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
583 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
585 ConversionPatternRewriter &rewriter)
const {
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
598 unsigned addressSpace = *maybeAddressSpace;
606 Type elementType = typeConverter->convertType(scalarMemRefType);
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
617 Value idxPlusOne = LLVM::AddOp::create(
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
630 if (
auto idx = dimOp.getConstantIndex())
633 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
639 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
641 ConversionPatternRewriter &rewriter)
const {
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t>
index = getConstantDimIndex(dimOp)) {
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
653 return descriptor.
size(rewriter, loc, i);
656 int64_t dimSize = memRefType.getDimSize(i);
661 int64_t rank = memRefType.getRank();
663 return memrefDescriptor.
size(rewriter, loc,
index, rank);
670template <
typename Derived>
674 using Base = LoadStoreOpLowering<Derived>;
704struct GenericAtomicRMWOpLowering
705 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
715 auto *initBlock = rewriter.getInsertionBlock();
716 auto *loopBlock = rewriter.splitBlock(initBlock,
Block::iterator(atomicOp));
717 loopBlock->addArgument(valueType, loc);
723 rewriter.setInsertionPointToEnd(initBlock);
724 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725 auto dataPtr = getStridedElementPtr(
726 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727 Value init = LLVM::LoadOp::create(
728 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
730 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
733 rewriter.setInsertionPointToStart(loopBlock);
736 auto loopArgument = loopBlock->getArgument(0);
738 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
742 mapping.
map(nestedOp.getResults(),
clone->getResults());
748 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
749 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
751 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
752 result, successOrdering, failureOrdering);
754 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
755 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
759 loopBlock, newLoaded);
761 rewriter.setInsertionPointToEnd(endBlock);
764 rewriter.replaceOp(atomicOp, {newLoaded});
772convertGlobalMemrefTypeToLLVM(MemRefType type,
779 Type elementType = typeConverter.convertType(type.getElementType());
780 Type arrayTy = elementType;
782 for (
int64_t dim : llvm::reverse(type.getShape()))
783 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
796 symbolTables(symbolTables) {}
799 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter)
const override {
801 MemRefType type = global.getType();
802 if (!isConvertibleAndHasIdentityMaps(type))
805 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
807 LLVM::Linkage linkage =
808 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
809 bool isExternal = global.isExternal();
810 bool isUninitialized = global.isUninitialized();
813 if (!isExternal && !isUninitialized) {
814 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
815 initialValue = elementsAttr;
819 if (type.getRank() == 0)
820 initialValue = elementsAttr.getSplatValue<
Attribute>();
823 uint64_t alignment = global.getAlignment().value_or(0);
824 FailureOr<unsigned> addressSpace =
825 getTypeConverter()->getMemRefAddressSpace(type);
826 if (failed(addressSpace))
827 return global.emitOpError(
828 "memory space cannot be converted to an integer address space");
836 symbolTable->remove(global);
840 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
841 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
842 initialValue, alignment, *addressSpace);
846 symbolTable->
insert(newGlobal, rewriter.getInsertionPoint());
848 if (!isExternal && isUninitialized) {
849 rewriter.createBlock(&newGlobal.getInitializerRegion());
851 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
852 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
861struct GetGlobalMemrefOpLowering
868 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
869 ConversionPatternRewriter &rewriter)
const override {
870 auto loc = op.getLoc();
871 MemRefType memRefType = op.getType();
872 if (!isConvertibleAndHasIdentityMaps(memRefType))
873 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
882 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
883 rewriter, sizes, strides, sizeBytes, !
false);
885 MemRefType type = cast<MemRefType>(op.getResult().getType());
889 FailureOr<unsigned> maybeAddressSpace =
890 getTypeConverter()->getMemRefAddressSpace(type);
891 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
892 unsigned memSpace = *maybeAddressSpace;
894 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
895 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
897 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
902 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
908 auto intPtrType = getIntPtrType(memSpace);
909 Value deadBeefConst =
912 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
917 auto memRefDescriptor = this->createMemRefDescriptor(
918 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
921 rewriter.replaceOp(op, {memRefDescriptor});
928struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
932 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
933 ConversionPatternRewriter &rewriter)
const override {
934 auto type = loadOp.getMemRefType();
939 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
942 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
943 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
944 loadOp.getAlignment().value_or(0),
false, loadOp.getNontemporal());
951struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
955 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter)
const override {
957 auto type = op.getMemRefType();
963 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
965 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
966 op.getAlignment().value_or(0),
967 false, op.getNontemporal());
974struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
978 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
979 ConversionPatternRewriter &rewriter)
const override {
980 auto type = prefetchOp.getMemRefType();
981 auto loc = prefetchOp.getLoc();
983 Value dataPtr = getStridedElementPtr(
984 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
987 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
988 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
990 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
991 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
992 localityHint, isData);
1001 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter)
const override {
1004 Type operandType = op.getMemref().getType();
1005 if (isa<UnrankedMemRefType>(operandType)) {
1007 rewriter.replaceOp(op, {desc.
rank(rewriter, loc)});
1010 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1011 Type indexType = getIndexType();
1012 rewriter.replaceOp(op,
1014 rankedMemRefType.getRank())});
1025 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1026 ConversionPatternRewriter &rewriter)
const override {
1027 Type srcType = memRefCastOp.getOperand().getType();
1028 Type dstType = memRefCastOp.getType();
1035 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1036 if (typeConverter->convertType(srcType) !=
1037 typeConverter->convertType(dstType))
1041 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1044 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1045 auto loc = memRefCastOp.getLoc();
1048 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1049 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1053 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1058 auto srcMemRefType = cast<MemRefType>(srcType);
1059 int64_t rank = srcMemRefType.getRank();
1061 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1062 loc, adaptor.getSource(), rewriter);
1065 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1066 rewriter.getIndexAttr(rank));
1071 memRefDesc.
setRank(rewriter, loc, rankVal);
1074 rewriter.replaceOp(memRefCastOp, (
Value)memRefDesc);
1076 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1085 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType,
ptr);
1086 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1088 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
1108 symbolTables(symbolTables) {}
1111 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1112 ConversionPatternRewriter &rewriter)
const {
1113 auto loc = op.getLoc();
1114 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1119 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1120 rewriter.getIndexAttr(1));
1121 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1122 auto size = srcDesc.
size(rewriter, loc, pos);
1123 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1127 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1130 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1132 Type elementType = typeConverter->convertType(srcType.getElementType());
1136 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.
getType(),
1137 elementType, srcBasePtr, srcOffset);
1140 Value targetOffset = targetDesc.
offset(rewriter, loc);
1142 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.
getType(), elementType,
1143 targetBasePtr, targetOffset);
1144 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1146 rewriter.eraseOp(op);
1152 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1153 ConversionPatternRewriter &rewriter)
const {
1154 auto loc = op.getLoc();
1155 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1156 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1159 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1160 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1162 auto *typeConverter = getTypeConverter();
1167 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1169 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank,
ptr});
1173 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1175 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1176 Value unrankedSource =
1177 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1178 : adaptor.getSource();
1179 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1180 Value unrankedTarget =
1181 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1182 : adaptor.getTarget();
1185 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1186 rewriter.getIndexAttr(1));
1187 auto promote = [&](
Value desc) {
1188 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1190 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1191 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1195 auto sourcePtr = promote(unrankedSource);
1196 auto targetPtr = promote(unrankedTarget);
1200 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1202 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1203 sourcePtr.getType(), symbolTables);
1206 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1210 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1212 rewriter.eraseOp(op);
1218 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1219 ConversionPatternRewriter &rewriter)
const override {
1220 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1221 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1224 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1228 return memrefType &&
1229 (memrefType.getLayout().isIdentity() ||
1230 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1234 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1235 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1237 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1241struct MemorySpaceCastOpLowering
1247 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter)
const override {
1251 Type resultType = op.getDest().getType();
1252 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1253 auto resultDescType =
1254 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1255 Type newPtrType = resultDescType.getBody()[0];
1261 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1263 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1265 resultTypeR, descVals);
1266 rewriter.replaceOp(op,
result);
1269 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1272 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1273 FailureOr<unsigned> maybeSourceAddrSpace =
1274 getTypeConverter()->getMemRefAddressSpace(sourceType);
1275 if (failed(maybeSourceAddrSpace))
1276 return rewriter.notifyMatchFailure(loc,
1277 "non-integer source address space");
1278 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1279 FailureOr<unsigned> maybeResultAddrSpace =
1280 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1281 if (failed(maybeResultAddrSpace))
1282 return rewriter.notifyMatchFailure(loc,
1283 "non-integer result address space");
1284 unsigned resultAddrSpace = *maybeResultAddrSpace;
1287 Value rank = sourceDesc.
rank(rewriter, loc);
1292 rewriter, loc, typeConverter->convertType(resultTypeU));
1293 result.setRank(rewriter, loc, rank);
1295 rewriter, loc, *getTypeConverter(),
result, resultAddrSpace);
1296 Value resultUnderlyingDesc =
1297 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1298 rewriter.getI8Type(), resultUnderlyingSize);
1299 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1302 auto sourceElemPtrType =
1303 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1304 auto resultElemPtrType =
1305 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1308 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1310 sourceDesc.
alignedPtr(rewriter, loc, *getTypeConverter(),
1311 sourceUnderlyingDesc, sourceElemPtrType);
1312 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1313 rewriter, loc, resultElemPtrType, allocatedPtr);
1314 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1315 resultElemPtrType, alignedPtr);
1317 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1318 resultElemPtrType, allocatedPtr);
1319 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1320 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1323 Value sourceIndexVals =
1324 sourceDesc.
offsetBasePtr(rewriter, loc, *getTypeConverter(),
1325 sourceUnderlyingDesc, sourceElemPtrType);
1326 Value resultIndexVals =
1327 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1328 resultUnderlyingDesc, resultElemPtrType);
1331 2 * llvm::divideCeil(
1332 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1333 Value bytesToSkipConst = LLVM::ConstantOp::create(
1334 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1336 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1337 resultUnderlyingSize, bytesToSkipConst);
1338 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1344 return rewriter.notifyMatchFailure(loc,
"unexpected memref type");
1351static void extractPointersAndOffset(
Location loc,
1352 ConversionPatternRewriter &rewriter,
1354 Value originalOperand,
1355 Value convertedOperand,
1357 Value *offset =
nullptr) {
1359 if (isa<MemRefType>(operandType)) {
1362 *alignedPtr = desc.
alignedPtr(rewriter, loc);
1363 if (offset !=
nullptr)
1364 *offset = desc.
offset(rewriter, loc);
1370 cast<UnrankedMemRefType>(operandType));
1371 auto elementPtrType =
1372 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1380 rewriter, loc, underlyingDescPtr, elementPtrType);
1382 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1383 if (offset !=
nullptr) {
1385 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1389struct MemRefReinterpretCastOpLowering
1395 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter)
const override {
1397 Type srcType = castOp.getSource().getType();
1400 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1401 adaptor, &descriptor)))
1403 rewriter.replaceOp(castOp, {descriptor});
1408 LogicalResult convertSourceMemRefToDescriptor(
1409 ConversionPatternRewriter &rewriter,
Type srcType,
1410 memref::ReinterpretCastOp castOp,
1411 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1412 MemRefType targetMemRefType =
1413 cast<MemRefType>(castOp.getResult().getType());
1414 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415 typeConverter->convertType(targetMemRefType));
1416 if (!llvmTargetDescriptorTy)
1424 Value allocatedPtr, alignedPtr;
1425 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1426 castOp.getSource(), adaptor.getSource(),
1427 &allocatedPtr, &alignedPtr);
1428 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1429 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1432 if (castOp.isDynamicOffset(0))
1433 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1435 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1438 unsigned dynSizeId = 0;
1439 unsigned dynStrideId = 0;
1440 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1441 if (castOp.isDynamicSize(i))
1442 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1444 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1446 if (castOp.isDynamicStride(i))
1447 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1449 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1456struct MemRefReshapeOpLowering
1461 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter)
const override {
1463 Type srcType = reshapeOp.getSource().getType();
1466 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1467 adaptor, &descriptor)))
1469 rewriter.replaceOp(reshapeOp, {descriptor});
1475 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1476 Type srcType, memref::ReshapeOp reshapeOp,
1477 memref::ReshapeOp::Adaptor adaptor,
1478 Value *descriptor)
const {
1479 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1480 if (shapeMemRefType.hasStaticShape()) {
1481 MemRefType targetMemRefType =
1482 cast<MemRefType>(reshapeOp.getResult().getType());
1483 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1484 typeConverter->convertType(targetMemRefType));
1485 if (!llvmTargetDescriptorTy)
1494 Value allocatedPtr, alignedPtr;
1495 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1496 reshapeOp.getSource(), adaptor.getSource(),
1497 &allocatedPtr, &alignedPtr);
1498 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1499 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1504 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1505 return rewriter.notifyMatchFailure(
1506 reshapeOp,
"failed to get stride and offset exprs");
1508 if (!isStaticStrideOrOffset(offset))
1509 return rewriter.notifyMatchFailure(reshapeOp,
1510 "dynamic offset is unsupported");
1512 desc.setConstantOffset(rewriter, loc, offset);
1514 assert(targetMemRefType.getLayout().isIdentity() &&
1515 "Identity layout map is a precondition of a valid reshape op");
1517 Type indexType = getIndexType();
1518 Value stride =
nullptr;
1519 int64_t targetRank = targetMemRefType.getRank();
1520 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1521 if (ShapedType::isStatic(strides[i])) {
1526 }
else if (!stride) {
1536 if (!targetMemRefType.isDynamicDim(i)) {
1538 targetMemRefType.getDimSize(i));
1540 Value shapeOp = reshapeOp.getShape();
1542 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp,
index);
1543 Type indexType = getIndexType();
1544 if (dimSize.
getType() != indexType)
1545 dimSize = typeConverter->materializeTargetConversion(
1546 rewriter, loc, indexType, dimSize);
1547 assert(dimSize &&
"Invalid memref element type");
1550 desc.setSize(rewriter, loc, i, dimSize);
1551 desc.setStride(rewriter, loc, i, stride);
1554 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1564 Value resultRank = shapeDesc.
size(rewriter, loc, 0);
1567 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1568 unsigned addressSpace =
1569 *getTypeConverter()->getMemRefAddressSpace(targetType);
1574 rewriter, loc, typeConverter->convertType(targetType));
1575 targetDesc.setRank(rewriter, loc, resultRank);
1577 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1578 Value underlyingDescPtr = LLVM::AllocaOp::create(
1579 rewriter, loc, getPtrType(), IntegerType::get(
getContext(), 8),
1581 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1584 Value allocatedPtr, alignedPtr, offset;
1585 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1586 reshapeOp.getSource(), adaptor.getSource(),
1587 &allocatedPtr, &alignedPtr, &offset);
1590 auto elementPtrType =
1591 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1594 elementPtrType, allocatedPtr);
1596 underlyingDescPtr, elementPtrType,
1599 underlyingDescPtr, elementPtrType,
1605 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1607 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1610 Value resultRankMinusOne =
1611 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1613 Block *initBlock = rewriter.getInsertionBlock();
1614 Type indexType = getTypeConverter()->getIndexType();
1615 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1617 Block *condBlock = rewriter.createBlock(initBlock->
getParent(), {},
1618 {indexType, indexType}, {loc, loc});
1621 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1622 rewriter.mergeBlocks(remainingBlock, condBlock,
ValueRange());
1624 rewriter.setInsertionPointToEnd(initBlock);
1625 LLVM::BrOp::create(rewriter, loc,
1626 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1627 rewriter.setInsertionPointToStart(condBlock);
1632 Value pred = LLVM::ICmpOp::create(
1633 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1634 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1637 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1638 rewriter.setInsertionPointToStart(bodyBlock);
1641 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1642 Value sizeLoadGep = LLVM::GEPOp::create(
1643 rewriter, loc, llvmIndexPtrType,
1644 typeConverter->convertType(shapeMemRefType.getElementType()),
1645 shapeOperandPtr, indexArg);
1646 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1648 targetSizesBase, indexArg, size);
1652 targetStridesBase, indexArg, strideArg);
1653 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1656 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1657 LLVM::BrOp::create(rewriter, loc,
ValueRange({decrement, nextStride}),
1661 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1664 rewriter.setInsertionPointToEnd(condBlock);
1665 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock,
ValueRange(),
1669 rewriter.setInsertionPointToStart(remainder);
1671 *descriptor = targetDesc;
1678template <
typename ReshapeOp>
1679class ReassociatingReshapeOpConversion
1683 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1686 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1687 ConversionPatternRewriter &rewriter)
const override {
1688 return rewriter.notifyMatchFailure(
1690 "reassociation operations should have been expanded beforehand");
1700 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1701 ConversionPatternRewriter &rewriter)
const override {
1702 return rewriter.notifyMatchFailure(
1703 subViewOp,
"subview operations should have been expanded beforehand");
1719 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1720 ConversionPatternRewriter &rewriter)
const override {
1721 auto loc = transposeOp.getLoc();
1725 if (transposeOp.getPermutation().isIdentity())
1726 return rewriter.replaceOp(transposeOp, {viewMemRef}),
success();
1730 typeConverter->convertType(transposeOp.getIn().getType()));
1734 targetMemRef.setAllocatedPtr(rewriter, loc,
1736 targetMemRef.setAlignedPtr(rewriter, loc,
1740 targetMemRef.setOffset(rewriter, loc, viewMemRef.
offset(rewriter, loc));
1746 for (
const auto &en :
1747 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1748 int targetPos = en.index();
1749 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1750 targetMemRef.setSize(rewriter, loc, targetPos,
1751 viewMemRef.
size(rewriter, loc, sourcePos));
1752 targetMemRef.setStride(rewriter, loc, targetPos,
1753 viewMemRef.
stride(rewriter, loc, sourcePos));
1756 rewriter.replaceOp(transposeOp, {targetMemRef});
1771 Value getSize(ConversionPatternRewriter &rewriter,
Location loc,
1773 Type indexType)
const {
1774 assert(idx <
shape.size());
1775 if (ShapedType::isStatic(
shape[idx]))
1779 llvm::count_if(
shape.take_front(idx), ShapedType::isDynamic);
1780 return dynamicSizes[nDynamic];
1787 Value getStride(ConversionPatternRewriter &rewriter,
Location loc,
1789 Value runningStride,
unsigned idx,
Type indexType)
const {
1790 assert(idx < strides.size());
1791 if (ShapedType::isStatic(strides[idx]))
1794 return runningStride
1795 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1797 assert(!runningStride);
1802 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1803 ConversionPatternRewriter &rewriter)
const override {
1804 auto loc = viewOp.getLoc();
1806 auto viewMemRefType = viewOp.getType();
1807 auto targetElementTy =
1808 typeConverter->convertType(viewMemRefType.getElementType());
1809 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1810 if (!targetDescTy || !targetElementTy ||
1813 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1818 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1819 if (failed(successStrides))
1820 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1821 assert(offset == 0 &&
"expected offset to be 0");
1825 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1826 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1835 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1836 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1840 alignedPtr = LLVM::GEPOp::create(
1841 rewriter, loc, alignedPtr.
getType(),
1842 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1843 adaptor.getByteShift());
1845 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1847 Type indexType = getIndexType();
1851 targetMemRef.setOffset(
1856 if (viewMemRefType.getRank() == 0)
1857 return rewriter.replaceOp(viewOp, {targetMemRef}),
success();
1860 Value stride =
nullptr, nextSize =
nullptr;
1861 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1863 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1864 adaptor.getSizes(), i, indexType);
1865 targetMemRef.setSize(rewriter, loc, i, size);
1868 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1869 targetMemRef.setStride(rewriter, loc, i, stride);
1873 rewriter.replaceOp(viewOp, {targetMemRef});
1884static std::optional<LLVM::AtomicBinOp>
1885matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1886 switch (atomicOp.getKind()) {
1887 case arith::AtomicRMWKind::addf:
1888 return LLVM::AtomicBinOp::fadd;
1889 case arith::AtomicRMWKind::addi:
1890 return LLVM::AtomicBinOp::add;
1891 case arith::AtomicRMWKind::assign:
1892 return LLVM::AtomicBinOp::xchg;
1893 case arith::AtomicRMWKind::maximumf:
1895 LDBG() <<
"the lowering of memref.atomicrmw maximumf changed "
1896 "from fmax to fmaximum, expect more NaNs";
1897 return LLVM::AtomicBinOp::fmaximum;
1898 case arith::AtomicRMWKind::maxnumf:
1899 return LLVM::AtomicBinOp::fmax;
1900 case arith::AtomicRMWKind::maxs:
1901 return LLVM::AtomicBinOp::max;
1902 case arith::AtomicRMWKind::maxu:
1903 return LLVM::AtomicBinOp::umax;
1904 case arith::AtomicRMWKind::minimumf:
1906 LDBG() <<
"the lowering of memref.atomicrmw minimum changed "
1907 "from fmin to fminimum, expect more NaNs";
1908 return LLVM::AtomicBinOp::fminimum;
1909 case arith::AtomicRMWKind::minnumf:
1910 return LLVM::AtomicBinOp::fmin;
1911 case arith::AtomicRMWKind::mins:
1912 return LLVM::AtomicBinOp::min;
1913 case arith::AtomicRMWKind::minu:
1914 return LLVM::AtomicBinOp::umin;
1915 case arith::AtomicRMWKind::ori:
1916 return LLVM::AtomicBinOp::_or;
1917 case arith::AtomicRMWKind::xori:
1918 return LLVM::AtomicBinOp::_xor;
1919 case arith::AtomicRMWKind::andi:
1920 return LLVM::AtomicBinOp::_and;
1922 return std::nullopt;
1924 llvm_unreachable(
"Invalid AtomicRMWKind");
1927struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1931 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1932 ConversionPatternRewriter &rewriter)
const override {
1933 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1936 auto memRefType = atomicOp.getMemRefType();
1939 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1942 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1943 adaptor.getMemref(), adaptor.getIndices());
1944 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1945 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1946 LLVM::AtomicOrdering::acq_rel);
1952class ConvertExtractAlignedPointerAsIndex
1959 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1961 ConversionPatternRewriter &rewriter)
const override {
1967 alignedPtr = desc.
alignedPtr(rewriter, extractOp->getLoc());
1969 auto elementPtrTy = LLVM::LLVMPointerType::get(
1976 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1980 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1981 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1988class ExtractStridedMetadataOpLowering
1995 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1997 ConversionPatternRewriter &rewriter)
const override {
2004 Location loc = extractStridedMetadataOp.getLoc();
2005 Value source = extractStridedMetadataOp.getSource();
2007 auto sourceMemRefType = cast<MemRefType>(source.
getType());
2008 int64_t rank = sourceMemRefType.getRank();
2010 results.reserve(2 + rank * 2);
2016 rewriter, loc, *getTypeConverter(),
2017 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2018 baseBuffer, alignedBuffer);
2019 results.push_back((
Value)dstMemRef);
2022 results.push_back(sourceMemRef.
offset(rewriter, loc));
2025 for (
unsigned i = 0; i < rank; ++i)
2026 results.push_back(sourceMemRef.
size(rewriter, loc, i));
2028 for (
unsigned i = 0; i < rank; ++i)
2029 results.push_back(sourceMemRef.
stride(rewriter, loc, i));
2031 rewriter.replaceOp(extractStridedMetadataOp, results);
2044 AllocaScopeOpLowering,
2045 AssumeAlignmentOpLowering,
2046 AtomicRMWOpLowering,
2047 ConvertExtractAlignedPointerAsIndex,
2049 DistinctObjectsOpLowering,
2050 ExtractStridedMetadataOpLowering,
2051 GenericAtomicRMWOpLowering,
2052 GetGlobalMemrefOpLowering,
2054 MemRefCastOpLowering,
2055 MemRefReinterpretCastOpLowering,
2056 MemRefReshapeOpLowering,
2057 MemorySpaceCastOpLowering,
2060 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2061 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2065 ViewOpLowering>(converter);
2067 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2071 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2074 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2078struct FinalizeMemRefToLLVMConversionPass
2080 FinalizeMemRefToLLVMConversionPass> {
2081 using FinalizeMemRefToLLVMConversionPassBase::
2082 FinalizeMemRefToLLVMConversionPassBase;
2084 void runOnOperation()
override {
2086 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2088 dataLayoutAnalysis.getAtOrAbove(op));
2093 options.useGenericFunctions = useGenericFunctions;
2096 options.overrideIndexBitwidth(indexBitwidth);
2099 &dataLayoutAnalysis);
2105 target.addLegalOp<func::FuncOp>();
2106 if (failed(applyPartialConversion(op,
target, std::move(
patterns))))
2107 signalPassFailure();
2114 void loadDependentDialects(MLIRContext *context)
const final {
2115 context->loadDialect<LLVM::LLVMDialect>();
2120 void populateConvertToLLVMConversionPatterns(
2121 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2122 RewritePatternSet &
patterns)
const final {
2131 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.