37 enum class CuSparseFormat {
50 static void markAsGPUContainer(ModuleOp topModule) {
51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
57 static gpu::GPUModuleOp genGPUModule(
OpBuilder &builder, ModuleOp topModule) {
58 for (
auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
60 markAsGPUContainer(topModule);
62 return builder.
create<gpu::GPUModuleOp>(topModule->getLoc(),
67 static gpu::GPUFuncOp genGPUFunc(
OpBuilder &builder, gpu::GPUModuleOp gpuModule,
71 unsigned kernelNumber = 0;
75 (
"kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76 }
while (gpuModule.lookupSymbol(kernelName));
80 for (
unsigned i = 0, e = args.size(); i < e; i++)
81 argsTp.push_back(args[i].getType());
84 builder.
create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85 gpuFunc->
setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
91 static Value genLaunchGPUFunc(
OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
94 unsigned numThreads) {
102 .
create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
115 MemRefType memTp = cast<MemRefType>(mem.
getType());
118 Value cast = builder.
create<memref::CastOp>(loc, resTp, mem);
119 builder.
create<gpu::HostRegisterOp>(loc, cast);
126 builder.
create<gpu::HostUnregisterOp>(loc, cast);
139 builder.
create<gpu::WaitOp>(loc,
Type(), operands);
148 auto tp = cast<ShapedType>(mem.
getType());
149 auto elemTp = tp.getElementType();
150 auto shape = tp.getShape();
153 for (
unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154 if (shape[r] == ShapedType::kDynamic) {
156 dynamicSizes.push_back(dimOp);
167 return builder.
create<memref::AllocOp>(loc, memTp, size).getResult();
181 return genAllocBuffer(builder, loc, builder.
getI8Type(), size, token);
187 return builder.
create<gpu::DeallocOp>(loc, token.
getType(), token, mem)
194 return builder.
create<gpu::MemcpyOp>(loc, token.
getType(), token, dst, src)
201 Value firstToken = genFirstWait(builder, loc);
202 auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203 Value devMem = alloc.getResult(0);
204 Value depToken = alloc.getAsyncToken();
205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
212 auto tensorType = llvm::cast<ShapedType>(tensor.
getType());
215 return rewriter.
create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
227 bool useHostRegistrationForOut) {
230 for (
Value s : scalars)
233 for (
Value b : buffers) {
234 if (useHostRegistrationForOut) {
235 out = genHostRegisterMemref(builder, loc, b);
237 useHostRegistrationForOut =
false;
240 args.push_back(genAllocCopy(builder, loc, b, tokens));
253 unsigned base = scalars.size();
254 for (
unsigned i = base, e = args.size(); i < e; i++) {
259 genHostUnregisterMemref(builder, loc, out);
264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
266 firstToken = genFirstWait(builder, loc);
268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
273 static void genGPUCode(
PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274 scf::ParallelOp forallOp,
285 for (
Value c : constants)
287 for (
Value s : scalars)
289 for (
Value b : buffers)
296 Value bid = rewriter.
create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297 Value bsz = rewriter.
create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298 Value tid = rewriter.
create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299 Value gsz = rewriter.
create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300 Value mul = rewriter.
create<arith::MulIOp>(loc, bid, bsz);
301 Value row = rewriter.
create<arith::AddIOp>(loc, mul, tid);
302 Value inc = rewriter.
create<arith::MulIOp>(loc, bsz, gsz);
310 Value upper = irMap.
lookup(forallOp.getUpperBound()[0]);
311 scf::ForOp forOp = rewriter.
create<scf::ForOp>(loc, row, upper, inc);
317 forOp.getRegion().begin(), irMap);
321 rewriter.
create<gpu::ReturnOp>(gpuFunc->getLoc());
329 static bool matchAddOfArgs(
Block *block,
Value val) {
331 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
334 return (def->getOperand(0) == a && def->getOperand(1) == b) ||
335 (def->getOperand(0) == b && def->getOperand(1) == a);
342 static bool matchMulOfArgs(
Block *block,
Value val) {
344 if (isa<arith::MulFOp, arith::MulIOp>(def)) {
347 return (def->getOperand(0) == a && def->getOperand(1) == b) ||
348 (def->getOperand(0) == b && def->getOperand(1) == a);
355 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
357 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
358 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
360 return (def->getOperand(0) == x &&
361 matchMulOfArgs(op.
getBlock(), def->getOperand(1))) ||
362 (def->getOperand(1) == x &&
363 matchMulOfArgs(op.
getBlock(), def->getOperand(0)));
370 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
375 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
378 if (s_out == redOp->getOperand(0))
379 other = redOp->getOperand(1);
380 else if (s_out == redOp->getOperand(1))
381 other = redOp->getOperand(0);
386 if (
auto unOp = other.
getDefiningOp<sparse_tensor::UnaryOp>()) {
387 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
390 auto yieldUn = cast<sparse_tensor::YieldOp>(
391 unOp.getRegion(0).front().getTerminator());
392 auto yieldRed = cast<sparse_tensor::YieldOp>(
393 redOp.getRegion().front().getTerminator());
394 return matchMulOfArgs(op.
getBlock(), yieldUn.getOperand(0)) &&
395 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
402 static bool isDenseTensor(
Value v) {
404 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
418 isAdmissibleMetaData(aTp);
442 assert(dims.size() == 2);
443 return dims[0] == dims[1] && dims[0] > 1;
456 return CuSparseFormat::kNone;
458 if (isAdmissibleCOO(aTp))
459 #ifdef CUSPARSE_COO_AOS
460 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
462 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
464 if (isAdmissibleCSR(aTp))
465 return CuSparseFormat::kCSR;
466 if (isAdmissibleCSC(aTp))
467 return CuSparseFormat::kCSC;
468 if (isAdmissibleBSR(aTp))
469 return CuSparseFormat::kBSR;
470 return CuSparseFormat::kNone;
475 CuSparseFormat format,
bool enableRT) {
476 if (format == CuSparseFormat::kCOO) {
488 CuSparseFormat format,
bool enableRT) {
489 bool isCOO = format == CuSparseFormat::kCOO;
490 if (isCOO && !enableRT)
501 CuSparseFormat format,
bool enableRT) {
502 if (format == CuSparseFormat::kCOO) {
506 return builder.
create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
507 sz1, sz2, nseA, rowA, colA, valA);
509 #ifdef CUSPARSE_COO_AOS
511 return builder.
create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
512 sz1, sz2, nseA, rowA, valA);
514 llvm_unreachable(
"gpu::CreateCooAoSOp is deprecated");
518 if (format == CuSparseFormat::kCSR)
519 return builder.
create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
520 sz2, nseA, rowA, colA, valA);
521 if (format == CuSparseFormat::kCSC)
522 return builder.
create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
523 sz2, nseA, rowA, colA, valA);
527 assert(format == CuSparseFormat::kBSR);
529 assert(dims.size() == 2 && dims[0] == dims[1]);
530 uint64_t b = dims[0];
532 Value bRows = builder.
create<arith::DivUIOp>(loc, sz1, bSz);
533 Value bCols = builder.
create<arith::DivUIOp>(loc, sz2, bSz);
536 return builder.
create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
537 bCols, bNum, bSz, bSz, rowA, colA,
543 linalg::GenericOp op,
bool enableRT) {
554 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT,
true);
555 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
562 Value nseA = rewriter.
create<NumberOfEntriesOp>(loc, a);
565 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
566 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
568 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
569 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
570 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
571 Value memX = genTensorToMemref(rewriter, loc, x);
572 Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
573 Value memY = genTensorToMemref(rewriter, loc, y);
574 Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
575 genBlockingWait(rewriter, loc, tokens);
583 Value token = genFirstWait(rewriter, loc);
585 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
586 nseA, rowA, colA, valA, format, enableRT);
589 auto dvecX = rewriter.
create<gpu::CreateDnTensorOp>(
590 loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
591 Value dnX = dvecX.getResult(0);
592 token = dvecX.getAsyncToken();
593 auto dvecY = rewriter.
create<gpu::CreateDnTensorOp>(
594 loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
595 Value dnY = dvecY.getResult(0);
596 token = dvecY.getAsyncToken();
597 auto dnYType = llvm::cast<ShapedType>(y.
getType()).getElementType();
600 auto bufferComp = rewriter.
create<gpu::SpMVBufferSizeOp>(
601 loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
603 Value bufferSz = bufferComp.getResult(0);
604 token = bufferComp.getAsyncToken();
605 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
606 Value buffer = buf.getResult(0);
607 token = buf.getAsyncToken();
610 auto spmvComp = rewriter.
create<gpu::SpMVOp>(
611 loc, tokenTp, token, spMatA, dnX, dnY, dnYType, buffer);
612 token = spmvComp.getAsyncToken();
615 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
617 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
619 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
621 token = genDeallocMemRef(rewriter, loc, rowA, token);
623 token = genDeallocMemRef(rewriter, loc, colA, token);
624 token = genDeallocMemRef(rewriter, loc, valA, token);
625 token = genDeallocMemRef(rewriter, loc, buffer, token);
626 token = genDeallocMemRef(rewriter, loc, vecX, token);
627 token = genCopyMemRef(rewriter, loc, memY, vecY, token);
628 token = genDeallocMemRef(rewriter, loc, vecY, token);
629 tokens.push_back(token);
630 genBlockingWait(rewriter, loc, tokens);
640 linalg::GenericOp op,
bool enableRT) {
651 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT,
false);
652 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
659 Value nseA = rewriter.
create<NumberOfEntriesOp>(loc, a);
663 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
664 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
666 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
667 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
668 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
669 Value bufB = genTensorToMemref(rewriter, loc, b);
670 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
671 Value bufC = genTensorToMemref(rewriter, loc, c);
672 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
673 genBlockingWait(rewriter, loc, tokens);
681 Value token = genFirstWait(rewriter, loc);
683 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
684 nseA, rowA, colA, valA, format, enableRT);
687 auto dmatB = rewriter.
create<gpu::CreateDnTensorOp>(
688 loc, dnTensorHandleTp, tokenTp, token, matB,
690 Value dnB = dmatB.getResult(0);
691 token = dmatB.getAsyncToken();
692 auto dmatC = rewriter.
create<gpu::CreateDnTensorOp>(
693 loc, dnTensorHandleTp, tokenTp, token, matC,
695 Value dnC = dmatC.getResult(0);
696 token = dmatC.getAsyncToken();
697 auto dmatCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
700 auto bufferComp = rewriter.
create<gpu::SpMMBufferSizeOp>(
701 loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
703 Value bufferSz = bufferComp.getResult(0);
704 token = bufferComp.getAsyncToken();
705 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
706 Value buffer = buf.getResult(0);
707 token = buf.getAsyncToken();
708 auto dnCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
711 auto spmmComp = rewriter.
create<gpu::SpMMOp>(
712 loc, tokenTp, token, spMatA, dnB, dnC, dnCType, buffer);
713 token = spmmComp.getAsyncToken();
716 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
718 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
720 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
722 token = genDeallocMemRef(rewriter, loc, rowA, token);
724 token = genDeallocMemRef(rewriter, loc, colA, token);
725 token = genDeallocMemRef(rewriter, loc, valA, token);
726 token = genDeallocMemRef(rewriter, loc, buffer, token);
727 token = genDeallocMemRef(rewriter, loc, matB, token);
728 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
729 token = genDeallocMemRef(rewriter, loc, matC, token);
730 tokens.push_back(token);
731 genBlockingWait(rewriter, loc, tokens);
741 linalg::GenericOp op,
bool enableRT) {
749 auto format = CuSparseFormat::kCSR;
753 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
761 Value nseA = rewriter.
create<NumberOfEntriesOp>(loc, a);
762 Value nseB = rewriter.
create<NumberOfEntriesOp>(loc, b);
766 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
767 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
769 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
770 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
772 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
773 Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
774 Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
775 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
776 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
777 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
778 genBlockingWait(rewriter, loc, tokens);
786 Value token = genFirstWait(rewriter, loc);
788 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
789 nseA, rowA, colA, valA, format, enableRT);
793 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
794 nseB, rowB, colB, valB, format, enableRT);
801 Value mplus1 = rewriter.
create<arith::AddIOp>(loc, szm, one);
802 auto e1 = genAllocBuffer(rewriter, loc, cTp.
getPosType(), mplus1, token);
803 Value rowC = e1.getResult(0);
804 token = e1.getAsyncToken();
805 auto e2 = genAllocBuffer(rewriter, loc, cTp.
getCrdType(), zero, token);
806 Value colC = e2.getResult(0);
807 token = e2.getAsyncToken();
808 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
809 Value valC = e3.getResult(0);
810 token = e3.getAsyncToken();
812 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
813 zero, rowC, colC, valC, format, enableRT);
819 rewriter.
create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
822 Operation *work1 = rewriter.
create<gpu::SpGEMMWorkEstimationOrComputeOp>(
823 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
824 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
825 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
828 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
829 Value buffer1 = buf1.getResult(0);
830 token = buf1.getAsyncToken();
831 Operation *work2 = rewriter.
create<gpu::SpGEMMWorkEstimationOrComputeOp>(
832 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
833 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
835 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
839 Operation *compute1 = rewriter.
create<gpu::SpGEMMWorkEstimationOrComputeOp>(
840 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
841 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
842 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
845 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
846 Value buffer2 = buf2.getResult(0);
847 token = buf2.getAsyncToken();
848 Operation *compute2 = rewriter.
create<gpu::SpGEMMWorkEstimationOrComputeOp>(
849 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
850 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
851 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
856 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
859 auto a2 = genAllocBuffer(rewriter, loc, cTp.
getCrdType(), nnz, token);
860 colC = a2.getResult(0);
861 token = a2.getAsyncToken();
862 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
863 valC = a3.getResult(0);
864 token = a3.getAsyncToken();
868 loc, tokenTp, token, spMatC, rowC, colC, valC);
871 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
872 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
873 token =
copy->getResult(0);
878 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
881 token = rewriter.
create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
883 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
885 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
887 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
889 token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
890 token = genCopyMemRef(rewriter, loc, colH, colC, token);
891 token = genCopyMemRef(rewriter, loc, valH, valC, token);
892 token = genDeallocMemRef(rewriter, loc, rowA, token);
893 token = genDeallocMemRef(rewriter, loc, colA, token);
894 token = genDeallocMemRef(rewriter, loc, valA, token);
895 token = genDeallocMemRef(rewriter, loc, rowB, token);
896 token = genDeallocMemRef(rewriter, loc, colB, token);
897 token = genDeallocMemRef(rewriter, loc, valB, token);
898 token = genDeallocMemRef(rewriter, loc, rowC, token);
899 token = genDeallocMemRef(rewriter, loc, colC, token);
900 token = genDeallocMemRef(rewriter, loc, valC, token);
901 token = genDeallocMemRef(rewriter, loc, buffer1, token);
902 token = genDeallocMemRef(rewriter, loc, buffer2, token);
903 tokens.push_back(token);
904 genBlockingWait(rewriter, loc, tokens);
908 Value vt = rewriter.
create<bufferization::ToTensorOp>(loc, valH);
909 Value rt = rewriter.
create<bufferization::ToTensorOp>(loc, rowH);
910 Value ct = rewriter.
create<bufferization::ToTensorOp>(loc, colH);
918 linalg::GenericOp op) {
926 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
933 Value bufA = genTensorToMemref(rewriter, loc, A);
934 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
935 Value bufB = genTensorToMemref(rewriter, loc, B);
936 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
937 Value bufC = genTensorToMemref(rewriter, loc, C);
938 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
939 genBlockingWait(rewriter, loc, tokens);
950 Value token = genFirstWait(rewriter, loc);
952 loc, spMatHandleTp, tokenTp, token, szm, szk,
953 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
956 auto dmatB = rewriter.
create<gpu::CreateDnTensorOp>(
957 loc, dnTensorHandleTp, tokenTp, token, matB,
959 Value dnB = dmatB.getResult(0);
960 token = dmatB.getAsyncToken();
961 auto dmatC = rewriter.
create<gpu::CreateDnTensorOp>(
962 loc, dnTensorHandleTp, tokenTp, token, matC,
964 Value dnC = dmatC.getResult(0);
965 token = dmatC.getAsyncToken();
966 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
971 auto bufferComp = rewriter.
create<gpu::SpMMBufferSizeOp>(
972 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
973 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
975 token = bufferComp.getAsyncToken();
978 Value bufferSz1 = bufferComp.getResult(0);
979 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
980 Value buffer1 = buf1.getResult(0);
981 token = buf1.getAsyncToken();
982 Value bufferSz2 = bufferComp.getResult(1);
983 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
984 Value buffer2 = buf2.getResult(0);
985 token = buf2.getAsyncToken();
986 Value bufferSz3 = bufferComp.getResult(2);
987 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
988 Value buffer3 = buf3.getResult(0);
989 token = buf3.getAsyncToken();
992 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
993 auto spmmComp = rewriter.
create<gpu::SpMMOp>(
994 loc, tokenTp, token, spMatA, dnB, dnC, dnCType,
996 token = spmmComp.getAsyncToken();
999 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1001 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1003 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1006 token = genDeallocMemRef(rewriter, loc, buffer1, token);
1007 token = genDeallocMemRef(rewriter, loc, buffer2, token);
1008 token = genDeallocMemRef(rewriter, loc, buffer3, token);
1009 token = genDeallocMemRef(rewriter, loc, matA, token);
1010 token = genDeallocMemRef(rewriter, loc, matB, token);
1011 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1012 token = genDeallocMemRef(rewriter, loc, matC, token);
1013 tokens.push_back(token);
1014 genBlockingWait(rewriter, loc, tokens);
1024 linalg::GenericOp op,
bool enableRT) {
1035 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT,
false);
1036 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1037 format == CuSparseFormat::kCSC)
1045 Value nseC = rewriter.
create<NumberOfEntriesOp>(loc, c);
1049 Value bufA = genTensorToMemref(rewriter, loc, a);
1050 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1051 Value bufB = genTensorToMemref(rewriter, loc, b);
1052 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1053 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1054 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
1056 Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1057 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
1058 Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1059 genBlockingWait(rewriter, loc, tokens);
1067 Value token = genFirstWait(rewriter, loc);
1068 auto dmatA = rewriter.
create<gpu::CreateDnTensorOp>(
1070 Value dnA = dmatA.getResult(0);
1071 token = dmatA.getAsyncToken();
1072 auto dmatB = rewriter.
create<gpu::CreateDnTensorOp>(
1074 Value dnB = dmatB.getResult(0);
1075 token = dmatB.getAsyncToken();
1077 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1078 nseC, rowC, colC, valC, format, enableRT);
1081 auto dnCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
1084 auto bufferComp = rewriter.
create<gpu::SDDMMBufferSizeOp>(
1085 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1086 Value bufferSz = bufferComp.getResult(0);
1087 token = bufferComp.getAsyncToken();
1088 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1089 Value buffer = buf.getResult(0);
1090 token = buf.getAsyncToken();
1093 auto sddmmComp = rewriter.
create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1094 spMatC, dnCType, buffer);
1095 token = sddmmComp.getAsyncToken();
1098 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1100 token = rewriter.
create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1102 token = rewriter.
create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1104 token = genDeallocMemRef(rewriter, loc, buffer, token);
1105 token = genDeallocMemRef(rewriter, loc, matA, token);
1106 token = genDeallocMemRef(rewriter, loc, matB, token);
1107 token = genDeallocMemRef(rewriter, loc, rowC, token);
1109 token = genDeallocMemRef(rewriter, loc, colC, token);
1110 token = genCopyMemRef(rewriter, loc, memV, valC, token);
1111 token = genDeallocMemRef(rewriter, loc, valC, token);
1112 tokens.push_back(token);
1113 genBlockingWait(rewriter, loc, tokens);
1143 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1152 Value val = o.get();
1154 if (auto arg = dyn_cast<BlockArgument>(val))
1155 block = arg.getOwner();
1157 block = val.getDefiningOp()->getBlock();
1158 if (!isNestedIn(block, forallOp))
1159 invariants.insert(val);
1167 for (
Value val : invariants) {
1170 constants.push_back(val);
1172 scalars.push_back(val);
1173 else if (isa<MemRefType>(tp))
1174 buffers.push_back(val);
1185 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1189 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1190 auto gpuModule = genGPUModule(rewriter, topModule);
1191 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1192 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1198 genBlockingWait(rewriter, loc, tokens);
1201 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1203 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1205 genBlockingWait(rewriter, loc, tokens);
1212 static bool isNestedIn(
Block *block, scf::ParallelOp forallOp) {
1220 unsigned numThreads;
1237 if (op.getNumDpsInits() != 1)
1240 const unsigned numLoops = op.getNumLoops();
1242 const auto iteratorTypes = op.getIteratorTypesArray();
1254 if (numLoops == 2 && numTensors == 3 &&
1257 maps == infer({{i,
j}, {
j}, {i}}) && matchSumOfMultOfArgs(op)) {
1258 return rewriteSpMV(rewriter, op, enableRT);
1262 if (numLoops == 3 && numTensors == 3 &&
1266 maps == infer({{i, k}, {k,
j}, {i,
j}}) && matchSumOfMultOfArgs(op)) {
1268 return rewriteSpGEMM(rewriter, op, enableRT);
1270 return rewrite2To4SpMM(rewriter, op);
1271 return rewriteSpMM(rewriter, op, enableRT);
1275 if (numLoops == 3 && numTensors == 3 &&
1279 maps == infer({{i, k}, {k,
j}, {i,
j}}) &&
1280 matchSumReductionOfMulUnary(op)) {
1281 return rewriteSDDMM(rewriter, op, enableRT);
1304 unsigned numThreads) {
1305 patterns.
add<ForallRewriter>(patterns.
getContext(), numThreads);
1310 patterns.
add<LinalgOpRewriter>(patterns.
getContext(), enableRT);
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
A wrapper around RankedTensorType, which has three goals:
bool isSingletonLvl(Level l) const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
bool isDenseLvl(Level l) const
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isOrderedLvl(Level l) const
bool isUniqueLvl(Level l) const
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToCoordinatesBufferOp.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value genToValues(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToValuesOp.
Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl)
Infers the result type and generates ToPositionsOp.
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor, Level lvl, Level cooStart)
Infers the result type and generates ToCoordinatesOp.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Utility class for the GPU dialect to represent triples of Values accessible through ....
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.