27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
32#define DEBUG_TYPE "memref-to-llvm"
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
46static bool isStaticStrideOrOffset(
int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
50static FailureOr<LLVM::LLVMFuncOp>
61static FailureOr<LLVM::LLVMFuncOp>
73static FailureOr<LLVM::LLVMFuncOp>
89static Value createAligned(ConversionPatternRewriter &rewriter,
Location loc,
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.
getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
106 layout = &analysis->getAbove(op);
108 Type elementType = memRefType.getElementType();
109 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
111 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
119 MemRefType memRefType,
Type elementPtrType,
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
124 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
142 symbolTables(symbolTables) {}
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter)
const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes,
true);
170 Value alignment = getAlignment(rewriter, loc, op);
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType &&
"could not compute element ptr type");
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
191 createAligned(rewriter, loc, allocatedInt, alignment);
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
201 rewriter.replaceOp(op, {memRefDescriptor});
206 template <
typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter,
Location loc,
209 MemRefType memRefType = op.
getType();
211 if (
auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
215 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
220 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
234 symbolTables(symbolTables) {}
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !
false);
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
264 Value allocAlignment =
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
272 Type elementPtrType = this->getElementPtrType(memRefType);
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType,
ptr,
ptr, sizes, strides, rewriter);
286 rewriter.replaceOp(op, {memRefDescriptor});
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
300 if (std::optional<uint64_t> alignment = op.getAlignment())
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
323 return sizeDivisor % factor == 0;
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !
true);
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
361 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
376 rewriter.replaceOp(op, {memRefDescriptor});
381struct AllocaScopeOpLowering
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter)
const override {
389 Location loc = allocaScopeOp.getLoc();
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc,
ValueRange(), remainingOpsBlock);
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc,
ValueRange(), beforeBody);
419 rewriter.setInsertionPointToEnd(afterBody);
421 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
430 rewriter.replaceOp(allocaScopeOp, continueBlock->
getArguments());
436struct AssumeAlignmentOpLowering
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter)
const override {
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType,
memref,
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(
true));
459 Value alignmentConst =
463 rewriter.replaceOp(op,
memref);
468struct DistinctObjectsOpLowering
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter)
const override {
479 if (operands.size() <= 1) {
481 rewriter.replaceOp(op, operands);
487 for (
auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
499 for (
auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (
auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
508 rewriter.replaceOp(op, operands);
524 symbolTables(symbolTables) {}
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter)
const override {
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
533 if (failed(freeFunc))
536 if (
auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter)
const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
569 rewriter.replaceOp(dimOp, {*extractedSize});
572 if (isa<MemRefType>(operandType)) {
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
578 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
583 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
585 ConversionPatternRewriter &rewriter)
const {
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
598 unsigned addressSpace = *maybeAddressSpace;
606 Type elementType = typeConverter->convertType(scalarMemRefType);
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
617 Value idxPlusOne = LLVM::AddOp::create(
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
630 if (
auto idx = dimOp.getConstantIndex())
633 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
639 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
641 ConversionPatternRewriter &rewriter)
const {
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t>
index = getConstantDimIndex(dimOp)) {
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
653 return descriptor.
size(rewriter, loc, i);
656 int64_t dimSize = memRefType.getDimSize(i);
661 int64_t rank = memRefType.getRank();
663 return memrefDescriptor.
size(rewriter, loc,
index, rank);
670template <
typename Derived>
674 using Base = LoadStoreOpLowering<Derived>;
704struct GenericAtomicRMWOpLowering
705 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
717 bool needsBitcast = isa<FloatType>(valueType);
718 Type cmpxchgType = valueType;
720 unsigned bitWidth = cast<FloatType>(valueType).getWidth();
721 cmpxchgType = rewriter.getIntegerType(bitWidth);
725 auto *initBlock = rewriter.getInsertionBlock();
726 auto *loopBlock = rewriter.splitBlock(initBlock,
Block::iterator(atomicOp));
727 loopBlock->addArgument(cmpxchgType, loc);
733 rewriter.setInsertionPointToEnd(initBlock);
734 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
735 auto dataPtr = getStridedElementPtr(
736 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
737 Value init = LLVM::LoadOp::create(
738 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
741 init = LLVM::BitcastOp::create(rewriter, loc, cmpxchgType, init);
742 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
745 rewriter.setInsertionPointToStart(loopBlock);
748 Value loopArgument = loopBlock->getArgument(0);
749 Value loopArgForBody = loopArgument;
752 LLVM::BitcastOp::create(rewriter, loc, valueType, loopArgument);
754 mapping.
map(atomicOp.getCurrentValue(), loopArgForBody);
758 mapping.
map(nestedOp.getResults(),
clone->getResults());
764 return atomicOp.emitError(
"result not defined in region");
767 result = LLVM::BitcastOp::create(rewriter, loc, cmpxchgType,
result);
771 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
772 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
774 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
775 result, successOrdering, failureOrdering);
777 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
778 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
782 loopBlock, newLoaded);
788 rewriter.setInsertionPointToStart(endBlock);
789 newLoaded = LLVM::BitcastOp::create(rewriter, loc, valueType, newLoaded);
791 rewriter.setInsertionPointToEnd(endBlock);
792 rewriter.replaceOp(atomicOp, {newLoaded});
800convertGlobalMemrefTypeToLLVM(MemRefType type,
807 Type elementType = typeConverter.convertType(type.getElementType());
808 Type arrayTy = elementType;
810 for (
int64_t dim : llvm::reverse(type.getShape()))
811 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
824 symbolTables(symbolTables) {}
827 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter)
const override {
829 MemRefType type = global.getType();
830 if (!isConvertibleAndHasIdentityMaps(type))
833 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
835 LLVM::Linkage linkage =
836 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
837 bool isExternal = global.isExternal();
838 bool isUninitialized = global.isUninitialized();
841 if (!isExternal && !isUninitialized) {
842 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
843 initialValue = elementsAttr;
847 if (type.getRank() == 0)
848 initialValue = elementsAttr.getSplatValue<
Attribute>();
851 uint64_t alignment = global.getAlignment().value_or(0);
852 FailureOr<unsigned> addressSpace =
853 getTypeConverter()->getMemRefAddressSpace(type);
854 if (failed(addressSpace))
855 return global.emitOpError(
856 "memory space cannot be converted to an integer address space");
864 symbolTable->remove(global);
868 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
869 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
870 initialValue, alignment, *addressSpace);
874 symbolTable->
insert(newGlobal, rewriter.getInsertionPoint());
876 if (!isExternal && isUninitialized) {
877 rewriter.createBlock(&newGlobal.getInitializerRegion());
879 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
880 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
889struct GetGlobalMemrefOpLowering
896 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
897 ConversionPatternRewriter &rewriter)
const override {
898 auto loc = op.getLoc();
899 MemRefType memRefType = op.getType();
900 if (!isConvertibleAndHasIdentityMaps(memRefType))
901 return rewriter.notifyMatchFailure(op,
"incompatible memref type");
910 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
911 rewriter, sizes, strides, sizeBytes, !
false);
913 MemRefType type = cast<MemRefType>(op.getResult().getType());
917 FailureOr<unsigned> maybeAddressSpace =
918 getTypeConverter()->getMemRefAddressSpace(type);
919 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
920 unsigned memSpace = *maybeAddressSpace;
922 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
923 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
925 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
930 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
936 auto intPtrType = getIntPtrType(memSpace);
937 Value deadBeefConst =
940 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
945 auto memRefDescriptor = this->createMemRefDescriptor(
946 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
949 rewriter.replaceOp(op, {memRefDescriptor});
956struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
960 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter)
const override {
962 auto type = loadOp.getMemRefType();
967 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
970 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
971 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
972 loadOp.getAlignment().value_or(0),
false, loadOp.getNontemporal());
979struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
983 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter)
const override {
985 auto type = op.getMemRefType();
991 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
993 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
994 op.getAlignment().value_or(0),
995 false, op.getNontemporal());
1002struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
1006 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter)
const override {
1008 auto type = prefetchOp.getMemRefType();
1009 auto loc = prefetchOp.getLoc();
1011 Value dataPtr = getStridedElementPtr(
1012 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
1015 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
1016 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
1017 IntegerAttr isData =
1018 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
1019 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
1020 localityHint, isData);
1029 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1030 ConversionPatternRewriter &rewriter)
const override {
1032 Type operandType = op.getMemref().getType();
1033 if (isa<UnrankedMemRefType>(operandType)) {
1035 rewriter.replaceOp(op, {desc.
rank(rewriter, loc)});
1038 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1039 Type indexType = getIndexType();
1040 rewriter.replaceOp(op,
1042 rankedMemRefType.getRank())});
1053 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter)
const override {
1055 Type srcType = memRefCastOp.getOperand().getType();
1056 Type dstType = memRefCastOp.getType();
1063 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1064 if (typeConverter->convertType(srcType) !=
1065 typeConverter->convertType(dstType))
1069 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1072 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1073 auto loc = memRefCastOp.getLoc();
1076 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1077 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1081 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1086 auto srcMemRefType = cast<MemRefType>(srcType);
1087 int64_t rank = srcMemRefType.getRank();
1089 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1090 loc, adaptor.getSource(), rewriter);
1093 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1094 rewriter.getIndexAttr(rank));
1099 memRefDesc.
setRank(rewriter, loc, rankVal);
1102 rewriter.replaceOp(memRefCastOp, (
Value)memRefDesc);
1104 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1113 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType,
ptr);
1114 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1116 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
1136 symbolTables(symbolTables) {}
1139 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1140 ConversionPatternRewriter &rewriter)
const {
1141 auto loc = op.getLoc();
1142 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1147 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1148 rewriter.getIndexAttr(1));
1149 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1150 auto size = srcDesc.
size(rewriter, loc, pos);
1151 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1155 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1158 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1160 Type elementType = typeConverter->convertType(srcType.getElementType());
1164 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.
getType(),
1165 elementType, srcBasePtr, srcOffset);
1168 Value targetOffset = targetDesc.
offset(rewriter, loc);
1170 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.
getType(), elementType,
1171 targetBasePtr, targetOffset);
1172 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1174 rewriter.eraseOp(op);
1180 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1181 ConversionPatternRewriter &rewriter)
const {
1182 auto loc = op.getLoc();
1183 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1184 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1187 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1188 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1190 auto *typeConverter = getTypeConverter();
1195 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1197 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank,
ptr});
1201 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1203 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1204 Value unrankedSource =
1205 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1206 : adaptor.getSource();
1207 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1208 Value unrankedTarget =
1209 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1210 : adaptor.getTarget();
1213 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1214 rewriter.getIndexAttr(1));
1215 auto promote = [&](
Value desc) {
1216 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1218 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1219 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1223 auto sourcePtr = promote(unrankedSource);
1224 auto targetPtr = promote(unrankedTarget);
1228 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1230 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1231 sourcePtr.getType(), symbolTables);
1234 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1238 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1240 rewriter.eraseOp(op);
1246 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1247 ConversionPatternRewriter &rewriter)
const override {
1248 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1249 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1252 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1256 return memrefType &&
1257 (memrefType.getLayout().isIdentity() ||
1258 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1262 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1263 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1265 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1269struct MemorySpaceCastOpLowering
1275 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1276 ConversionPatternRewriter &rewriter)
const override {
1279 Type resultType = op.getDest().getType();
1280 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1281 auto convertedType =
1282 typeConverter->convertType<LLVM::LLVMStructType>(resultTypeR);
1284 return rewriter.notifyMatchFailure(op,
"memref type conversion failed");
1285 Type newPtrType = convertedType.getBody()[0];
1291 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1293 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1295 resultTypeR, descVals);
1296 rewriter.replaceOp(op,
result);
1299 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1302 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1303 FailureOr<unsigned> maybeSourceAddrSpace =
1304 getTypeConverter()->getMemRefAddressSpace(sourceType);
1305 if (failed(maybeSourceAddrSpace))
1306 return rewriter.notifyMatchFailure(loc,
1307 "non-integer source address space");
1308 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1309 FailureOr<unsigned> maybeResultAddrSpace =
1310 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1311 if (failed(maybeResultAddrSpace))
1312 return rewriter.notifyMatchFailure(loc,
1313 "non-integer result address space");
1314 unsigned resultAddrSpace = *maybeResultAddrSpace;
1317 Value rank = sourceDesc.
rank(rewriter, loc);
1322 rewriter, loc, typeConverter->convertType(resultTypeU));
1323 result.setRank(rewriter, loc, rank);
1325 rewriter, loc, *getTypeConverter(),
result, resultAddrSpace);
1326 Value resultUnderlyingDesc =
1327 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1328 rewriter.getI8Type(), resultUnderlyingSize);
1329 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1332 auto sourceElemPtrType =
1333 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1334 auto resultElemPtrType =
1335 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1338 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1340 sourceDesc.
alignedPtr(rewriter, loc, *getTypeConverter(),
1341 sourceUnderlyingDesc, sourceElemPtrType);
1342 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1343 rewriter, loc, resultElemPtrType, allocatedPtr);
1344 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1345 resultElemPtrType, alignedPtr);
1347 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1348 resultElemPtrType, allocatedPtr);
1349 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1350 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1353 Value sourceIndexVals =
1354 sourceDesc.
offsetBasePtr(rewriter, loc, *getTypeConverter(),
1355 sourceUnderlyingDesc, sourceElemPtrType);
1356 Value resultIndexVals =
1357 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1358 resultUnderlyingDesc, resultElemPtrType);
1361 2 * llvm::divideCeil(
1362 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1363 Value bytesToSkipConst = LLVM::ConstantOp::create(
1364 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1366 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1367 resultUnderlyingSize, bytesToSkipConst);
1368 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1374 return rewriter.notifyMatchFailure(loc,
"unexpected memref type");
1381static void extractPointersAndOffset(
Location loc,
1382 ConversionPatternRewriter &rewriter,
1384 Value originalOperand,
1385 Value convertedOperand,
1387 Value *offset =
nullptr) {
1389 if (isa<MemRefType>(operandType)) {
1392 *alignedPtr = desc.
alignedPtr(rewriter, loc);
1393 if (offset !=
nullptr)
1394 *offset = desc.
offset(rewriter, loc);
1400 cast<UnrankedMemRefType>(operandType));
1401 auto elementPtrType =
1402 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1410 rewriter, loc, underlyingDescPtr, elementPtrType);
1412 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1413 if (offset !=
nullptr) {
1415 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1419struct MemRefReinterpretCastOpLowering
1425 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter)
const override {
1427 Type srcType = castOp.getSource().getType();
1430 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1431 adaptor, &descriptor)))
1433 rewriter.replaceOp(castOp, {descriptor});
1438 LogicalResult convertSourceMemRefToDescriptor(
1439 ConversionPatternRewriter &rewriter,
Type srcType,
1440 memref::ReinterpretCastOp castOp,
1441 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1442 MemRefType targetMemRefType =
1443 cast<MemRefType>(castOp.getResult().getType());
1444 auto llvmTargetDescriptorTy =
1445 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1446 if (!llvmTargetDescriptorTy)
1454 Value allocatedPtr, alignedPtr;
1455 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1456 castOp.getSource(), adaptor.getSource(),
1457 &allocatedPtr, &alignedPtr);
1458 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1459 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1462 if (castOp.isDynamicOffset(0))
1463 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1465 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1468 unsigned dynSizeId = 0;
1469 unsigned dynStrideId = 0;
1470 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1471 if (castOp.isDynamicSize(i))
1472 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1474 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1476 if (castOp.isDynamicStride(i))
1477 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1479 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1486struct MemRefReshapeOpLowering
1491 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1492 ConversionPatternRewriter &rewriter)
const override {
1493 Type srcType = reshapeOp.getSource().getType();
1496 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1497 adaptor, &descriptor)))
1499 rewriter.replaceOp(reshapeOp, {descriptor});
1505 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1506 Type srcType, memref::ReshapeOp reshapeOp,
1507 memref::ReshapeOp::Adaptor adaptor,
1508 Value *descriptor)
const {
1509 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1510 if (shapeMemRefType.hasStaticShape()) {
1511 MemRefType targetMemRefType =
1512 cast<MemRefType>(reshapeOp.getResult().getType());
1513 auto llvmTargetDescriptorTy =
1514 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1515 if (!llvmTargetDescriptorTy)
1524 Value allocatedPtr, alignedPtr;
1525 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1526 reshapeOp.getSource(), adaptor.getSource(),
1527 &allocatedPtr, &alignedPtr);
1528 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1529 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1534 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1535 return rewriter.notifyMatchFailure(
1536 reshapeOp,
"failed to get stride and offset exprs");
1538 if (!isStaticStrideOrOffset(offset))
1539 return rewriter.notifyMatchFailure(reshapeOp,
1540 "dynamic offset is unsupported");
1542 desc.setConstantOffset(rewriter, loc, offset);
1544 assert(targetMemRefType.getLayout().isIdentity() &&
1545 "Identity layout map is a precondition of a valid reshape op");
1547 Type indexType = getIndexType();
1548 Value stride =
nullptr;
1549 int64_t targetRank = targetMemRefType.getRank();
1550 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1551 if (ShapedType::isStatic(strides[i])) {
1556 }
else if (!stride) {
1566 if (!targetMemRefType.isDynamicDim(i)) {
1568 targetMemRefType.getDimSize(i));
1570 Value shapeOp = reshapeOp.getShape();
1572 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp,
index);
1573 Type indexType = getIndexType();
1574 if (dimSize.
getType() != indexType)
1575 dimSize = typeConverter->materializeTargetConversion(
1576 rewriter, loc, indexType, dimSize);
1577 assert(dimSize &&
"Invalid memref element type");
1580 desc.setSize(rewriter, loc, i, dimSize);
1581 desc.setStride(rewriter, loc, i, stride);
1584 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1594 Value resultRank = shapeDesc.
size(rewriter, loc, 0);
1597 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1598 unsigned addressSpace =
1599 *getTypeConverter()->getMemRefAddressSpace(targetType);
1604 rewriter, loc, typeConverter->convertType(targetType));
1605 targetDesc.setRank(rewriter, loc, resultRank);
1607 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1608 Value underlyingDescPtr = LLVM::AllocaOp::create(
1609 rewriter, loc, getPtrType(), IntegerType::get(
getContext(), 8),
1611 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1614 Value allocatedPtr, alignedPtr, offset;
1615 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1616 reshapeOp.getSource(), adaptor.getSource(),
1617 &allocatedPtr, &alignedPtr, &offset);
1620 auto elementPtrType =
1621 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1624 elementPtrType, allocatedPtr);
1626 underlyingDescPtr, elementPtrType,
1629 underlyingDescPtr, elementPtrType,
1635 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1637 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1640 Value resultRankMinusOne =
1641 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1643 Block *initBlock = rewriter.getInsertionBlock();
1644 Type indexType = getTypeConverter()->getIndexType();
1645 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1647 Block *condBlock = rewriter.createBlock(initBlock->
getParent(), {},
1648 {indexType, indexType}, {loc, loc});
1651 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1652 rewriter.mergeBlocks(remainingBlock, condBlock,
ValueRange());
1654 rewriter.setInsertionPointToEnd(initBlock);
1655 LLVM::BrOp::create(rewriter, loc,
1656 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1657 rewriter.setInsertionPointToStart(condBlock);
1662 Value pred = LLVM::ICmpOp::create(
1663 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1664 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1667 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1668 rewriter.setInsertionPointToStart(bodyBlock);
1671 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1672 Value sizeLoadGep = LLVM::GEPOp::create(
1673 rewriter, loc, llvmIndexPtrType,
1674 typeConverter->convertType(shapeMemRefType.getElementType()),
1675 shapeOperandPtr, indexArg);
1676 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1678 targetSizesBase, indexArg, size);
1682 targetStridesBase, indexArg, strideArg);
1683 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1686 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1687 LLVM::BrOp::create(rewriter, loc,
ValueRange({decrement, nextStride}),
1691 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1694 rewriter.setInsertionPointToEnd(condBlock);
1695 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock,
ValueRange(),
1699 rewriter.setInsertionPointToStart(remainder);
1701 *descriptor = targetDesc;
1708template <
typename ReshapeOp>
1709class ReassociatingReshapeOpConversion
1713 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1716 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1717 ConversionPatternRewriter &rewriter)
const override {
1718 return rewriter.notifyMatchFailure(
1720 "reassociation operations should have been expanded beforehand");
1730 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1731 ConversionPatternRewriter &rewriter)
const override {
1732 return rewriter.notifyMatchFailure(
1733 subViewOp,
"subview operations should have been expanded beforehand");
1749 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1750 ConversionPatternRewriter &rewriter)
const override {
1751 auto loc = transposeOp.getLoc();
1755 if (transposeOp.getPermutation().isIdentity())
1756 return rewriter.replaceOp(transposeOp, {viewMemRef}),
success();
1760 typeConverter->convertType(transposeOp.getIn().getType()));
1764 targetMemRef.setAllocatedPtr(rewriter, loc,
1766 targetMemRef.setAlignedPtr(rewriter, loc,
1770 targetMemRef.setOffset(rewriter, loc, viewMemRef.
offset(rewriter, loc));
1776 for (
const auto &en :
1777 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1778 int targetPos = en.index();
1779 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1780 targetMemRef.setSize(rewriter, loc, targetPos,
1781 viewMemRef.
size(rewriter, loc, sourcePos));
1782 targetMemRef.setStride(rewriter, loc, targetPos,
1783 viewMemRef.
stride(rewriter, loc, sourcePos));
1786 rewriter.replaceOp(transposeOp, {targetMemRef});
1801 Value getSize(ConversionPatternRewriter &rewriter,
Location loc,
1803 Type indexType)
const {
1804 assert(idx <
shape.size());
1805 if (ShapedType::isStatic(
shape[idx]))
1809 llvm::count_if(
shape.take_front(idx), ShapedType::isDynamic);
1810 return dynamicSizes[nDynamic];
1817 Value getStride(ConversionPatternRewriter &rewriter,
Location loc,
1819 Value runningStride,
unsigned idx,
Type indexType)
const {
1820 assert(idx < strides.size());
1821 if (ShapedType::isStatic(strides[idx]))
1824 return runningStride
1825 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1827 assert(!runningStride);
1832 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1833 ConversionPatternRewriter &rewriter)
const override {
1834 auto loc = viewOp.getLoc();
1836 auto viewMemRefType = viewOp.getType();
1837 auto targetElementTy =
1838 typeConverter->convertType(viewMemRefType.getElementType());
1839 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1840 if (!targetDescTy || !targetElementTy ||
1843 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1848 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1849 if (failed(successStrides))
1850 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1851 assert(offset == 0 &&
"expected offset to be 0");
1855 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1856 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1865 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1866 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1870 alignedPtr = LLVM::GEPOp::create(
1871 rewriter, loc, alignedPtr.
getType(),
1872 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1873 adaptor.getByteShift());
1875 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1877 Type indexType = getIndexType();
1881 targetMemRef.setOffset(
1886 if (viewMemRefType.getRank() == 0)
1887 return rewriter.replaceOp(viewOp, {targetMemRef}),
success();
1890 Value stride =
nullptr, nextSize =
nullptr;
1891 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1893 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1894 adaptor.getSizes(), i, indexType);
1895 targetMemRef.setSize(rewriter, loc, i, size);
1898 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1899 targetMemRef.setStride(rewriter, loc, i, stride);
1903 rewriter.replaceOp(viewOp, {targetMemRef});
1914static std::optional<LLVM::AtomicBinOp>
1915matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1916 switch (atomicOp.getKind()) {
1917 case arith::AtomicRMWKind::addf:
1918 return LLVM::AtomicBinOp::fadd;
1919 case arith::AtomicRMWKind::addi:
1920 return LLVM::AtomicBinOp::add;
1921 case arith::AtomicRMWKind::assign:
1922 return LLVM::AtomicBinOp::xchg;
1923 case arith::AtomicRMWKind::maximumf:
1925 LDBG() <<
"the lowering of memref.atomicrmw maximumf changed "
1926 "from fmax to fmaximum, expect more NaNs";
1927 return LLVM::AtomicBinOp::fmaximum;
1928 case arith::AtomicRMWKind::maxnumf:
1929 return LLVM::AtomicBinOp::fmax;
1930 case arith::AtomicRMWKind::maxs:
1931 return LLVM::AtomicBinOp::max;
1932 case arith::AtomicRMWKind::maxu:
1933 return LLVM::AtomicBinOp::umax;
1934 case arith::AtomicRMWKind::minimumf:
1936 LDBG() <<
"the lowering of memref.atomicrmw minimum changed "
1937 "from fmin to fminimum, expect more NaNs";
1938 return LLVM::AtomicBinOp::fminimum;
1939 case arith::AtomicRMWKind::minnumf:
1940 return LLVM::AtomicBinOp::fmin;
1941 case arith::AtomicRMWKind::mins:
1942 return LLVM::AtomicBinOp::min;
1943 case arith::AtomicRMWKind::minu:
1944 return LLVM::AtomicBinOp::umin;
1945 case arith::AtomicRMWKind::ori:
1946 return LLVM::AtomicBinOp::_or;
1947 case arith::AtomicRMWKind::xori:
1948 return LLVM::AtomicBinOp::_xor;
1949 case arith::AtomicRMWKind::andi:
1950 return LLVM::AtomicBinOp::_and;
1952 return std::nullopt;
1954 llvm_unreachable(
"Invalid AtomicRMWKind");
1957struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1961 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1962 ConversionPatternRewriter &rewriter)
const override {
1963 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1966 auto memRefType = atomicOp.getMemRefType();
1969 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1972 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1973 adaptor.getMemref(), adaptor.getIndices());
1974 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1975 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1976 LLVM::AtomicOrdering::acq_rel);
1982class ConvertExtractAlignedPointerAsIndex
1989 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1991 ConversionPatternRewriter &rewriter)
const override {
1997 alignedPtr = desc.
alignedPtr(rewriter, extractOp->getLoc());
1999 auto elementPtrTy = LLVM::LLVMPointerType::get(
2006 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
2010 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
2011 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
2018class ExtractStridedMetadataOpLowering
2025 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
2027 ConversionPatternRewriter &rewriter)
const override {
2034 Location loc = extractStridedMetadataOp.getLoc();
2035 Value source = extractStridedMetadataOp.getSource();
2037 auto sourceMemRefType = cast<MemRefType>(source.
getType());
2038 int64_t rank = sourceMemRefType.getRank();
2040 results.reserve(2 + rank * 2);
2046 rewriter, loc, *getTypeConverter(),
2047 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2048 baseBuffer, alignedBuffer);
2049 results.push_back((
Value)dstMemRef);
2052 results.push_back(sourceMemRef.
offset(rewriter, loc));
2055 for (
unsigned i = 0; i < rank; ++i)
2056 results.push_back(sourceMemRef.
size(rewriter, loc, i));
2058 for (
unsigned i = 0; i < rank; ++i)
2059 results.push_back(sourceMemRef.
stride(rewriter, loc, i));
2061 rewriter.replaceOp(extractStridedMetadataOp, results);
2074 AllocaScopeOpLowering,
2075 AssumeAlignmentOpLowering,
2076 AtomicRMWOpLowering,
2077 ConvertExtractAlignedPointerAsIndex,
2079 DistinctObjectsOpLowering,
2080 ExtractStridedMetadataOpLowering,
2081 GenericAtomicRMWOpLowering,
2082 GetGlobalMemrefOpLowering,
2084 MemRefCastOpLowering,
2085 MemRefReinterpretCastOpLowering,
2086 MemRefReshapeOpLowering,
2087 MemorySpaceCastOpLowering,
2090 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2091 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2095 ViewOpLowering>(converter);
2097 patterns.
add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2101 patterns.
add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2104 patterns.
add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2108struct FinalizeMemRefToLLVMConversionPass
2110 FinalizeMemRefToLLVMConversionPass> {
2111 using FinalizeMemRefToLLVMConversionPassBase::
2112 FinalizeMemRefToLLVMConversionPassBase;
2114 void runOnOperation()
override {
2116 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2118 dataLayoutAnalysis.getAtOrAbove(op));
2123 options.useGenericFunctions = useGenericFunctions;
2126 options.overrideIndexBitwidth(indexBitwidth);
2129 &dataLayoutAnalysis);
2135 target.addLegalOp<func::FuncOp>();
2136 if (failed(applyPartialConversion(op,
target, std::move(patterns))))
2137 signalPassFailure();
2142struct MemRefToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
2143 MemRefToLLVMDialectInterface(Dialect *dialect)
2144 : ConvertToLLVMPatternInterface(dialect) {}
2146 void loadDependentDialects(MLIRContext *context)
const final {
2147 context->loadDialect<LLVM::LLVMDialect>();
2152 void populateConvertToLLVMConversionPatterns(
2153 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
2154 RewritePatternSet &patterns)
const final {
2163 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.