10 #include "../PassDetail.h" 23 #include "llvm/ADT/SmallBitVector.h" 29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30 return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
38 LLVM::LLVMFuncOp getAllocFn(ModuleOp module)
const {
39 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
51 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
52 MemRefType memRefType = allocOp.getType();
55 if (
auto alignmentAttr = allocOp.getAlignment()) {
56 alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
57 }
else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
62 alignment =
getSizeInBytes(loc, memRefType.getElementType(), rewriter);
67 sizeBytes = rewriter.
create<LLVM::AddOp>(loc, sizeBytes, alignment);
73 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
74 auto results = rewriter.
create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
75 Value allocatedPtr = rewriter.
create<LLVM::BitcastOp>(loc, elementPtrType,
78 Value alignedPtr = allocatedPtr;
82 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
84 createAligned(rewriter, loc, allocatedInt, alignment);
86 rewriter.
create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
89 return std::make_tuple(allocatedPtr, alignedPtr);
104 getTypeConverter()->getDataLayoutAnalysis()) {
105 layout = &analysis->getAbove(op);
107 Type elementType = memRefType.getElementType();
108 if (
auto memRefElementType = elementType.
dyn_cast<MemRefType>())
109 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
112 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
113 memRefElementType, *layout);
119 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
122 for (
unsigned i = 0, e = type.getRank(); i < e; i++) {
123 if (ShapedType::isDynamic(type.getDimSize(i)))
125 sizeDivisor = sizeDivisor * type.getDimSize(i);
127 return sizeDivisor % factor == 0;
133 int64_t getAllocationAlignment(memref::AllocOp allocOp)
const {
141 return std::max(kMinAlignedAllocAlignment,
142 llvm::PowerOf2Ceil(eltSizeBytes));
145 LLVM::LLVMFuncOp getAllocFn(ModuleOp module)
const {
146 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
158 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
159 MemRefType memRefType = allocOp.getType();
160 int64_t alignment = getAllocationAlignment(allocOp);
161 Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
165 if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
166 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
169 auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
170 auto results = rewriter.
create<LLVM::CallOp>(
171 loc, allocFuncOp,
ValueRange({allocAlignment, sizeBytes}));
172 Value allocatedPtr = rewriter.
create<LLVM::BitcastOp>(loc, elementPtrType,
173 results.getResult());
175 return std::make_tuple(allocatedPtr, allocatedPtr);
179 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
199 auto allocaOp = cast<memref::AllocaOp>(op);
202 auto allocatedElementPtr = rewriter.
create<LLVM::AllocaOp>(
203 loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
205 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
209 struct AllocaScopeOpLowering
214 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
217 Location loc = allocaScopeOp.getLoc();
222 auto *remainingOpsBlock =
224 Block *continueBlock;
225 if (allocaScopeOp.getNumResults() == 0) {
226 continueBlock = remainingOpsBlock;
229 remainingOpsBlock, allocaScopeOp.getResultTypes(),
231 allocaScopeOp.getLoc()));
236 Block *beforeBody = &allocaScopeOp.getBodyRegion().
front();
237 Block *afterBody = &allocaScopeOp.getBodyRegion().
back();
243 rewriter.
create<LLVM::StackSaveOp>(loc, getVoidPtrType());
250 cast<memref::AllocaScopeReturnOp>(afterBody->
getTerminator());
252 returnOp, returnOp.getResults(), continueBlock);
256 rewriter.
create<LLVM::StackRestoreOp>(loc, stackSaveOp);
265 struct AssumeAlignmentOpLowering
271 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
273 Value memref = adaptor.getMemref();
274 unsigned alignment = op.getAlignment();
278 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.
getLoc());
289 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
293 Value ptrValue = rewriter.
create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
294 rewriter.
create<LLVM::AssumeOp>(
295 loc, rewriter.
create<LLVM::ICmpOp>(
296 loc, LLVM::ICmpPredicate::eq,
297 rewriter.
create<LLVM::AndOp>(loc, ptrValue, mask), zero));
313 LLVM::LLVMFuncOp getFreeFn(ModuleOp module)
const {
314 bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
323 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
326 auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
329 op.getLoc(), getVoidPtrType(),
330 memref.allocatedPtr(rewriter, op.getLoc()));
342 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
344 Type operandType = dimOp.getSource().getType();
347 dimOp, {extractSizeOfUnrankedMemRef(
348 operandType, dimOp, adaptor.getOperands(), rewriter)});
352 if (operandType.
isa<MemRefType>()) {
354 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
355 adaptor.getOperands(), rewriter)});
358 llvm_unreachable(
"expected MemRefType or UnrankedMemRefType");
362 Value extractSizeOfUnrankedMemRef(
Type operandType, memref::DimOp dimOp,
368 auto scalarMemRefType =
369 MemRefType::get({}, unrankedMemRefType.getElementType());
370 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
376 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
377 Value scalarMemRefDescPtr = rewriter.
create<LLVM::BitcastOp>(
381 underlyingRankedDesc);
385 getTypeConverter()->getIndexType(), addressSpace);
392 loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
394 rewriter.
create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, idxPlusOne);
395 return rewriter.
create<LLVM::LoadOp>(loc, sizePtr);
402 if (
auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
403 return constantOp.getValue()
411 Value extractSizeOfRankedMemRef(
Type operandType, memref::DimOp dimOp,
417 MemRefType memRefType = operandType.
cast<MemRefType>();
420 if (memRefType.isDynamicDim(i)) {
423 return descriptor.
size(rewriter, loc, i);
426 int64_t dimSize = memRefType.getDimSize(i);
427 return createIndexConstant(rewriter, loc, dimSize);
429 Value index = adaptor.getIndex();
430 int64_t rank = memRefType.getRank();
432 return memrefDescriptor.
size(rewriter, loc, index, rank);
439 template <
typename Derived>
443 using Base = LoadStoreOpLowering<Derived>;
446 MemRefType type = op.getMemRefType();
447 return isConvertibleAndHasIdentityMaps(type) ?
success() :
failure();
478 struct GenericAtomicRMWOpLowering
479 :
public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
483 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
485 auto loc = atomicOp.getLoc();
486 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
497 auto opsToMoveStart = atomicOp->getIterator();
498 auto opsToMoveEnd = initBlock->
back().getIterator();
502 auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
503 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
504 adaptor.getIndices(), rewriter);
505 Value init = rewriter.
create<LLVM::LoadOp>(loc, dataPtr);
506 rewriter.
create<LLVM::BrOp>(loc, init, loopBlock);
512 auto loopArgument = loopBlock->getArgument(0);
514 mapping.
map(atomicOp.getCurrentValue(), loopArgument);
524 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
525 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
526 auto boolType = IntegerType::get(rewriter.
getContext(), 1);
528 {valueType, boolType});
529 auto cmpxchg = rewriter.
create<LLVM::AtomicCmpXchgOp>(
530 loc, pairType, dataPtr, loopArgument, result, successOrdering,
533 Value newLoaded = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
534 Value ok = rewriter.
create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
538 loopBlock, newLoaded);
541 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
542 std::next(opsToMoveEnd), rewriter);
545 rewriter.
replaceOp(atomicOp, {newLoaded});
556 mapping.
map(oldResult, newResult);
558 for (
auto it = start; it != end; ++it) {
559 rewriter.
clone(*it, mapping);
560 opsToErase.push_back(&*it);
562 for (
auto *it : opsToErase)
568 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
576 Type arrayTy = elementType;
578 for (int64_t dim : llvm::reverse(type.getShape()))
584 struct GlobalMemrefOpLowering
589 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
591 MemRefType type = global.getType();
592 if (!isConvertibleAndHasIdentityMaps(type))
595 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
597 LLVM::Linkage linkage =
598 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
601 if (!global.isExternal() && !global.isUninitialized()) {
602 auto elementsAttr = global.getInitialValue()->
cast<ElementsAttr>();
603 initialValue = elementsAttr;
607 if (type.getRank() == 0)
608 initialValue = elementsAttr.getSplatValue<
Attribute>();
611 uint64_t alignment = global.getAlignment().value_or(0);
614 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
615 initialValue, alignment, type.getMemorySpaceAsInt());
616 if (!global.isExternal() && global.isUninitialized()) {
618 newGlobal.getInitializerRegion().push_back(blk);
621 rewriter.
create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
622 rewriter.
create<LLVM::ReturnOp>(global.getLoc(), undef);
641 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
642 MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
643 unsigned memSpace = type.getMemorySpaceAsInt();
645 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
646 auto addressOf = rewriter.
create<LLVM::AddressOfOp>(
648 getGlobalOp.getName());
655 auto gep = rewriter.
create<LLVM::GEPOp>(
656 loc, elementPtrType, addressOf,
662 auto intPtrType = getIntPtrType(memSpace);
663 Value deadBeefConst =
666 rewriter.
create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
670 return std::make_tuple(deadBeefPtr, gep);
676 struct LoadOpLowering :
public LoadStoreOpLowering<memref::LoadOp> {
680 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
682 auto type = loadOp.getMemRefType();
685 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
686 adaptor.getIndices(), rewriter);
694 struct StoreOpLowering :
public LoadStoreOpLowering<memref::StoreOp> {
698 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
700 auto type = op.getMemRefType();
702 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
703 adaptor.getIndices(), rewriter);
711 struct PrefetchOpLowering :
public LoadStoreOpLowering<memref::PrefetchOp> {
715 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
717 auto type = prefetchOp.getMemRefType();
718 auto loc = prefetchOp.getLoc();
720 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
721 adaptor.getIndices(), rewriter);
725 auto isWrite = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type,
726 prefetchOp.getIsWrite());
727 auto localityHint = rewriter.
create<LLVM::ConstantOp>(
728 loc, llvmI32Type, prefetchOp.getLocalityHint());
729 auto isData = rewriter.
create<LLVM::ConstantOp>(
730 loc, llvmI32Type, prefetchOp.getIsDataCache());
733 localityHint, isData);
742 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
745 Type operandType = op.getMemref().getType();
748 rewriter.
replaceOp(op, {desc.rank(rewriter, loc)});
751 if (
auto rankedMemRefType = operandType.
dyn_cast<MemRefType>()) {
753 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
763 LogicalResult match(memref::CastOp memRefCastOp)
const override {
764 Type srcType = memRefCastOp.getOperand().getType();
765 Type dstType = memRefCastOp.getType();
772 if (srcType.
isa<MemRefType>() && dstType.
isa<MemRefType>())
787 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
789 auto srcType = memRefCastOp.getOperand().getType();
790 auto dstType = memRefCastOp.getType();
791 auto targetStructType = typeConverter->
convertType(memRefCastOp.getType());
792 auto loc = memRefCastOp.getLoc();
795 if (srcType.
isa<MemRefType>() && dstType.
isa<MemRefType>())
796 return rewriter.
replaceOp(memRefCastOp, {adaptor.getSource()});
803 auto srcMemRefType = srcType.
cast<MemRefType>();
804 int64_t rank = srcMemRefType.getRank();
806 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
807 loc, adaptor.getSource(), rewriter);
810 rewriter.
create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
813 auto rankVal = rewriter.
create<LLVM::ConstantOp>(
819 memRefDesc.setRank(rewriter, loc, rankVal);
821 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
830 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
838 auto loadOp = rewriter.
create<LLVM::LoadOp>(loc, castPtr);
839 rewriter.
replaceOp(memRefCastOp, loadOp.getResult());
841 llvm_unreachable(
"Unsupported unranked memref to unranked memref cast");
855 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
857 auto loc = op.getLoc();
858 auto srcType = op.getSource().getType().
dyn_cast<MemRefType>();
863 Value numElements = rewriter.
create<LLVM::ConstantOp>(
865 for (
int pos = 0; pos < srcType.getRank(); ++pos) {
866 auto size = srcDesc.size(rewriter, loc, pos);
867 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
871 auto sizeInBytes =
getSizeInBytes(loc, srcType.getElementType(), rewriter);
874 rewriter.
create<LLVM::MulOp>(loc, numElements, sizeInBytes);
876 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
877 Value srcOffset = srcDesc.offset(rewriter, loc);
878 Value srcPtr = rewriter.
create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
879 srcBasePtr, srcOffset);
881 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
882 Value targetOffset = targetDesc.offset(rewriter, loc);
883 Value targetPtr = rewriter.
create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
884 targetBasePtr, targetOffset);
887 rewriter.
create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
895 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
897 auto loc = op.getLoc();
903 auto rank = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
905 auto *typeConverter = getTypeConverter();
907 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
909 rewriter.
create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
912 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
918 Value unrankedSource = srcType.hasRank()
919 ? makeUnranked(adaptor.getSource(), srcType)
920 : adaptor.getSource();
922 ? makeUnranked(adaptor.getTarget(), targetType)
923 : adaptor.getTarget();
926 auto one = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
928 auto promote = [&](
Value desc) {
932 rewriter.
create<LLVM::StoreOp>(loc, desc, allocated);
936 auto sourcePtr = promote(unrankedSource);
937 auto targetPtr = promote(unrankedTarget);
941 auto elemSize = rewriter.
create<LLVM::ConstantOp>(
944 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
945 rewriter.
create<LLVM::CallOp>(loc, copyFn,
953 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
964 (memrefType.getLayout().isIdentity() ||
965 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
969 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
970 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
972 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
979 static void extractPointersAndOffset(
Location loc,
982 Value originalOperand,
983 Value convertedOperand,
985 Value *offset =
nullptr) {
987 if (operandType.
isa<MemRefType>()) {
989 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
990 *alignedPtr = desc.alignedPtr(rewriter, loc);
991 if (offset !=
nullptr)
992 *offset = desc.offset(rewriter, loc);
996 unsigned memorySpace =
1006 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1009 rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1011 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1012 if (offset !=
nullptr) {
1014 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1018 struct MemRefReinterpretCastOpLowering
1021 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1024 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1026 Type srcType = castOp.getSource().getType();
1029 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1030 adaptor, &descriptor)))
1032 rewriter.
replaceOp(castOp, {descriptor});
1039 memref::ReinterpretCastOp castOp,
1040 memref::ReinterpretCastOp::Adaptor adaptor,
Value *descriptor)
const {
1041 MemRefType targetMemRefType =
1042 castOp.getResult().getType().cast<MemRefType>();
1043 auto llvmTargetDescriptorTy = typeConverter->
convertType(targetMemRefType)
1045 if (!llvmTargetDescriptorTy)
1053 Value allocatedPtr, alignedPtr;
1054 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1055 castOp.getSource(), adaptor.getSource(),
1056 &allocatedPtr, &alignedPtr);
1057 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1058 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1061 if (castOp.isDynamicOffset(0))
1062 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1064 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1067 unsigned dynSizeId = 0;
1068 unsigned dynStrideId = 0;
1069 for (
unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1070 if (castOp.isDynamicSize(i))
1071 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1073 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1075 if (castOp.isDynamicStride(i))
1076 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1078 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1085 struct MemRefReshapeOpLowering
1090 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1092 Type srcType = reshapeOp.getSource().getType();
1095 if (
failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1096 adaptor, &descriptor)))
1098 rewriter.
replaceOp(reshapeOp, {descriptor});
1105 Type srcType, memref::ReshapeOp reshapeOp,
1106 memref::ReshapeOp::Adaptor adaptor,
1107 Value *descriptor)
const {
1108 auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1109 if (shapeMemRefType.hasStaticShape()) {
1110 MemRefType targetMemRefType =
1111 reshapeOp.getResult().getType().cast<MemRefType>();
1112 auto llvmTargetDescriptorTy =
1115 if (!llvmTargetDescriptorTy)
1124 Value allocatedPtr, alignedPtr;
1125 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1126 reshapeOp.getSource(), adaptor.getSource(),
1127 &allocatedPtr, &alignedPtr);
1128 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1129 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1136 reshapeOp,
"failed to get stride and offset exprs");
1138 if (!isStaticStrideOrOffset(offset))
1140 "dynamic offset is unsupported");
1142 desc.setConstantOffset(rewriter, loc, offset);
1144 assert(targetMemRefType.getLayout().isIdentity() &&
1145 "Identity layout map is a precondition of a valid reshape op");
1147 Value stride =
nullptr;
1148 int64_t targetRank = targetMemRefType.getRank();
1149 for (
auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1150 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1153 stride = createIndexConstant(rewriter, loc, strides[i]);
1154 }
else if (!stride) {
1158 stride = createIndexConstant(rewriter, loc, 1);
1162 int64_t size = targetMemRefType.getDimSize(i);
1165 if (!ShapedType::isDynamic(size)) {
1166 dimSize = createIndexConstant(rewriter, loc, size);
1168 Value shapeOp = reshapeOp.getShape();
1169 Value index = createIndexConstant(rewriter, loc, i);
1170 dimSize = rewriter.
create<memref::LoadOp>(loc, shapeOp, index);
1171 Type indexType = getIndexType();
1172 if (dimSize.
getType() != indexType)
1174 rewriter, loc, indexType, dimSize);
1175 assert(dimSize &&
"Invalid memref element type");
1178 desc.setSize(rewriter, loc, i, dimSize);
1179 desc.setStride(rewriter, loc, i, stride);
1182 stride = rewriter.
create<LLVM::MulOp>(loc, stride, dimSize);
1192 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1203 rewriter, loc, typeConverter->
convertType(targetType));
1204 targetDesc.
setRank(rewriter, loc, resultRank);
1208 Value underlyingDescPtr = rewriter.
create<LLVM::AllocaOp>(
1209 loc, getVoidPtrType(), sizes.front(), llvm::None);
1210 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1213 Value allocatedPtr, alignedPtr, offset;
1214 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1215 reshapeOp.getSource(), adaptor.getSource(),
1216 &allocatedPtr, &alignedPtr, &offset);
1223 elementPtrPtrType, allocatedPtr);
1226 elementPtrPtrType, alignedPtr);
1228 underlyingDescPtr, elementPtrPtrType,
1234 rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1237 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1238 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1239 Value oneIndex = createIndexConstant(rewriter, loc, 1);
1240 Value resultRankMinusOne =
1241 rewriter.
create<LLVM::SubOp>(loc, resultRank, oneIndex);
1244 Type indexType = getTypeConverter()->getIndexType();
1248 {indexType, indexType}, {loc, loc});
1251 Block *remainingBlock = rewriter.
splitBlock(initBlock, remainingOpsIt);
1255 rewriter.
create<LLVM::BrOp>(loc,
ValueRange({resultRankMinusOne, oneIndex}),
1258 Value indexArg = condBlock->getArgument(0);
1259 Value strideArg = condBlock->getArgument(1);
1261 Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1263 loc, IntegerType::get(rewriter.
getContext(), 1),
1264 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1272 Value sizeLoadGep = rewriter.
create<LLVM::GEPOp>(loc, llvmIndexPtrType,
1273 shapeOperandPtr, indexArg);
1274 Value size = rewriter.
create<LLVM::LoadOp>(loc, sizeLoadGep);
1276 targetSizesBase, indexArg, size);
1280 targetStridesBase, indexArg, strideArg);
1281 Value nextStride = rewriter.
create<LLVM::MulOp>(loc, strideArg, size);
1284 Value decrement = rewriter.
create<LLVM::SubOp>(loc, indexArg, oneIndex);
1293 rewriter.
create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1299 *descriptor = targetDesc;
1307 Type &llvmIndexType,
1309 return llvm::to_vector<4>(
1311 if (
auto attr = value.dyn_cast<
Attribute>())
1312 return b.
create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1313 return value.get<
Value>();
1323 for (
auto &en :
enumerate(reassociation)) {
1324 for (
auto dim : en.value())
1325 expandedDimToCollapsedDim[dim] = en.index();
1327 return expandedDimToCollapsedDim;
1337 int64_t outDimSize = outStaticShape[outDimIndex];
1338 if (!ShapedType::isDynamic(outDimSize))
1343 int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1344 int64_t otherDimSizesMul = 1;
1345 for (
auto otherDimIndex : reassocation[inDimIndex]) {
1346 if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1348 int64_t otherDimSize = outStaticShape[otherDimIndex];
1349 assert(!ShapedType::isDynamic(otherDimSize) &&
1350 "single dimension cannot be expanded into multiple dynamic " 1352 otherDimSizesMul *= otherDimSize;
1356 int64_t inDimSize = inStaticShape[inDimIndex];
1357 Value inDimSizeDynamic =
1358 ShapedType::isDynamic(inDimSize)
1359 ? inDesc.
size(b, loc, inDimIndex)
1360 : b.
create<LLVM::ConstantOp>(loc, llvmIndexType,
1363 loc, inDimSizeDynamic,
1364 b.
create<LLVM::ConstantOp>(loc, llvmIndexType,
1366 return outDimSizeDynamic;
1373 if (!ShapedType::isDynamic(outDimSize))
1377 Value outDimSizeDynamic = c1;
1378 for (
auto inDimIndex : reassocation[outDimIndex]) {
1379 int64_t inDimSize = inStaticShape[inDimIndex];
1380 Value inDimSizeDynamic =
1381 ShapedType::isDynamic(inDimSize)
1382 ? inDesc.
size(b, loc, inDimIndex)
1383 : b.
create<LLVM::ConstantOp>(loc, llvmIndexType,
1386 b.
create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1388 return outDimSizeDynamic;
1397 return llvm::to_vector<4>(llvm::map_range(
1398 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1399 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1400 outStaticShape[outDimIndex],
1401 inStaticShape, inDesc, reassociation);
1413 return llvm::to_vector<4>(llvm::map_range(
1414 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1415 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1416 outStaticShape, inDesc, inStaticShape,
1417 reassociation, outDimToInDimMap);
1426 return outStaticShape.size() < inStaticShape.size()
1428 getCollapsedOutputShape(b, loc, llvmIndexType,
1429 reassociation, inStaticShape,
1430 inDesc, outStaticShape))
1432 getExpandedOutputShape(b, loc, llvmIndexType,
1433 reassociation, inStaticShape,
1434 inDesc, outStaticShape));
1437 static void fillInStridesForExpandedMemDescriptor(
1443 auto currentStrideToExpand = srcDesc.
stride(b, loc, en.index());
1444 for (
auto dstIndex : llvm::reverse(en.value())) {
1445 dstDesc.
setStride(b, loc, dstIndex, currentStrideToExpand);
1446 Value size = dstDesc.
size(b, loc, dstIndex);
1447 currentStrideToExpand =
1448 b.
create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1453 static void fillInStridesForCollapsedMemDescriptor(
1459 auto srcShape = srcType.getShape();
1462 auto dstIndex = en.index();
1464 while (srcShape[ref.back()] == 1 && ref.size() > 1)
1465 ref = ref.drop_back();
1466 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1467 dstDesc.
setStride(rewriter, loc, dstIndex,
1468 srcDesc.
stride(rewriter, loc, ref.back()));
1501 Block *continueBlock =
1507 Block *curEntryBlock = initBlock;
1508 Block *nextEntryBlock;
1509 for (
auto srcIndex : llvm::reverse(ref)) {
1510 if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1513 Value srcStride = srcDesc.
stride(rewriter, loc, srcIndex);
1514 if (srcIndex == ref.front()) {
1515 rewriter.
create<LLVM::BrOp>(loc, srcStride, continueBlock);
1521 loc, LLVM::ICmpPredicate::ne, srcDesc.
size(rewriter, loc, srcIndex),
1528 rewriter.
create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1529 srcStride, nextEntryBlock, llvm::None);
1530 curEntryBlock = nextEntryBlock;
1536 static void fillInDynamicStridesForMemDescriptor(
1538 TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1541 if (srcType.getRank() > dstType.getRank())
1542 fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1543 srcDesc, dstDesc, reassociation);
1545 fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1552 template <
typename ReshapeOp>
1553 class ReassociatingReshapeOpConversion
1557 using ReshapeOpAdaptor =
typename ReshapeOp::Adaptor;
1560 matchAndRewrite(ReshapeOp reshapeOp,
typename ReshapeOp::Adaptor adaptor,
1562 MemRefType dstType = reshapeOp.getResultType();
1563 MemRefType srcType = reshapeOp.getSrcType();
1569 reshapeOp,
"failed to get stride and offset exprs");
1573 Location loc = reshapeOp->getLoc();
1575 rewriter, loc, this->typeConverter->
convertType(dstType));
1582 Type llvmIndexType =
1585 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1586 srcStaticShape, srcDesc, dstStaticShape);
1588 dstDesc.
setSize(rewriter, loc, en.index(), en.value());
1590 if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1593 }
else if (srcType.getLayout().isIdentity() &&
1594 dstType.getLayout().isIdentity()) {
1595 Value c1 = rewriter.
create<LLVM::ConstantOp>(loc, llvmIndexType,
1598 for (
auto dimIndex :
1599 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1600 dstDesc.
setStride(rewriter, loc, dimIndex, stride);
1601 stride = rewriter.
create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1606 fillInDynamicStridesForMemDescriptor(
1607 rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1608 srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1610 rewriter.
replaceOp(reshapeOp, {dstDesc});
1624 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1626 auto loc = subViewOp.getLoc();
1628 auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1629 auto sourceElementTy =
1630 typeConverter->
convertType(sourceMemRefType.getElementType());
1632 auto viewMemRefType = subViewOp.getType();
1634 memref::SubViewOp::inferResultType(
1635 subViewOp.getSourceType(),
1639 .cast<MemRefType>();
1640 auto targetElementTy =
1641 typeConverter->
convertType(viewMemRefType.getElementType());
1642 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1643 if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1653 if (
failed(successStrides))
1663 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1664 Value bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1667 viewMemRefType.getMemorySpaceAsInt()),
1669 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1672 extracted = sourceMemRef.alignedPtr(rewriter, loc);
1673 bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1676 viewMemRefType.getMemorySpaceAsInt()),
1678 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1680 size_t inferredShapeRank = inferredType.getRank();
1681 size_t resultShapeRank = viewMemRefType.getRank();
1685 strideValues.reserve(inferredShapeRank);
1686 for (
unsigned i = 0; i < inferredShapeRank; ++i)
1687 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1691 if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1692 targetMemRef.setConstantOffset(rewriter, loc, offset);
1694 Value baseOffset = sourceMemRef.offset(rewriter, loc);
1698 for (
unsigned i = 0, e =
std::min(inferredShapeRank,
1699 subViewOp.getMixedOffsets().size());
1703 subViewOp.isDynamicOffset(i)
1704 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1705 : rewriter.
create<LLVM::ConstantOp>(
1708 Value mul = rewriter.
create<LLVM::MulOp>(loc, offset, strideValues[i]);
1709 baseOffset = rewriter.
create<LLVM::AddOp>(loc, baseOffset, mul);
1711 targetMemRef.setOffset(rewriter, loc, baseOffset);
1717 assert(mixedSizes.size() == mixedStrides.size() &&
1718 "expected sizes and strides of equal length");
1719 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1720 for (
int i = inferredShapeRank - 1,
j = resultShapeRank - 1;
1721 i >= 0 &&
j >= 0; --i) {
1722 if (unusedDims.test(i))
1729 if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1734 int64_t staticSize =
1736 if (staticSize != ShapedType::kDynamicSize) {
1737 size = rewriter.
create<LLVM::ConstantOp>(
1743 rewriter.
create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1744 auto cast = rewriter.
create<UnrealizedConversionCastOp>(
1745 loc, llvmIndexType, dim);
1746 size = cast.getResult(0);
1748 stride = rewriter.
create<LLVM::ConstantOp>(
1753 subViewOp.isDynamicSize(i)
1754 ? adaptor.
getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1755 : rewriter.
create<LLVM::ConstantOp>(
1758 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1759 stride = rewriter.
create<LLVM::ConstantOp>(
1763 subViewOp.isDynamicStride(i)
1764 ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1765 : rewriter.
create<LLVM::ConstantOp>(
1768 subViewOp.getStaticStride(i)));
1769 stride = rewriter.
create<LLVM::MulOp>(loc, stride, strideValues[i]);
1772 targetMemRef.setSize(rewriter, loc,
j, size);
1773 targetMemRef.setStride(rewriter, loc,
j, stride);
1777 rewriter.
replaceOp(subViewOp, {targetMemRef});
1794 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1796 auto loc = transposeOp.getLoc();
1800 if (transposeOp.getPermutation().isIdentity())
1804 rewriter, loc, typeConverter->
convertType(transposeOp.getShapedType()));
1808 targetMemRef.setAllocatedPtr(rewriter, loc,
1809 viewMemRef.allocatedPtr(rewriter, loc));
1810 targetMemRef.setAlignedPtr(rewriter, loc,
1811 viewMemRef.alignedPtr(rewriter, loc));
1814 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1817 for (
const auto &en :
1819 int sourcePos = en.index();
1820 int targetPos = en.value().cast<
AffineDimExpr>().getPosition();
1821 targetMemRef.setSize(rewriter, loc, targetPos,
1822 viewMemRef.size(rewriter, loc, sourcePos));
1823 targetMemRef.setStride(rewriter, loc, targetPos,
1824 viewMemRef.stride(rewriter, loc, sourcePos));
1827 rewriter.
replaceOp(transposeOp, {targetMemRef});
1844 unsigned idx)
const {
1845 assert(idx < shape.size());
1846 if (!ShapedType::isDynamic(shape[idx]))
1847 return createIndexConstant(rewriter, loc, shape[idx]);
1850 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1851 return dynamicSizes[nDynamic];
1860 Value runningStride,
unsigned idx)
const {
1861 assert(idx < strides.size());
1862 if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1863 return createIndexConstant(rewriter, loc, strides[idx]);
1865 return runningStride
1866 ? rewriter.
create<LLVM::MulOp>(loc, runningStride, nextSize)
1868 assert(!runningStride);
1869 return createIndexConstant(rewriter, loc, 1);
1873 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1875 auto loc = viewOp.getLoc();
1877 auto viewMemRefType = viewOp.getType();
1878 auto targetElementTy =
1879 typeConverter->
convertType(viewMemRefType.getElementType());
1880 auto targetDescTy = typeConverter->
convertType(viewMemRefType);
1881 if (!targetDescTy || !targetElementTy ||
1884 return viewOp.emitWarning(
"Target descriptor type not converted to LLVM"),
1890 if (
failed(successStrides))
1891 return viewOp.emitWarning(
"cannot cast to non-strided shape"),
failure();
1892 assert(offset == 0 &&
"expected offset to be 0");
1896 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1897 return viewOp.emitWarning(
"cannot cast to non-contiguous shape"),
1905 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1906 auto srcMemRefType = viewOp.getSource().
getType().
cast<MemRefType>();
1907 Value bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1910 srcMemRefType.getMemorySpaceAsInt()),
1912 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1915 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1916 alignedPtr = rewriter.
create<LLVM::GEPOp>(
1917 loc, alignedPtr.
getType(), alignedPtr, adaptor.getByteShift());
1918 bitcastPtr = rewriter.
create<LLVM::BitcastOp>(
1921 srcMemRefType.getMemorySpaceAsInt()),
1923 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1928 targetMemRef.setOffset(rewriter, loc,
1929 createIndexConstant(rewriter, loc, offset));
1932 if (viewMemRefType.getRank() == 0)
1936 Value stride =
nullptr, nextSize =
nullptr;
1937 for (
int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1939 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1940 adaptor.getSizes(), i);
1941 targetMemRef.setSize(rewriter, loc, i, size);
1943 stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1944 targetMemRef.setStride(rewriter, loc, i, stride);
1948 rewriter.
replaceOp(viewOp, {targetMemRef});
1960 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1961 switch (atomicOp.getKind()) {
1962 case arith::AtomicRMWKind::addf:
1963 return LLVM::AtomicBinOp::fadd;
1964 case arith::AtomicRMWKind::addi:
1965 return LLVM::AtomicBinOp::add;
1966 case arith::AtomicRMWKind::assign:
1967 return LLVM::AtomicBinOp::xchg;
1968 case arith::AtomicRMWKind::maxs:
1970 case arith::AtomicRMWKind::maxu:
1971 return LLVM::AtomicBinOp::umax;
1972 case arith::AtomicRMWKind::mins:
1974 case arith::AtomicRMWKind::minu:
1975 return LLVM::AtomicBinOp::umin;
1976 case arith::AtomicRMWKind::ori:
1977 return LLVM::AtomicBinOp::_or;
1978 case arith::AtomicRMWKind::andi:
1979 return LLVM::AtomicBinOp::_and;
1983 llvm_unreachable(
"Invalid AtomicRMWKind");
1986 struct AtomicRMWOpLowering :
public LoadStoreOpLowering<memref::AtomicRMWOp> {
1990 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1992 if (
failed(match(atomicOp)))
1994 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1997 auto resultType = adaptor.getValue().getType();
1998 auto memRefType = atomicOp.getMemRefType();
2000 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2001 adaptor.getIndices(), rewriter);
2003 atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2004 LLVM::AtomicOrdering::acq_rel);
2016 AllocaScopeOpLowering,
2017 AtomicRMWOpLowering,
2018 AssumeAlignmentOpLowering,
2020 GenericAtomicRMWOpLowering,
2021 GlobalMemrefOpLowering,
2022 GetGlobalMemrefOpLowering,
2024 MemRefCastOpLowering,
2025 MemRefCopyOpLowering,
2026 MemRefReinterpretCastOpLowering,
2027 MemRefReshapeOpLowering,
2030 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2031 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2035 ViewOpLowering>(converter);
2039 patterns.
add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2041 patterns.
add<AllocOpLowering, DeallocOpLowering>(converter);
2045 struct MemRefToLLVMPass :
public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2046 MemRefToLLVMPass() =
default;
2048 void runOnOperation()
override {
2050 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2052 dataLayoutAnalysis.getAtOrAbove(op));
2057 options.useGenericFunctions = useGenericFunctions;
2060 options.overrideIndexBitwidth(indexBitwidth);
2063 &dataLayoutAnalysis);
2069 signalPassFailure();
2075 return std::make_unique<MemRefToLLVMPass>();
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType)
Builds IR extracting the offset from the descriptor.
MLIRContext * getContext() const
void addLegalOp(OperationName op)
Register the given operations as legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
const LowerToLLVMOptions & getOptions() const
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the allocated pointer into the descriptor.
std::unique_ptr< Pass > createMemRefToLLVMPass()
operand_range getOperands()
Returns an iterator on the underlying Value's.
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Value getOperand(unsigned idx)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp)
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
BlockListType::iterator iterator
static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType)
Builds IR extracting the pointer to the first element of the size array.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Derived class that automatically populates legalization information for different LLVM ops...
static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType)
Builds IR extracting the aligned pointer from the descriptor.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
T lookup(T from) const
Lookup a mapped value within the map.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
BlockArgument getArgument(unsigned i)
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the aligned pointer into the descriptor.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
void populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
IntegerAttr getI32IntegerAttr(int32_t value)
static llvm::DenseMap< int64_t, int64_t > getExpandedDimToCollapsedDimMap(ArrayRef< AffineMap > reassociation)
Compute a map that for a given dimension of the expanded type gives the dimension in the collapsed ty...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
static void setStride(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
OpListType::iterator iterator
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride)
Lowering for AllocOp and AllocaOp.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
IntegerAttr getI64IntegerAttr(int64_t value)
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static SmallVector< Value > getAsValues(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > valueOrAttrVec)
Helper function to convert a vector of OpFoldResults into a vector of Values.
Attributes are known-constant values of operations.
static void setSize(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static unsigned getMemRefEltSizeInBytes(MemRefType memRefType)
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
IntegerType getIntegerType(unsigned width)
Use malloc for for heap allocations.
Use aligned_alloc for heap allocations.
static Value strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Location getLoc()
The source location the operation was defined or derived from.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, Type elemPtrPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
BlockArgListType getArguments()
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Location getLoc() const
Return the location of this value.
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, Type indexType)
static llvm::ManagedStatic< PassManagerOptions > options
Operation * getTerminator()
Get the terminator operation of this block.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
RAII guard to reset the insertion point of the builder when destroyed.
Type getIndexType()
Returns the type of array element in this descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset)
Builds IR inserting the offset into the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
Type getType() const
Return the type of this value.
This class provides a shared interface for ranked and unranked memref types.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Type getElementType() const
Returns the element type of this memref type.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
A dimensional identifier appearing in an affine expression.
Conversion from types to the LLVM IR dialect.
BoolAttr getBoolAttr(bool value)
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
Options to control the LLVM lowering.
This class implements a pattern rewriter for use with ConversionPatterns.
static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride)
Builds IR inserting the pos-th stride into the descriptor.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size)
Builds IR inserting the pos-th size into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, Type elemPtrPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
Stores data layout objects for each operation that specifies the data layout above and below the give...
AllocLowering allocLowering
result_range getResults()
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType)
Helper determining if a memref is static-shape and contiguous-row-major layout, while still allowing ...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
The main mechanism for performing data layout queries.