26 #include "llvm/ADT/SmallBitVector.h"
30 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
31 #include "mlir/Conversion/Passes.h.inc"
38 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
39 return !ShapedType::isDynamic(strideOrOffset);
59 return allocateBufferManuallyAlign(
60 rewriter, loc, sizeBytes, op,
61 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
72 Value ptr = allocateBufferAutoAlign(
73 rewriter, loc, sizeBytes, op, &defaultLayout,
74 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
76 return std::make_tuple(ptr, ptr);
98 auto allocaOp = cast<memref::AllocaOp>(op);
100 typeConverter->
convertType(allocaOp.getType().getElementType());
102 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
103 auto elementPtrType =
104 getTypeConverter()->getPointerType(elementType, addrSpace);
106 auto allocatedElementPtr = rewriter.
create<LLVM::AllocaOp>(
107 loc, elementPtrType, elementType, sizeBytes,
108 allocaOp.getAlignment().value_or(0));
110 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
119 using OpAdaptor =
typename memref::ReallocOp::Adaptor;
127 virtual std::tuple<Value, Value>
129 Value sizeBytes, memref::ReallocOp op)
const = 0;
134 return matchAndRewrite(cast<memref::ReallocOp>(op),
149 LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor,
154 auto computeNumElements =
157 int64_t size = type.getShape()[0];
158 Value numElements = ((size == ShapedType::kDynamic)
160 : createIndexConstant(rewriter, loc, size));
162 if (numElements.
getType() != indexType)
164 rewriter, loc, indexType, numElements);
169 Value oldDesc = desc;
186 Value src = op.getSource();
188 Value srcNumElements = computeNumElements(
189 srcType, [&]() ->
Value {
return desc.size(rewriter, loc, 0); });
191 Value dstNumElements = computeNumElements(
192 dstType, [&]() ->
Value {
return op.getDynamicResultSize(); });
194 loc, IntegerType::get(rewriter.
getContext(), 1),
195 LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements);
200 Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter);
203 rewriter.
create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
208 rewriter.
create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
210 auto [dstRawPtr, dstAlignedPtr] =
211 allocateBuffer(rewriter, loc, dstByteSize, op);
213 Value srcAlignedPtr = desc.alignedPtr(rewriter, loc);
217 if (getTypeConverter()->useOpaquePointers())
219 return rewriter.
create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
221 rewriter.
create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
222 toVoidPtr(srcAlignedPtr), srcByteSize,
225 LLVM::LLVMFuncOp freeFunc =
226 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
227 rewriter.
create<LLVM::CallOp>(loc, freeFunc,
228 toVoidPtr(desc.allocatedPtr(rewriter, loc)));
231 desc.setAllocatedPtr(rewriter, loc, dstRawPtr);
232 desc.setAlignedPtr(rewriter, loc, dstAlignedPtr);
233 rewriter.
create<LLVM::BrOp>(loc,
Value(desc), endBlock);
238 newDesc.setSize(rewriter, loc, 0, dstNumElements);
247 struct ReallocOpLowering :
public ReallocOpLoweringBase {
249 : ReallocOpLoweringBase(converter) {}
252 memref::ReallocOp op)
const override {
253 return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
254 getAlignment(rewriter, loc, op));
258 struct AlignedReallocOpLowering :
public ReallocOpLoweringBase {
260 : ReallocOpLoweringBase(converter) {}
263 memref::ReallocOp op)
const override {
264 Value ptr = allocateBufferAutoAlign(
265 rewriter, loc, sizeBytes, op, &defaultLayout,
266 alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout));
267 return std::make_tuple(ptr, ptr);
275 struct AllocaScopeOpLowering
280 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
283 Location loc = allocaScopeOp.getLoc();
288 auto *remainingOpsBlock =
290 Block *continueBlock;
291 if (allocaScopeOp.getNumResults() == 0) {
292 continueBlock = remainingOpsBlock;
295 remainingOpsBlock, allocaScopeOp.getResultTypes(),
297 allocaScopeOp.getLoc()));
302 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
303 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
309 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
316 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
318 returnOp, returnOp.getResults(), continueBlock);
322 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
331 struct AssumeAlignmentOpLowering
337 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
339 Value memref = adaptor.getMemref();
340 unsigned alignment = op.getAlignment();
344 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.
getLoc());
355 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
359 Value ptrValue = rewriter.
create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
360 rewriter.
create<LLVM::AssumeOp>(
361 loc, rewriter.
create<LLVM::ICmpOp>(
362 loc, LLVM::ICmpPredicate::eq,
363 rewriter.
create<LLVM::AndOp>(loc, ptrValue, mask), zero));
380 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
383 LLVM::LLVMFuncOp freeFunc =
384 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
386 if (
auto unrankedTy =
387 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
388 Type elementType = unrankedTy.getElementType();
389 Type llvmElementTy = getTypeConverter()->convertType(elementType);
390 LLVM::LLVMPointerType elementPtrTy = getTypeConverter()->getPointerType(
391 llvmElementTy, unrankedTy.getMemorySpaceAsInt());
393 rewriter, op.getLoc(),
401 if (!getTypeConverter()->useOpaquePointers())
402 allocatedPtr = rewriter.
create<LLVM::BitcastOp>(
403 op.getLoc(), getVoidPtrType(), allocatedPtr);
416 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
418 Type operandType = dimOp.getSource().getType();
421 operandType, dimOp, adaptor.getOperands(), rewriter);
422 if (
failed(extractedSize))
424 rewriter.
replaceOp(dimOp, {*extractedSize});
427 if (operandType.
isa<MemRefType>()) {
429 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
430 adaptor.getOperands(), rewriter)});
433 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
438 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
444 auto scalarMemRefType =
445 MemRefType::get({}, unrankedMemRefType.getElementType());
447 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
448 if (
failed(maybeAddressSpace)) {
449 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
453 unsigned addressSpace = *maybeAddressSpace;
459 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
462 Value scalarMemRefDescPtr;
463 if (getTypeConverter()->useOpaquePointers())
464 scalarMemRefDescPtr = underlyingRankedDesc;
466 scalarMemRefDescPtr = rewriter.
create<LLVM::BitcastOp>(
467 loc, LLVM::LLVMPointerType::get(elementType, addressSpace),
468 underlyingRankedDesc);
471 Type indexPtrTy = getTypeConverter()->getPointerType(
474 loc, indexPtrTy, elementType, scalarMemRefDescPtr,
480 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
482 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
485 .
create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
489 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
490 if (
auto idx = dimOp.getConstantIndex())
493 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
494 return constantOp.getValue()
502 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
508 MemRefType memRefType = operandType.
cast<MemRefType>();
509 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
511 if (i >= 0 && i < memRefType.getRank()) {
512 if (memRefType.isDynamicDim(i)) {
515 return descriptor.size(rewriter, loc, i);
518 int64_t dimSize = memRefType.getDimSize(i);
519 return createIndexConstant(rewriter, loc, dimSize);
522 Value index = adaptor.getIndex();
523 int64_t rank = memRefType.getRank();
525 return memrefDescriptor.size(rewriter, loc, index, rank);
532 template <
typename Derived>
536 using Base = LoadStoreOpLowering<Derived>;
539 MemRefType type = op.getMemRefType();
540 return isConvertibleAndHasIdentityMaps(type) ?
success() :
failure();
571 struct GenericAtomicRMWOpLowering
572 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
576 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
578 auto loc = atomicOp.getLoc();
579 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
591 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
592 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
593 adaptor.getIndices(), rewriter);
595 loc, typeConverter->
convertType(memRefType.getElementType()), dataPtr);
596 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
602 auto loopArgument = loopBlock->getArgument(0);
604 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
614 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
615 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
616 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
617 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
619 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
620 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
624 loopBlock, newLoaded);
629 rewriter.
replaceOp(atomicOp, {newLoaded});
636 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
644 Type arrayTy = elementType;
646 for (int64_t dim : llvm::reverse(type.getShape()))
647 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
652 struct GlobalMemrefOpLowering
657 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
659 MemRefType type = global.getType();
660 if (!isConvertibleAndHasIdentityMaps(type))
663 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
665 LLVM::Linkage linkage =
666 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
669 if (!global.isExternal() && !global.isUninitialized()) {
670 auto elementsAttr = global.getInitialValue()->
cast<ElementsAttr>();
671 initialValue = elementsAttr;
675 if (type.getRank() == 0)
676 initialValue = elementsAttr.getSplatValue<
Attribute>();
679 uint64_t alignment = global.getAlignment().value_or(0);
681 getTypeConverter()->getMemRefAddressSpace(type);
683 return global.emitOpError(
684 "memory space cannot be converted to an integer address space");
686 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
687 initialValue, alignment, *addressSpace);
688 if (!global.isExternal() && global.isUninitialized()) {
690 newGlobal.getInitializerRegion().push_back(blk);
693 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
694 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
713 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
714 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
718 unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
720 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
721 Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
723 rewriter.
create<LLVM::AddressOfOp>(loc, resTy, getGlobalOp.getName());
728 Type elementPtrType =
729 getTypeConverter()->getPointerType(elementType, memSpace);
731 auto gep = rewriter.
create<LLVM::GEPOp>(
732 loc, elementPtrType, arrayTy, addressOf,
738 auto intPtrType = getIntPtrType(memSpace);
739 Value deadBeefConst =
742 rewriter.
create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
746 return std::make_tuple(deadBeefPtr, gep);
752 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
756 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
758 auto type = loadOp.getMemRefType();
761 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
762 adaptor.getIndices(), rewriter);
764 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr, 0,
765 false, loadOp.getNontemporal());
772 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
776 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
778 auto type = op.getMemRefType();
780 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
781 adaptor.getIndices(), rewriter);
783 0,
false, op.getNontemporal());
790 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
794 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
796 auto type = prefetchOp.getMemRefType();
797 auto loc = prefetchOp.getLoc();
799 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
800 adaptor.getIndices(), rewriter);
804 auto isWrite = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type,
805 prefetchOp.getIsWrite());
806 auto localityHint = rewriter.
create<LLVM::ConstantOp>(
807 loc, llvmI32Type, prefetchOp.getLocalityHint());
808 auto isData = rewriter.
create<LLVM::ConstantOp>(
809 loc, llvmI32Type, prefetchOp.getIsDataCache());
812 localityHint, isData);
821 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
824 Type operandType = op.getMemref().getType();
827 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
830 if (
auto rankedMemRefType = operandType.
dyn_cast<MemRefType>()) {
832 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
842 LogicalResult match(memref::CastOp memRefCastOp)
const override {
843 Type srcType = memRefCastOp.getOperand().getType();
844 Type dstType = memRefCastOp.getType();
851 if (srcType.
isa<MemRefType>() && dstType.
isa<MemRefType>())
866 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
868 auto srcType = memRefCastOp.getOperand().getType();
869 auto dstType = memRefCastOp.getType();
870 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
871 auto loc = memRefCastOp.getLoc();
874 if (srcType.
isa<MemRefType>() && dstType.
isa<MemRefType>())
875 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
882 auto srcMemRefType = srcType.
cast<MemRefType>();
883 int64_t rank = srcMemRefType.getRank();
885 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
886 loc, adaptor.getSource(), rewriter);
890 if (getTypeConverter()->useOpaquePointers())
893 voidPtr = rewriter.
create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
896 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
902 memRefDesc.setRank(rewriter, loc, rankVal);
904 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
913 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
916 if (getTypeConverter()->useOpaquePointers())
919 castPtr = rewriter.
create<LLVM::BitcastOp>(
920 loc, LLVM::LLVMPointerType::get(targetStructType), ptr);
924 rewriter.
create<LLVM::LoadOp>(loc, targetStructType, castPtr);
925 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
927 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
941 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
943 auto loc = op.getLoc();
944 auto srcType = op.getSource().getType().
dyn_cast<MemRefType>();
949 Value numElements = rewriter.
create<LLVM::ConstantOp>(
951 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
952 auto size = srcDesc.size(rewriter, loc, pos);
953 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
957 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
960 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
962 Type elementType = typeConverter->
convertType(srcType.getElementType());
964 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
965 Value srcOffset = srcDesc.offset(rewriter, loc);
967 loc, srcBasePtr.
getType(), elementType, srcBasePtr, srcOffset);
969 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
970 Value targetOffset = targetDesc.offset(rewriter, loc);
972 loc, targetBasePtr.
getType(), elementType, targetBasePtr, targetOffset);
975 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
983 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
985 auto loc = op.getLoc();
993 auto *typeConverter = getTypeConverter();
995 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
998 if (getTypeConverter()->useOpaquePointers())
1001 voidPtr = rewriter.
create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
1004 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1012 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1014 Value unrankedSource = srcType.hasRank()
1015 ? makeUnranked(adaptor.getSource(), srcType)
1016 : adaptor.getSource();
1018 ? makeUnranked(adaptor.getTarget(), targetType)
1019 : adaptor.getTarget();
1025 Type ptrType = getTypeConverter()->getPointerType(desc.getType());
1027 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1028 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
1032 auto sourcePtr =
promote(unrankedSource);
1033 auto targetPtr =
promote(unrankedTarget);
1037 auto elemSize = rewriter.
create<LLVM::ConstantOp>(
1040 op->getParentOfType<ModuleOp>(),
getIndexType(), sourcePtr.getType());
1041 rewriter.
create<LLVM::CallOp>(loc, copyFn,
1045 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1053 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1058 auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
1059 if (!type.hasStaticShape())
1067 int64_t runningStride = 1;
1068 for (
unsigned i = strides.size(); i > 0; --i) {
1069 if (strides[i - 1] != runningStride)
1071 runningStride *= type.getDimSize(i - 1);
1077 auto memrefType = type.dyn_cast<mlir::MemRefType>();
1081 return memrefType &&
1082 (memrefType.getLayout().isIdentity() ||
1083 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1084 isStaticShapeAndContiguousRowMajor(memrefType)));
1087 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1088 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1090 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1094 struct MemorySpaceCastOpLowering
1100 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1104 Type resultType = op.getDest().getType();
1105 if (
auto resultTypeR = resultType.
dyn_cast<MemRefType>()) {
1106 auto resultDescType =
1108 Type newPtrType = resultDescType.getBody()[0];
1114 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1116 rewriter.
create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1118 resultTypeR, descVals);
1127 getTypeConverter()->getMemRefAddressSpace(sourceType);
1128 if (
failed(maybeSourceAddrSpace))
1130 "non-integer source address space");
1131 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1133 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1134 if (
failed(maybeResultAddrSpace))
1136 "non-integer result address space");
1137 unsigned resultAddrSpace = *maybeResultAddrSpace;
1140 Value rank = sourceDesc.rank(rewriter, loc);
1141 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1145 rewriter, loc, typeConverter->convertType(resultTypeU));
1146 result.setRank(rewriter, loc, rank);
1149 result, resultAddrSpace, sizes);
1150 Value resultUnderlyingSize = sizes.front();
1151 Value resultUnderlyingDesc = rewriter.
create<LLVM::AllocaOp>(
1152 loc, getVoidPtrType(), rewriter.
getI8Type(), resultUnderlyingSize);
1153 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1156 Type llvmElementType =
1157 typeConverter->convertType(sourceType.getElementType());
1158 LLVM::LLVMPointerType sourceElemPtrType =
1159 getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace);
1160 auto resultElemPtrType =
1161 getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace);
1163 Value allocatedPtr = sourceDesc.allocatedPtr(
1164 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1166 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1167 sourceUnderlyingDesc, sourceElemPtrType);
1168 allocatedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1169 loc, resultElemPtrType, allocatedPtr);
1170 alignedPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1171 loc, resultElemPtrType, alignedPtr);
1173 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1174 resultElemPtrType, allocatedPtr);
1175 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1176 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1179 Value sourceIndexVals =
1180 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1181 sourceUnderlyingDesc, sourceElemPtrType);
1182 Value resultIndexVals =
1183 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1184 resultUnderlyingDesc, resultElemPtrType);
1186 int64_t bytesToSkip =
1188 ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1189 Value bytesToSkipConst = rewriter.
create<LLVM::ConstantOp>(
1192 loc,
getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1193 Type llvmBool = typeConverter->convertType(rewriter.
getI1Type());
1194 Value nonVolatile = rewriter.
create<LLVM::ConstantOp>(
1196 rewriter.
create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1197 copySize, nonVolatile);
1209 static void extractPointersAndOffset(
Location loc,
1212 Value originalOperand,
1213 Value convertedOperand,
1215 Value *offset =
nullptr) {
1217 if (operandType.
isa<MemRefType>()) {
1219 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1220 *alignedPtr = desc.alignedPtr(rewriter, loc);
1221 if (offset !=
nullptr)
1222 *offset = desc.offset(rewriter, loc);
1231 LLVM::LLVMPointerType elementPtrType =
1237 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1240 rewriter, loc, underlyingDescPtr, elementPtrType);
1242 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1243 if (offset !=
nullptr) {
1245 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1249 struct MemRefReinterpretCastOpLowering
1255 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1257 Type srcType = castOp.getSource().getType();
1260 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1261 adaptor, &descriptor)))
1263 rewriter.
replaceOp(castOp, {descriptor});
1270 memref::ReinterpretCastOp castOp,
1271 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1272 MemRefType targetMemRefType =
1273 castOp.getResult().getType().cast<MemRefType>();
1274 auto llvmTargetDescriptorTy = typeConverter->
convertType(targetMemRefType)
1276 if (!llvmTargetDescriptorTy)
1284 Value allocatedPtr, alignedPtr;
1285 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1286 castOp.getSource(), adaptor.getSource(),
1287 &allocatedPtr, &alignedPtr);
1288 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1289 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1292 if (castOp.isDynamicOffset(0))
1293 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1295 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1298 unsigned dynSizeId = 0;
1299 unsigned dynStrideId = 0;
1300 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1301 if (castOp.isDynamicSize(i))
1302 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1304 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1306 if (castOp.isDynamicStride(i))
1307 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1309 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1316 struct MemRefReshapeOpLowering
1321 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1323 Type srcType = reshapeOp.getSource().getType();
1326 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1327 adaptor, &descriptor)))
1329 rewriter.
replaceOp(reshapeOp, {descriptor});
1336 Type srcType, memref::ReshapeOp reshapeOp,
1337 memref::ReshapeOp::Adaptor adaptor,
1338 Value *descriptor)
const {
1339 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1340 if (shapeMemRefType.hasStaticShape()) {
1341 MemRefType targetMemRefType =
1342 reshapeOp.getResult().getType().cast<MemRefType>();
1343 auto llvmTargetDescriptorTy =
1346 if (!llvmTargetDescriptorTy)
1355 Value allocatedPtr, alignedPtr;
1356 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1357 reshapeOp.getSource(), adaptor.getSource(),
1358 &allocatedPtr, &alignedPtr);
1359 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1360 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1367 reshapeOp,
"failed to get stride and offset exprs");
1369 if (!isStaticStrideOrOffset(offset))
1371 "dynamic offset is unsupported");
1373 desc.setConstantOffset(rewriter, loc, offset);
1375 assert(targetMemRefType.getLayout().isIdentity() &&
1376 "Identity layout map is a precondition of a valid reshape op");
1378 Value stride =
nullptr;
1379 int64_t targetRank = targetMemRefType.getRank();
1380 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1381 if (!ShapedType::isDynamic(strides[i])) {
1384 stride = createIndexConstant(rewriter, loc, strides[i]);
1385 }
else if (!stride) {
1389 stride = createIndexConstant(rewriter, loc, 1);
1393 int64_t size = targetMemRefType.getDimSize(i);
1396 if (!ShapedType::isDynamic(size)) {
1397 dimSize = createIndexConstant(rewriter, loc, size);
1399 Value shapeOp = reshapeOp.getShape();
1400 Value index = createIndexConstant(rewriter, loc, i);
1401 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1403 if (dimSize.
getType() != indexType)
1405 rewriter, loc, indexType, dimSize);
1406 assert(dimSize &&
"Invalid memref element type");
1409 desc.setSize(rewriter, loc, i, dimSize);
1410 desc.setStride(rewriter, loc, i, stride);
1413 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1423 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1428 unsigned addressSpace =
1429 *getTypeConverter()->getMemRefAddressSpace(targetType);
1435 rewriter, loc, typeConverter->
convertType(targetType));
1436 targetDesc.setRank(rewriter, loc, resultRank);
1439 targetDesc, addressSpace, sizes);
1440 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1441 loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1443 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1446 Value allocatedPtr, alignedPtr, offset;
1447 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1448 reshapeOp.getSource(), adaptor.getSource(),
1449 &allocatedPtr, &alignedPtr, &offset);
1453 LLVM::LLVMPointerType elementPtrType =
1454 getTypeConverter()->getPointerType(llvmElementType, addressSpace);
1457 elementPtrType, allocatedPtr);
1459 underlyingDescPtr, elementPtrType,
1462 underlyingDescPtr, elementPtrType,
1468 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1470 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1471 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1472 Value oneIndex = createIndexConstant(rewriter, loc, 1);
1473 Value resultRankMinusOne =
1474 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1477 Type indexType = getTypeConverter()->getIndexType();
1481 {indexType, indexType}, {loc, loc});
1484 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1488 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1494 Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1496 loc, IntegerType::get(rewriter.
getContext(), 1),
1497 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1504 Type llvmIndexPtrType = getTypeConverter()->getPointerType(indexType);
1506 loc, llvmIndexPtrType,
1507 typeConverter->
convertType(shapeMemRefType.getElementType()),
1508 shapeOperandPtr, indexArg);
1509 Value size = rewriter.
create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1511 targetSizesBase, indexArg, size);
1515 targetStridesBase, indexArg, strideArg);
1516 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1519 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1528 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1529 remainder, std::nullopt);
1534 *descriptor = targetDesc;
1541 template <
typename ReshapeOp>
1542 class ReassociatingReshapeOpConversion
1546 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1549 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1553 "reassociation operations should have been expanded beforehand");
1563 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1566 subViewOp,
"subview operations should have been expanded beforehand");
1582 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1584 auto loc = transposeOp.getLoc();
1588 if (transposeOp.getPermutation().isIdentity())
1592 rewriter, loc, typeConverter->
convertType(transposeOp.getShapedType()));
1596 targetMemRef.setAllocatedPtr(rewriter, loc,
1597 viewMemRef.allocatedPtr(rewriter, loc));
1598 targetMemRef.setAlignedPtr(rewriter, loc,
1599 viewMemRef.alignedPtr(rewriter, loc));
1602 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1605 for (
const auto &en :
1607 int sourcePos = en.index();
1608 int targetPos = en.value().cast<
AffineDimExpr>().getPosition();
1609 targetMemRef.setSize(rewriter, loc, targetPos,
1610 viewMemRef.size(rewriter, loc, sourcePos));
1611 targetMemRef.setStride(rewriter, loc, targetPos,
1612 viewMemRef.stride(rewriter, loc, sourcePos));
1615 rewriter.
replaceOp(transposeOp, {targetMemRef});
1632 unsigned idx)
const {
1633 assert(idx < shape.size());
1634 if (!ShapedType::isDynamic(shape[idx]))
1635 return createIndexConstant(rewriter, loc, shape[idx]);
1638 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1639 return dynamicSizes[nDynamic];
1648 Value runningStride,
unsigned idx)
const {
1649 assert(idx < strides.size());
1650 if (!ShapedType::isDynamic(strides[idx]))
1651 return createIndexConstant(rewriter, loc, strides[idx]);
1653 return runningStride
1654 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1656 assert(!runningStride);
1657 return createIndexConstant(rewriter, loc, 1);
1661 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1663 auto loc = viewOp.getLoc();
1665 auto viewMemRefType = viewOp.getType();
1666 auto targetElementTy =
1667 typeConverter->
convertType(viewMemRefType.getElementType());
1668 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1669 if (!targetDescTy || !targetElementTy ||
1672 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1678 if (
failed(successStrides))
1679 return viewOp.emitWarning(
"cannot cast to non-strided shape"),
failure();
1680 assert(offset == 0 &&
"expected offset to be 0");
1684 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1685 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1693 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1694 auto srcMemRefType = viewOp.getSource().
getType().
cast<MemRefType>();
1695 unsigned sourceMemorySpace =
1696 *getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
1698 if (getTypeConverter()->useOpaquePointers())
1699 bitcastPtr = allocatedPtr;
1701 bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1702 loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
1705 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1708 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1709 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1711 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1712 adaptor.getByteShift());
1714 if (getTypeConverter()->useOpaquePointers()) {
1715 bitcastPtr = alignedPtr;
1717 bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1718 loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
1722 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1727 targetMemRef.setOffset(rewriter, loc,
1728 createIndexConstant(rewriter, loc, offset));
1731 if (viewMemRefType.getRank() == 0)
1735 Value stride =
nullptr, nextSize =
nullptr;
1736 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1738 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1739 adaptor.getSizes(), i);
1740 targetMemRef.setSize(rewriter, loc, i, size);
1742 stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1743 targetMemRef.setStride(rewriter, loc, i, stride);
1747 rewriter.
replaceOp(viewOp, {targetMemRef});
1758 static std::optional<LLVM::AtomicBinOp>
1759 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1760 switch (atomicOp.getKind()) {
1761 case arith::AtomicRMWKind::addf:
1762 return LLVM::AtomicBinOp::fadd;
1763 case arith::AtomicRMWKind::addi:
1764 return LLVM::AtomicBinOp::add;
1765 case arith::AtomicRMWKind::assign:
1766 return LLVM::AtomicBinOp::xchg;
1767 case arith::AtomicRMWKind::maxs:
1769 case arith::AtomicRMWKind::maxu:
1770 return LLVM::AtomicBinOp::umax;
1771 case arith::AtomicRMWKind::mins:
1773 case arith::AtomicRMWKind::minu:
1774 return LLVM::AtomicBinOp::umin;
1775 case arith::AtomicRMWKind::ori:
1776 return LLVM::AtomicBinOp::_or;
1777 case arith::AtomicRMWKind::andi:
1778 return LLVM::AtomicBinOp::_and;
1780 return std::nullopt;
1782 llvm_unreachable(
"Invalid AtomicRMWKind");
1785 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1789 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1791 if (
failed(match(atomicOp)))
1793 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1796 auto memRefType = atomicOp.getMemRefType();
1798 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1799 adaptor.getIndices(), rewriter);
1801 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1802 LLVM::AtomicOrdering::acq_rel);
1808 class ConvertExtractAlignedPointerAsIndex
1815 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1820 extractOp, getTypeConverter()->getIndexType(),
1821 desc.alignedPtr(rewriter, extractOp->getLoc()));
1828 class ExtractStridedMetadataOpLowering
1835 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1844 Location loc = extractStridedMetadataOp.getLoc();
1845 Value source = extractStridedMetadataOp.getSource();
1847 auto sourceMemRefType = source.
getType().
cast<MemRefType>();
1848 int64_t rank = sourceMemRefType.getRank();
1850 results.reserve(2 + rank * 2);
1853 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1854 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1856 rewriter, loc, *getTypeConverter(),
1857 extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
1858 baseBuffer, alignedBuffer);
1859 results.push_back((
Value)dstMemRef);
1862 results.push_back(sourceMemRef.offset(rewriter, loc));
1865 for (
unsigned i = 0; i < rank; ++i)
1866 results.push_back(sourceMemRef.size(rewriter, loc, i));
1868 for (
unsigned i = 0; i < rank; ++i)
1869 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1871 rewriter.
replaceOp(extractStridedMetadataOp, results);
1883 AllocaScopeOpLowering,
1884 AtomicRMWOpLowering,
1885 AssumeAlignmentOpLowering,
1886 ConvertExtractAlignedPointerAsIndex,
1888 ExtractStridedMetadataOpLowering,
1889 GenericAtomicRMWOpLowering,
1890 GlobalMemrefOpLowering,
1891 GetGlobalMemrefOpLowering,
1893 MemRefCastOpLowering,
1894 MemRefCopyOpLowering,
1895 MemorySpaceCastOpLowering,
1896 MemRefReinterpretCastOpLowering,
1897 MemRefReshapeOpLowering,
1900 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1901 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1905 ViewOpLowering>(converter);
1908 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1909 patterns.
add<AlignedAllocOpLowering, AlignedReallocOpLowering,
1910 DeallocOpLowering>(converter);
1911 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1912 patterns.
add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
1917 struct FinalizeMemRefToLLVMConversionPass
1918 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
1919 FinalizeMemRefToLLVMConversionPass> {
1920 using FinalizeMemRefToLLVMConversionPassBase::
1921 FinalizeMemRefToLLVMConversionPassBase;
1923 void runOnOperation()
override {
1925 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1927 dataLayoutAnalysis.getAtOrAbove(op));
1929 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1930 : LowerToLLVMOptions::AllocLowering::Malloc);
1932 options.useGenericFunctions = useGenericFunctions;
1933 options.useOpaquePointers = useOpaquePointers;
1936 options.overrideIndexBitwidth(indexBitwidth);
1939 &dataLayoutAnalysis);
1943 target.addLegalOp<func::FuncOp>();
1945 signalPassFailure();
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
A dimensional identifier appearing in an affine expression.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0)
Creates an LLVM pointer type with the given element type and address space.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type)
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
bool useOpaquePointers() const
Returns true if using opaque pointers was enabled in the lowering options.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
result_range getResults()
BlockListType::iterator iterator
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setSize(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
static void setStride(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp, bool opaquePointers)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp, bool opaquePointers)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(PatternRewriter &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Lowering for AllocOp and AllocaOp.
Lowering for memory allocation ops.
This class represents an efficient way to signal success or failure.