27 #include "llvm/Support/DebugLog.h"
28 #include "llvm/Support/MathExtras.h"
32 #define DEBUG_TYPE "memref-to-llvm"
35 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36 #include "mlir/Conversion/Passes.h.inc"
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
46 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
50 static FailureOr<LLVM::LLVMFuncOp>
61 static FailureOr<LLVM::LLVMFuncOp>
73 static FailureOr<LLVM::LLVMFuncOp>
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.
getType(),
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
108 Type elementType = memRefType.getElementType();
109 if (
auto memRefElementType = dyn_cast<MemRefType>(elementType))
111 if (
auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
119 MemRefType memRefType,
Type elementPtrType,
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.
getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
124 assert(succeeded(maybeMemrefAddrSpace) &&
"unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
142 symbolTables(symbolTables) {}
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
156 getIndexType(), symbolTables);
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes,
true);
170 Value alignment = getAlignment(rewriter, loc, op);
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
178 assert(elementPtrType &&
"could not compute element ptr type");
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
191 createAligned(rewriter, loc, allocatedInt, alignment);
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
201 rewriter.
replaceOp(op, {memRefDescriptor});
206 template <
typename OpType>
209 MemRefType memRefType = op.
getType();
211 if (
auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
215 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
220 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
234 symbolTables(symbolTables) {}
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
248 getIndexType(), symbolTables);
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !
false);
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
264 Value allocAlignment =
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
286 rewriter.
replaceOp(op, {memRefDescriptor});
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
300 if (std::optional<uint64_t> alignment = op.getAlignment())
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
Operation *op,
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
323 return sizeDivisor % factor == 0;
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !
true);
358 typeConverter->
convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
360 getTypeConverter()->getMemRefAddressSpace(op.getType());
361 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
376 rewriter.
replaceOp(op, {memRefDescriptor});
381 struct AllocaScopeOpLowering
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
389 Location loc = allocaScopeOp.getLoc();
394 auto *remainingOpsBlock =
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc,
ValueRange(), remainingOpsBlock);
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc,
ValueRange(), beforeBody);
421 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
423 returnOp, returnOp.getResults(), continueBlock);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
436 struct AssumeAlignmentOpLowering
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
446 Value memref = adaptor.getMemref();
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getBoolAttr(
true));
459 Value alignmentConst =
461 LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
479 symbolTables(symbolTables) {}
482 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
485 FailureOr<LLVM::LLVMFuncOp> freeFunc =
486 getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
491 if (
auto unrankedTy =
492 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
494 rewriter.
getContext(), unrankedTy.getMemorySpaceAsInt());
496 rewriter, op.getLoc(),
516 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
518 Type operandType = dimOp.getSource().getType();
519 if (isa<UnrankedMemRefType>(operandType)) {
520 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
521 operandType, dimOp, adaptor.getOperands(), rewriter);
522 if (
failed(extractedSize))
524 rewriter.
replaceOp(dimOp, {*extractedSize});
527 if (isa<MemRefType>(operandType)) {
529 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
530 adaptor.getOperands(), rewriter)});
533 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
538 extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
543 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
544 auto scalarMemRefType =
546 FailureOr<unsigned> maybeAddressSpace =
547 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
548 if (
failed(maybeAddressSpace)) {
549 dimOp.emitOpError(
"memref memory space must be convertible to an integer "
553 unsigned addressSpace = *maybeAddressSpace;
567 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
572 Value idxPlusOne = LLVM::AddOp::create(
576 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
577 getTypeConverter()->getIndexType(),
578 offsetPtr, idxPlusOne);
579 return LLVM::LoadOp::create(rewriter, loc,
580 getTypeConverter()->getIndexType(), sizePtr)
584 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp)
const {
585 if (
auto idx = dimOp.getConstantIndex())
588 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
589 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
594 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
600 MemRefType memRefType = cast<MemRefType>(operandType);
601 Type indexType = getIndexType();
602 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
604 if (i >= 0 && i < memRefType.getRank()) {
605 if (memRefType.isDynamicDim(i)) {
608 return descriptor.
size(rewriter, loc, i);
611 int64_t dimSize = memRefType.getDimSize(i);
615 Value index = adaptor.getIndex();
616 int64_t rank = memRefType.getRank();
618 return memrefDescriptor.
size(rewriter, loc, index, rank);
625 template <
typename Derived>
629 using Base = LoadStoreOpLowering<Derived>;
659 struct GenericAtomicRMWOpLowering
660 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
664 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
666 auto loc = atomicOp.getLoc();
667 Type valueType = typeConverter->
convertType(atomicOp.getResult().getType());
679 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
681 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
682 Value init = LLVM::LoadOp::create(
683 rewriter, loc, typeConverter->
convertType(memRefType.getElementType()),
685 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
691 auto loopArgument = loopBlock->getArgument(0);
693 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
703 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
704 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
706 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
707 result, successOrdering, failureOrdering);
709 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
710 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
714 loopBlock, newLoaded);
719 rewriter.
replaceOp(atomicOp, {newLoaded});
727 convertGlobalMemrefTypeToLLVM(MemRefType type,
735 Type arrayTy = elementType;
737 for (int64_t dim : llvm::reverse(type.getShape()))
751 symbolTables(symbolTables) {}
754 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
756 MemRefType type = global.getType();
757 if (!isConvertibleAndHasIdentityMaps(type))
760 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
762 LLVM::Linkage linkage =
763 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
764 bool isExternal = global.isExternal();
765 bool isUninitialized = global.isUninitialized();
768 if (!isExternal && !isUninitialized) {
769 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
770 initialValue = elementsAttr;
774 if (type.getRank() == 0)
775 initialValue = elementsAttr.getSplatValue<
Attribute>();
778 uint64_t alignment = global.getAlignment().value_or(0);
779 FailureOr<unsigned> addressSpace =
780 getTypeConverter()->getMemRefAddressSpace(type);
782 return global.emitOpError(
783 "memory space cannot be converted to an integer address space");
791 symbolTable->remove(global);
796 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
797 initialValue, alignment, *addressSpace);
803 if (!isExternal && isUninitialized) {
804 rewriter.
createBlock(&newGlobal.getInitializerRegion());
806 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
807 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
816 struct GetGlobalMemrefOpLowering
823 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
825 auto loc = op.getLoc();
826 MemRefType memRefType = op.getType();
827 if (!isConvertibleAndHasIdentityMaps(memRefType))
837 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
838 rewriter, sizes, strides, sizeBytes, !
false);
840 MemRefType type = cast<MemRefType>(op.getResult().getType());
844 FailureOr<unsigned> maybeAddressSpace =
845 getTypeConverter()->getMemRefAddressSpace(type);
846 assert(succeeded(maybeAddressSpace) &&
"unsupported address space");
847 unsigned memSpace = *maybeAddressSpace;
849 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
852 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
857 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
863 auto intPtrType = getIntPtrType(memSpace);
864 Value deadBeefConst =
867 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
872 auto memRefDescriptor = this->createMemRefDescriptor(
873 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
876 rewriter.
replaceOp(op, {memRefDescriptor});
883 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
887 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
889 auto type = loadOp.getMemRefType();
898 loadOp, typeConverter->
convertType(type.getElementType()), dataPtr,
899 loadOp.getAlignment().value_or(0),
false, loadOp.getNontemporal());
906 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
910 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
912 auto type = op.getMemRefType();
921 op.getAlignment().value_or(0),
922 false, op.getNontemporal());
929 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
933 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
935 auto type = prefetchOp.getMemRefType();
936 auto loc = prefetchOp.getLoc();
939 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
943 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
947 localityHint, isData);
956 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
959 Type operandType = op.getMemref().getType();
960 if (isa<UnrankedMemRefType>(operandType)) {
965 if (
auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
966 Type indexType = getIndexType();
969 rankedMemRefType.getRank())});
980 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
982 Type srcType = memRefCastOp.getOperand().getType();
983 Type dstType = memRefCastOp.getType();
990 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
996 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
999 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
1000 auto loc = memRefCastOp.getLoc();
1003 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1004 rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
1008 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1013 auto srcMemRefType = cast<MemRefType>(srcType);
1014 int64_t rank = srcMemRefType.getRank();
1016 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1017 loc, adaptor.getSource(), rewriter);
1020 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1026 memRefDesc.
setRank(rewriter, loc, rankVal);
1031 }
else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1040 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1041 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
1043 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
1063 symbolTables(symbolTables) {}
1066 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1068 auto loc = op.getLoc();
1069 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1074 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1076 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
1077 auto size = srcDesc.
size(rewriter, loc, pos);
1078 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1082 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1085 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1087 Type elementType = typeConverter->
convertType(srcType.getElementType());
1091 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.
getType(),
1092 elementType, srcBasePtr, srcOffset);
1095 Value targetOffset = targetDesc.
offset(rewriter, loc);
1097 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.
getType(), elementType,
1098 targetBasePtr, targetOffset);
1099 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1107 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1109 auto loc = op.getLoc();
1110 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1111 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1114 auto makeUnranked = [&,
this](
Value ranked, MemRefType type) {
1115 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1117 auto *typeConverter = getTypeConverter();
1124 rewriter, loc, *typeConverter, unrankedType,
ValueRange{rank, ptr});
1128 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1130 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1131 Value unrankedSource =
1132 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1133 : adaptor.getSource();
1134 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1135 Value unrankedTarget =
1136 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1137 : adaptor.getTarget();
1140 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1145 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1146 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1150 auto sourcePtr =
promote(unrankedSource);
1151 auto targetPtr =
promote(unrankedTarget);
1155 auto elemSize =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
1157 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1158 sourcePtr.
getType(), symbolTables);
1161 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1165 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1173 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1175 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1176 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1179 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1183 return memrefType &&
1184 (memrefType.getLayout().isIdentity() ||
1185 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1189 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1190 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1192 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1196 struct MemorySpaceCastOpLowering
1202 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1206 Type resultType = op.getDest().getType();
1207 if (
auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1208 auto resultDescType =
1209 cast<LLVM::LLVMStructType>(typeConverter->
convertType(resultTypeR));
1210 Type newPtrType = resultDescType.getBody()[0];
1216 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1218 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1220 resultTypeR, descVals);
1224 if (
auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1227 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1228 FailureOr<unsigned> maybeSourceAddrSpace =
1229 getTypeConverter()->getMemRefAddressSpace(sourceType);
1230 if (
failed(maybeSourceAddrSpace))
1232 "non-integer source address space");
1233 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1234 FailureOr<unsigned> maybeResultAddrSpace =
1235 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1236 if (
failed(maybeResultAddrSpace))
1238 "non-integer result address space");
1239 unsigned resultAddrSpace = *maybeResultAddrSpace;
1242 Value rank = sourceDesc.
rank(rewriter, loc);
1247 rewriter, loc, typeConverter->
convertType(resultTypeU));
1248 result.setRank(rewriter, loc, rank);
1250 rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1251 Value resultUnderlyingDesc =
1252 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1253 rewriter.
getI8Type(), resultUnderlyingSize);
1254 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1257 auto sourceElemPtrType =
1259 auto resultElemPtrType =
1263 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1265 sourceDesc.
alignedPtr(rewriter, loc, *getTypeConverter(),
1266 sourceUnderlyingDesc, sourceElemPtrType);
1267 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1268 rewriter, loc, resultElemPtrType, allocatedPtr);
1269 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1270 resultElemPtrType, alignedPtr);
1272 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1273 resultElemPtrType, allocatedPtr);
1274 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1275 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1278 Value sourceIndexVals =
1279 sourceDesc.
offsetBasePtr(rewriter, loc, *getTypeConverter(),
1280 sourceUnderlyingDesc, sourceElemPtrType);
1281 Value resultIndexVals =
1282 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1283 resultUnderlyingDesc, resultElemPtrType);
1285 int64_t bytesToSkip =
1287 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1288 Value bytesToSkipConst = LLVM::ConstantOp::create(
1289 rewriter, loc, getIndexType(), rewriter.
getIndexAttr(bytesToSkip));
1291 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1292 resultUnderlyingSize, bytesToSkipConst);
1293 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1306 static void extractPointersAndOffset(
Location loc,
1309 Value originalOperand,
1310 Value convertedOperand,
1312 Value *offset =
nullptr) {
1314 if (isa<MemRefType>(operandType)) {
1317 *alignedPtr = desc.
alignedPtr(rewriter, loc);
1318 if (offset !=
nullptr)
1319 *offset = desc.
offset(rewriter, loc);
1325 cast<UnrankedMemRefType>(operandType));
1326 auto elementPtrType =
1335 rewriter, loc, underlyingDescPtr, elementPtrType);
1337 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1338 if (offset !=
nullptr) {
1340 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1344 struct MemRefReinterpretCastOpLowering
1350 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1352 Type srcType = castOp.getSource().getType();
1355 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1356 adaptor, &descriptor)))
1358 rewriter.
replaceOp(castOp, {descriptor});
1363 LogicalResult convertSourceMemRefToDescriptor(
1365 memref::ReinterpretCastOp castOp,
1366 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1367 MemRefType targetMemRefType =
1368 cast<MemRefType>(castOp.getResult().getType());
1369 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1371 if (!llvmTargetDescriptorTy)
1379 Value allocatedPtr, alignedPtr;
1380 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1381 castOp.getSource(), adaptor.getSource(),
1382 &allocatedPtr, &alignedPtr);
1383 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1384 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1387 if (castOp.isDynamicOffset(0))
1388 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1390 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1393 unsigned dynSizeId = 0;
1394 unsigned dynStrideId = 0;
1395 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1396 if (castOp.isDynamicSize(i))
1397 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1399 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1401 if (castOp.isDynamicStride(i))
1402 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1404 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1411 struct MemRefReshapeOpLowering
1416 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1418 Type srcType = reshapeOp.getSource().getType();
1421 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1422 adaptor, &descriptor)))
1424 rewriter.
replaceOp(reshapeOp, {descriptor});
1431 Type srcType, memref::ReshapeOp reshapeOp,
1432 memref::ReshapeOp::Adaptor adaptor,
1433 Value *descriptor)
const {
1434 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1435 if (shapeMemRefType.hasStaticShape()) {
1436 MemRefType targetMemRefType =
1437 cast<MemRefType>(reshapeOp.getResult().getType());
1438 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1440 if (!llvmTargetDescriptorTy)
1449 Value allocatedPtr, alignedPtr;
1450 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1451 reshapeOp.getSource(), adaptor.getSource(),
1452 &allocatedPtr, &alignedPtr);
1453 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1454 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1459 if (
failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1461 reshapeOp,
"failed to get stride and offset exprs");
1463 if (!isStaticStrideOrOffset(offset))
1465 "dynamic offset is unsupported");
1467 desc.setConstantOffset(rewriter, loc, offset);
1469 assert(targetMemRefType.getLayout().isIdentity() &&
1470 "Identity layout map is a precondition of a valid reshape op");
1472 Type indexType = getIndexType();
1473 Value stride =
nullptr;
1474 int64_t targetRank = targetMemRefType.getRank();
1475 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1476 if (ShapedType::isStatic(strides[i])) {
1481 }
else if (!stride) {
1491 if (!targetMemRefType.isDynamicDim(i)) {
1493 targetMemRefType.getDimSize(i));
1495 Value shapeOp = reshapeOp.getShape();
1497 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1498 Type indexType = getIndexType();
1499 if (dimSize.
getType() != indexType)
1501 rewriter, loc, indexType, dimSize);
1502 assert(dimSize &&
"Invalid memref element type");
1505 desc.setSize(rewriter, loc, i, dimSize);
1506 desc.setStride(rewriter, loc, i, stride);
1509 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1519 Value resultRank = shapeDesc.
size(rewriter, loc, 0);
1522 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1523 unsigned addressSpace =
1524 *getTypeConverter()->getMemRefAddressSpace(targetType);
1529 rewriter, loc, typeConverter->
convertType(targetType));
1530 targetDesc.setRank(rewriter, loc, resultRank);
1532 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1533 Value underlyingDescPtr = LLVM::AllocaOp::create(
1536 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1539 Value allocatedPtr, alignedPtr, offset;
1540 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1541 reshapeOp.getSource(), adaptor.getSource(),
1542 &allocatedPtr, &alignedPtr, &offset);
1545 auto elementPtrType =
1549 elementPtrType, allocatedPtr);
1551 underlyingDescPtr, elementPtrType,
1554 underlyingDescPtr, elementPtrType,
1560 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1562 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1565 Value resultRankMinusOne =
1566 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1569 Type indexType = getTypeConverter()->getIndexType();
1573 {indexType, indexType}, {loc, loc});
1576 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1580 LLVM::BrOp::create(rewriter, loc,
1581 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1587 Value pred = LLVM::ICmpOp::create(
1589 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1597 Value sizeLoadGep = LLVM::GEPOp::create(
1598 rewriter, loc, llvmIndexPtrType,
1599 typeConverter->
convertType(shapeMemRefType.getElementType()),
1600 shapeOperandPtr, indexArg);
1601 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1603 targetSizesBase, indexArg, size);
1607 targetStridesBase, indexArg, strideArg);
1608 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1611 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1612 LLVM::BrOp::create(rewriter, loc,
ValueRange({decrement, nextStride}),
1620 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock,
ValueRange(),
1626 *descriptor = targetDesc;
1633 template <
typename ReshapeOp>
1634 class ReassociatingReshapeOpConversion
1638 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1641 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1645 "reassociation operations should have been expanded beforehand");
1655 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1658 subViewOp,
"subview operations should have been expanded beforehand");
1674 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1676 auto loc = transposeOp.getLoc();
1680 if (transposeOp.getPermutation().isIdentity())
1681 return rewriter.
replaceOp(transposeOp, {viewMemRef}), success();
1685 typeConverter->
convertType(transposeOp.getIn().getType()));
1689 targetMemRef.setAllocatedPtr(rewriter, loc,
1691 targetMemRef.setAlignedPtr(rewriter, loc,
1695 targetMemRef.setOffset(rewriter, loc, viewMemRef.
offset(rewriter, loc));
1701 for (
const auto &en :
1703 int targetPos = en.index();
1704 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1705 targetMemRef.setSize(rewriter, loc, targetPos,
1706 viewMemRef.
size(rewriter, loc, sourcePos));
1707 targetMemRef.setStride(rewriter, loc, targetPos,
1708 viewMemRef.
stride(rewriter, loc, sourcePos));
1711 rewriter.
replaceOp(transposeOp, {targetMemRef});
1728 Type indexType)
const {
1729 assert(idx < shape.size());
1730 if (ShapedType::isStatic(shape[idx]))
1734 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1735 return dynamicSizes[nDynamic];
1744 Value runningStride,
unsigned idx,
Type indexType)
const {
1745 assert(idx < strides.size());
1746 if (ShapedType::isStatic(strides[idx]))
1749 return runningStride
1750 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1752 assert(!runningStride);
1757 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1759 auto loc = viewOp.getLoc();
1761 auto viewMemRefType = viewOp.getType();
1762 auto targetElementTy =
1763 typeConverter->
convertType(viewMemRefType.getElementType());
1764 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1765 if (!targetDescTy || !targetElementTy ||
1768 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1773 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1774 if (
failed(successStrides))
1775 return viewOp.emitWarning(
"cannot cast to non-strided shape"), failure();
1776 assert(offset == 0 &&
"expected offset to be 0");
1780 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1781 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1790 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1791 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1795 alignedPtr = LLVM::GEPOp::create(
1796 rewriter, loc, alignedPtr.
getType(),
1797 typeConverter->
convertType(srcMemRefType.getElementType()), alignedPtr,
1798 adaptor.getByteShift());
1800 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1802 Type indexType = getIndexType();
1806 targetMemRef.setOffset(
1811 if (viewMemRefType.getRank() == 0)
1812 return rewriter.
replaceOp(viewOp, {targetMemRef}), success();
1815 Value stride =
nullptr, nextSize =
nullptr;
1816 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1818 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1819 adaptor.getSizes(), i, indexType);
1820 targetMemRef.setSize(rewriter, loc, i, size);
1823 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1824 targetMemRef.setStride(rewriter, loc, i, stride);
1828 rewriter.
replaceOp(viewOp, {targetMemRef});
1839 static std::optional<LLVM::AtomicBinOp>
1840 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1841 switch (atomicOp.getKind()) {
1842 case arith::AtomicRMWKind::addf:
1843 return LLVM::AtomicBinOp::fadd;
1844 case arith::AtomicRMWKind::addi:
1846 case arith::AtomicRMWKind::assign:
1847 return LLVM::AtomicBinOp::xchg;
1848 case arith::AtomicRMWKind::maximumf:
1850 LDBG() <<
"the lowering of memref.atomicrmw maximumf changed "
1851 "from fmax to fmaximum, expect more NaNs";
1852 return LLVM::AtomicBinOp::fmaximum;
1853 case arith::AtomicRMWKind::maxnumf:
1854 return LLVM::AtomicBinOp::fmax;
1855 case arith::AtomicRMWKind::maxs:
1857 case arith::AtomicRMWKind::maxu:
1858 return LLVM::AtomicBinOp::umax;
1859 case arith::AtomicRMWKind::minimumf:
1861 LDBG() <<
"the lowering of memref.atomicrmw minimum changed "
1862 "from fmin to fminimum, expect more NaNs";
1863 return LLVM::AtomicBinOp::fminimum;
1864 case arith::AtomicRMWKind::minnumf:
1865 return LLVM::AtomicBinOp::fmin;
1866 case arith::AtomicRMWKind::mins:
1868 case arith::AtomicRMWKind::minu:
1869 return LLVM::AtomicBinOp::umin;
1870 case arith::AtomicRMWKind::ori:
1871 return LLVM::AtomicBinOp::_or;
1872 case arith::AtomicRMWKind::xori:
1873 return LLVM::AtomicBinOp::_xor;
1874 case arith::AtomicRMWKind::andi:
1875 return LLVM::AtomicBinOp::_and;
1877 return std::nullopt;
1879 llvm_unreachable(
"Invalid AtomicRMWKind");
1882 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1886 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1888 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1891 auto memRefType = atomicOp.getMemRefType();
1894 if (
failed(memRefType.getStridesAndOffset(strides, offset)))
1898 adaptor.getMemref(), adaptor.getIndices());
1900 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1901 LLVM::AtomicOrdering::acq_rel);
1907 class ConvertExtractAlignedPointerAsIndex
1914 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1922 alignedPtr = desc.
alignedPtr(rewriter, extractOp->getLoc());
1931 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1936 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1943 class ExtractStridedMetadataOpLowering
1950 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1959 Location loc = extractStridedMetadataOp.getLoc();
1960 Value source = extractStridedMetadataOp.getSource();
1962 auto sourceMemRefType = cast<MemRefType>(source.
getType());
1963 int64_t rank = sourceMemRefType.getRank();
1965 results.reserve(2 + rank * 2);
1971 rewriter, loc, *getTypeConverter(),
1972 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1973 baseBuffer, alignedBuffer);
1974 results.push_back((
Value)dstMemRef);
1977 results.push_back(sourceMemRef.
offset(rewriter, loc));
1980 for (
unsigned i = 0; i < rank; ++i)
1981 results.push_back(sourceMemRef.
size(rewriter, loc, i));
1983 for (
unsigned i = 0; i < rank; ++i)
1984 results.push_back(sourceMemRef.
stride(rewriter, loc, i));
1986 rewriter.
replaceOp(extractStridedMetadataOp, results);
1999 AllocaScopeOpLowering,
2000 AtomicRMWOpLowering,
2001 AssumeAlignmentOpLowering,
2002 ConvertExtractAlignedPointerAsIndex,
2004 ExtractStridedMetadataOpLowering,
2005 GenericAtomicRMWOpLowering,
2006 GetGlobalMemrefOpLowering,
2008 MemRefCastOpLowering,
2009 MemorySpaceCastOpLowering,
2010 MemRefReinterpretCastOpLowering,
2011 MemRefReshapeOpLowering,
2014 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2015 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2019 ViewOpLowering>(converter);
2021 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2025 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2028 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2032 struct FinalizeMemRefToLLVMConversionPass
2033 :
public impl::FinalizeMemRefToLLVMConversionPassBase<
2034 FinalizeMemRefToLLVMConversionPass> {
2035 using FinalizeMemRefToLLVMConversionPassBase::
2036 FinalizeMemRefToLLVMConversionPassBase;
2038 void runOnOperation()
override {
2040 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2042 dataLayoutAnalysis.getAtOrAbove(op));
2047 options.useGenericFunctions = useGenericFunctions;
2050 options.overrideIndexBitwidth(indexBitwidth);
2053 &dataLayoutAnalysis);
2059 target.addLegalOp<func::FuncOp>();
2061 signalPassFailure();
2068 void loadDependentDialects(
MLIRContext *context)
const final {
2069 context->loadDialect<LLVM::LLVMDialect>();
2074 void populateConvertToLLVMConversionPatterns(
2085 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
AllocLowering allocLowering
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.