31#include "llvm/Support/Casting.h"
39enum class CuSparseFormat {
52static void markAsGPUContainer(ModuleOp topModule) {
53 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
54 UnitAttr::get(topModule->getContext()));
59static gpu::GPUModuleOp genGPUModule(
OpBuilder &builder, ModuleOp topModule) {
60 for (
auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
62 markAsGPUContainer(topModule);
64 return gpu::GPUModuleOp::create(builder, topModule->getLoc(),
69static gpu::GPUFuncOp genGPUFunc(
OpBuilder &builder, gpu::GPUModuleOp gpuModule,
73 unsigned kernelNumber = 0;
77 (
"kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
78 }
while (gpuModule.lookupSymbol(kernelName));
83 argsTp.push_back(arg.getType());
84 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
86 gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type);
87 gpuFunc.setKernel(
true);
92static Value genLaunchGPUFunc(
OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
95 unsigned numThreads) {
102 return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize,
116 MemRefType memTp = cast<MemRefType>(mem.
getType());
118 UnrankedMemRefType::get(memTp.getElementType(), 0);
119 Value cast = memref::CastOp::create(builder, loc, resTp, mem);
120 gpu::HostRegisterOp::create(builder, loc, cast);
127 gpu::HostUnregisterOp::create(builder, loc, cast);
133 return gpu::WaitOp::create(builder, loc, tokenType,
ValueRange())
140 gpu::WaitOp::create(builder, loc,
Type(), operands);
149 auto tp = cast<ShapedType>(mem.
getType());
150 auto elemTp = tp.getElementType();
151 auto shape = tp.getShape();
152 auto memTp = MemRefType::get(
shape, elemTp);
154 for (
unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
155 if (
shape[r] == ShapedType::kDynamic) {
157 dynamicSizes.push_back(dimOp);
160 return gpu::AllocOp::create(builder, loc,
TypeRange({memTp, token.
getType()}),
167 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
168 return memref::AllocOp::create(builder, loc, memTp, size).getResult();
174 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
175 return gpu::AllocOp::create(builder, loc,
TypeRange({memTp, token.
getType()}),
182 return genAllocBuffer(builder, loc, builder.
getI8Type(), size, token);
188 return gpu::DeallocOp::create(builder, loc, token.
getType(), token, mem)
195 return gpu::MemcpyOp::create(builder, loc, token.
getType(), token, dst, src)
202 Value firstToken = genFirstWait(builder, loc);
203 auto alloc = genAllocMemRef(builder, loc,
b, firstToken);
204 Value devMem = alloc.getResult(0);
205 Value depToken = alloc.getAsyncToken();
206 tokens.push_back(genCopyMemRef(builder, loc, devMem,
b, depToken));
213 auto tensorType = llvm::cast<ShapedType>(
tensor.getType());
215 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
216 return bufferization::ToBufferOp::create(rewriter, loc, memrefType,
tensor);
228 bool useHostRegistrationForOut) {
231 for (
Value s : scalars)
235 if (useHostRegistrationForOut) {
236 out = genHostRegisterMemref(builder, loc,
b);
238 useHostRegistrationForOut =
false;
241 args.push_back(genAllocCopy(builder, loc,
b, tokens));
259 unsigned base = scalars.size();
264 for (
unsigned i = base, e = args.size(); i < e; i++) {
265 unsigned bufIdx = i - base;
269 if (copyBack[bufIdx]) {
270 if (out && bufIdx == 0) {
271 genHostUnregisterMemref(builder, loc, out);
276 genCopyMemRef(builder, loc, buffers[bufIdx], args[i], kernelToken);
278 firstToken = genFirstWait(builder, loc);
280 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
285static void genGPUCode(
PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
286 scf::ParallelOp forallOp,
297 for (
Value c : constants)
299 for (
Value s : scalars)
308 Value bid = gpu::BlockIdOp::create(rewriter, loc, gpu::Dimension::x);
309 Value bsz = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
310 Value tid = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
311 Value gsz = gpu::GridDimOp::create(rewriter, loc, gpu::Dimension::x);
312 Value mul = arith::MulIOp::create(rewriter, loc, bid, bsz);
313 Value row = arith::AddIOp::create(rewriter, loc,
mul, tid);
314 Value inc = arith::MulIOp::create(rewriter, loc, bsz, gsz);
322 Value upper = irMap.
lookup(forallOp.getUpperBound()[0]);
323 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, row, upper, inc);
329 forOp.getRegion().begin(), irMap);
336 gpu::ReturnOp::create(rewriter, gpuFunc->getLoc());
344static bool matchAddOfArgs(
Block *block,
Value val) {
346 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
349 return (def->getOperand(0) == a && def->getOperand(1) ==
b) ||
350 (def->getOperand(0) ==
b && def->getOperand(1) == a);
357static bool matchMulOfArgs(
Block *block,
Value val) {
359 if (isa<arith::MulFOp, arith::MulIOp>(def)) {
362 return (def->getOperand(0) == a && def->getOperand(1) ==
b) ||
363 (def->getOperand(0) ==
b && def->getOperand(1) == a);
370static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
371 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
372 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
373 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
374 Value x = op.getBlock()->getArguments()[2];
375 return (def->getOperand(0) == x &&
376 matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
377 (def->getOperand(1) == x &&
378 matchMulOfArgs(op.getBlock(), def->getOperand(0)));
385static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
386 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
388 Value s_out = op.getBlock()->getArguments()[2];
390 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
393 if (s_out == redOp->getOperand(0))
394 other = redOp->getOperand(1);
395 else if (s_out == redOp->getOperand(1))
396 other = redOp->getOperand(0);
401 if (
auto unOp = other.
getDefiningOp<sparse_tensor::UnaryOp>()) {
402 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
405 auto yieldUn = cast<sparse_tensor::YieldOp>(
406 unOp.getRegion(0).front().getTerminator());
407 auto yieldRed = cast<sparse_tensor::YieldOp>(
408 redOp.getRegion().front().getTerminator());
409 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
410 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
417static bool isDenseTensor(
Value v) {
419 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
433 isAdmissibleMetaData(aTp);
457 assert(dims.size() == 2);
458 return dims[0] == dims[1] && dims[0] > 1;
470static bool isConversionInto24(
Value v) {
472 Value a = cnv.getResult();
473 Value d = cnv.getSource();
475 return isDenseTensor(d) && isAdmissible24(aTp);
488 return CuSparseFormat::kNone;
490 if (isAdmissibleCOO(aTp))
491#ifdef CUSPARSE_COO_AOS
492 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
494 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
496 if (isAdmissibleCSR(aTp))
497 return CuSparseFormat::kCSR;
498 if (isAdmissibleCSC(aTp))
499 return CuSparseFormat::kCSC;
500 if (isAdmissibleBSR(aTp))
501 return CuSparseFormat::kBSR;
502 return CuSparseFormat::kNone;
507 CuSparseFormat format,
bool enableRT) {
508 if (format == CuSparseFormat::kCOO) {
511 return ToCoordinatesOp::create(builder, loc, a, 0);
512 return ToCoordinatesBufferOp::create(builder, loc, a);
515 return ToPositionsOp::create(builder, loc, a, 1);
520 CuSparseFormat format,
bool enableRT) {
521 bool isCOO = format == CuSparseFormat::kCOO;
522 if (isCOO && !enableRT)
525 return ToCoordinatesOp::create(builder, loc, a, 1);
533 CuSparseFormat format,
bool enableRT) {
534 if (format == CuSparseFormat::kCOO) {
538 return gpu::CreateCooOp::create(builder, loc, handleTp, tokenTp, token,
539 sz1, sz2, nseA, rowA, colA, valA);
541#ifdef CUSPARSE_COO_AOS
543 return gpu::CreateCooAoSOp::create(builder, loc, handleTp, tokenTp, token,
544 sz1, sz2, nseA, rowA, valA);
546 llvm_unreachable(
"gpu::CreateCooAoSOp is deprecated");
550 if (format == CuSparseFormat::kCSR)
551 return gpu::CreateCsrOp::create(builder, loc, handleTp, tokenTp, token, sz1,
552 sz2, nseA, rowA, colA, valA);
553 if (format == CuSparseFormat::kCSC)
554 return gpu::CreateCscOp::create(builder, loc, handleTp, tokenTp, token, sz1,
555 sz2, nseA, rowA, colA, valA);
559 assert(format == CuSparseFormat::kBSR);
561 assert(dims.size() == 2 && dims[0] == dims[1]);
562 uint64_t
b = dims[0];
564 Value bRows = arith::DivUIOp::create(builder, loc, sz1, bSz);
565 Value bCols = arith::DivUIOp::create(builder, loc, sz2, bSz);
566 Value bNum = arith::DivUIOp::create(builder, loc, nseA,
568 return gpu::CreateBsrOp::create(builder, loc, handleTp, tokenTp, token, bRows,
569 bCols, bNum, bSz, bSz, rowA, colA, valA);
574 linalg::GenericOp op,
bool enableRT) {
576 Value a = op.getOperand(0);
577 Value x = op.getOperand(1);
578 Value y = op.getOperand(2);
585 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT,
true);
586 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
593 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
596 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
597 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
598 Value memV = ToValuesOp::create(rewriter, loc, a);
599 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
600 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
601 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
602 Value memX = genTensorToMemref(rewriter, loc, x);
603 Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
604 Value memY = genTensorToMemref(rewriter, loc, y);
605 Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
606 genBlockingWait(rewriter, loc, tokens);
614 Value token = genFirstWait(rewriter, loc);
616 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
617 nseA, rowA, colA, valA, format, enableRT);
620 auto dvecX = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
621 tokenTp, token, vecX, szX);
622 Value dnX = dvecX.getResult(0);
623 token = dvecX.getAsyncToken();
624 auto dvecY = gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp,
625 tokenTp, token, vecY, szY);
626 Value dnY = dvecY.getResult(0);
627 token = dvecY.getAsyncToken();
628 auto dnYType = llvm::cast<ShapedType>(y.
getType()).getElementType();
631 auto bufferComp = gpu::SpMVBufferSizeOp::create(
632 rewriter, loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
634 Value bufferSz = bufferComp.getResult(0);
635 token = bufferComp.getAsyncToken();
636 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
637 Value buffer = buf.getResult(0);
638 token = buf.getAsyncToken();
642 gpu::SpMVOp::create(rewriter, loc, tokenTp, token, spMatA, dnX, dnY,
644 token = spmvComp.getAsyncToken();
647 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
649 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnX)
651 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnY)
653 token = genDeallocMemRef(rewriter, loc, rowA, token);
655 token = genDeallocMemRef(rewriter, loc, colA, token);
656 token = genDeallocMemRef(rewriter, loc, valA, token);
657 token = genDeallocMemRef(rewriter, loc, buffer, token);
658 token = genDeallocMemRef(rewriter, loc, vecX, token);
659 token = genCopyMemRef(rewriter, loc, memY, vecY, token);
660 token = genDeallocMemRef(rewriter, loc, vecY, token);
661 tokens.push_back(token);
662 genBlockingWait(rewriter, loc, tokens);
672 linalg::GenericOp op,
bool enableRT) {
674 Value a = op.getOperand(0);
675 Value b = op.getOperand(1);
676 Value c = op.getOperand(2);
683 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT,
false);
684 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
691 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
695 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
696 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
697 Value memV = ToValuesOp::create(rewriter, loc, a);
698 Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
699 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
700 Value valA = genAllocCopy(rewriter, loc, memV, tokens);
701 Value bufB = genTensorToMemref(rewriter, loc,
b);
702 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
703 Value bufC = genTensorToMemref(rewriter, loc, c);
704 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
705 genBlockingWait(rewriter, loc, tokens);
713 Value token = genFirstWait(rewriter, loc);
715 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
716 nseA, rowA, colA, valA, format, enableRT);
720 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
722 Value dnB = dmatB.getResult(0);
723 token = dmatB.getAsyncToken();
725 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
727 Value dnC = dmatC.getResult(0);
728 token = dmatC.getAsyncToken();
729 auto dmatCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
732 auto bufferComp = gpu::SpMMBufferSizeOp::create(
733 rewriter, loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
735 Value bufferSz = bufferComp.getResult(0);
736 token = bufferComp.getAsyncToken();
737 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
738 Value buffer = buf.getResult(0);
739 token = buf.getAsyncToken();
740 auto dnCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
744 gpu::SpMMOp::create(rewriter, loc, tokenTp, token, spMatA, dnB, dnC,
746 token = spmmComp.getAsyncToken();
749 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
751 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
753 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
755 token = genDeallocMemRef(rewriter, loc, rowA, token);
757 token = genDeallocMemRef(rewriter, loc, colA, token);
758 token = genDeallocMemRef(rewriter, loc, valA, token);
759 token = genDeallocMemRef(rewriter, loc, buffer, token);
760 token = genDeallocMemRef(rewriter, loc, matB, token);
761 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
762 token = genDeallocMemRef(rewriter, loc, matC, token);
763 tokens.push_back(token);
764 genBlockingWait(rewriter, loc, tokens);
774 linalg::GenericOp op,
bool enableRT) {
776 Value a = op.getOperand(0);
777 Value b = op.getOperand(1);
778 Value c = op.getOperand(2);
782 auto format = CuSparseFormat::kCSR;
786 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
794 Value nseA = NumberOfEntriesOp::create(rewriter, loc, a);
795 Value nseB = NumberOfEntriesOp::create(rewriter, loc,
b);
799 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
800 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
801 Value amemV = ToValuesOp::create(rewriter, loc, a);
802 Value bmemR = genFirstPosOrCrds(rewriter, loc,
b, format, enableRT);
803 Value bmemC = genSecondCrds(rewriter, loc,
b, format, enableRT);
804 Value bmemV = ToValuesOp::create(rewriter, loc,
b);
805 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
806 Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
807 Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
808 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
809 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
810 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
811 genBlockingWait(rewriter, loc, tokens);
819 Value token = genFirstWait(rewriter, loc);
821 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
822 nseA, rowA, colA, valA, format, enableRT);
826 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
827 nseB, rowB, colB, valB, format, enableRT);
834 Value mplus1 = arith::AddIOp::create(rewriter, loc, szm, one);
835 auto e1 = genAllocBuffer(rewriter, loc, cTp.
getPosType(), mplus1, token);
836 Value rowC = e1.getResult(0);
837 token = e1.getAsyncToken();
838 auto e2 = genAllocBuffer(rewriter, loc, cTp.
getCrdType(), zero, token);
839 Value colC = e2.getResult(0);
840 token = e2.getAsyncToken();
841 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
842 Value valC = e3.getResult(0);
843 token = e3.getAsyncToken();
845 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
846 zero, rowC, colC, valC, format, enableRT);
852 gpu::SpGEMMCreateDescrOp::create(rewriter, loc, descTp, tokenTp, token);
855 Operation *work1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
856 rewriter, loc, indexTp, tokenTp, token, desc,
857 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
858 spMatA, spMatB, spMatC, dnCType, zero, valC,
859 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
862 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
863 Value buffer1 = buf1.getResult(0);
864 token = buf1.getAsyncToken();
865 Operation *work2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
866 rewriter, loc, indexTp, tokenTp, token, desc,
867 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
868 spMatA, spMatB, spMatC, dnCType, bufferSz1, buffer1,
869 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
873 Operation *compute1 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
874 rewriter, loc, indexTp, tokenTp, token, desc,
875 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
876 spMatA, spMatB, spMatC, dnCType, zero, valC,
877 gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
880 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
881 Value buffer2 = buf2.getResult(0);
882 token = buf2.getAsyncToken();
883 Operation *compute2 = gpu::SpGEMMWorkEstimationOrComputeOp::create(
884 rewriter, loc, indexTp, tokenTp, token, desc,
885 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
886 spMatA, spMatB, spMatC, dnCType, bufferSz2, buffer2,
887 gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
891 Operation *sizes = gpu::SpMatGetSizeOp::create(
892 rewriter, loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
895 auto a2 = genAllocBuffer(rewriter, loc, cTp.
getCrdType(), nnz, token);
896 colC = a2.getResult(0);
897 token = a2.getAsyncToken();
898 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
899 valC = a3.getResult(0);
900 token = a3.getAsyncToken();
903 Operation *update = gpu::SetCsrPointersOp::create(
904 rewriter, loc, tokenTp, token, spMatC, rowC, colC, valC);
907 rewriter, loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
908 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
909 token =
copy->getResult(0);
914 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
917 token = gpu::SpGEMMDestroyDescrOp::create(rewriter, loc, tokenTp, token, desc)
919 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
921 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatB)
923 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
925 token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
926 token = genCopyMemRef(rewriter, loc, colH, colC, token);
927 token = genCopyMemRef(rewriter, loc, valH, valC, token);
928 token = genDeallocMemRef(rewriter, loc, rowA, token);
929 token = genDeallocMemRef(rewriter, loc, colA, token);
930 token = genDeallocMemRef(rewriter, loc, valA, token);
931 token = genDeallocMemRef(rewriter, loc, rowB, token);
932 token = genDeallocMemRef(rewriter, loc, colB, token);
933 token = genDeallocMemRef(rewriter, loc, valB, token);
934 token = genDeallocMemRef(rewriter, loc, rowC, token);
935 token = genDeallocMemRef(rewriter, loc, colC, token);
936 token = genDeallocMemRef(rewriter, loc, valC, token);
937 token = genDeallocMemRef(rewriter, loc, buffer1, token);
938 token = genDeallocMemRef(rewriter, loc, buffer2, token);
939 tokens.push_back(token);
940 genBlockingWait(rewriter, loc, tokens);
944 Value vt = bufferization::ToTensorOp::create(
946 Value rt = bufferization::ToTensorOp::create(
948 Value ct = bufferization::ToTensorOp::create(
957 linalg::GenericOp op) {
959 Value A = op.getOperand(0);
960 Value B = op.getOperand(1);
961 Value C = op.getOperand(2);
969 auto cnv =
A.getDefiningOp<ConvertOp>();
974 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
981 Value bufA = genTensorToMemref(rewriter, loc, A);
982 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
983 Value bufB = genTensorToMemref(rewriter, loc, B);
984 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
985 Value bufC = genTensorToMemref(rewriter, loc, C);
986 Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
987 genBlockingWait(rewriter, loc, tokens);
998 Value token = genFirstWait(rewriter, loc);
999 Operation *spGenA = gpu::Create2To4SpMatOp::create(
1000 rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk,
1001 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
1005 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
1007 Value dnB = dmatB.getResult(0);
1008 token = dmatB.getAsyncToken();
1010 gpu::CreateDnTensorOp::create(rewriter, loc, dnTensorHandleTp, tokenTp,
1012 Value dnC = dmatC.getResult(0);
1013 token = dmatC.getAsyncToken();
1014 auto dmatCType = llvm::cast<ShapedType>(matC.
getType()).getElementType();
1019 auto bufferComp = gpu::SpMMBufferSizeOp::create(
1020 rewriter, loc, bufferTypes, tokenTp, token,
1021 gpu::TransposeMode::NON_TRANSPOSE, gpu::TransposeMode::NON_TRANSPOSE,
1024 token = bufferComp.getAsyncToken();
1027 Value bufferSz1 = bufferComp.getResult(0);
1028 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
1029 Value buffer1 = buf1.getResult(0);
1030 token = buf1.getAsyncToken();
1031 Value bufferSz2 = bufferComp.getResult(1);
1032 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1033 Value buffer2 = buf2.getResult(0);
1034 token = buf2.getAsyncToken();
1035 Value bufferSz3 = bufferComp.getResult(2);
1036 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1037 Value buffer3 = buf3.getResult(0);
1038 token = buf3.getAsyncToken();
1041 auto dnCType = llvm::cast<ShapedType>(matC.
getType()).getElementType();
1042 auto spmmComp = gpu::SpMMOp::create(
1043 rewriter, loc, tokenTp, token, spMatA, dnB, dnC, dnCType,
1045 token = spmmComp.getAsyncToken();
1048 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatA)
1050 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1052 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnC)
1054 token = genDeallocMemRef(rewriter, loc, buffer1, token);
1055 token = genDeallocMemRef(rewriter, loc, buffer2, token);
1056 token = genDeallocMemRef(rewriter, loc, buffer3, token);
1057 token = genDeallocMemRef(rewriter, loc, matA, token);
1058 token = genDeallocMemRef(rewriter, loc, matB, token);
1059 token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1060 token = genDeallocMemRef(rewriter, loc, matC, token);
1061 tokens.push_back(token);
1062 genBlockingWait(rewriter, loc, tokens);
1072 linalg::GenericOp op,
bool enableRT) {
1074 Value a = op.getOperand(0);
1075 Value b = op.getOperand(1);
1076 Value c = op.getOperand(2);
1083 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT,
false);
1084 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1085 format == CuSparseFormat::kCSC)
1093 Value nseC = NumberOfEntriesOp::create(rewriter, loc, c);
1097 Value bufA = genTensorToMemref(rewriter, loc, a);
1098 Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1099 Value bufB = genTensorToMemref(rewriter, loc,
b);
1100 Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1101 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1102 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
1103 Value memV = ToValuesOp::create(rewriter, loc, c);
1104 Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1105 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) :
Value();
1106 Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1107 genBlockingWait(rewriter, loc, tokens);
1115 Value token = genFirstWait(rewriter, loc);
1117 gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1119 Value dnA = dmatA.getResult(0);
1120 token = dmatA.getAsyncToken();
1122 gpu::CreateDnTensorOp::create(rewriter, loc, dnMatHandleTp, tokenTp,
1124 Value dnB = dmatB.getResult(0);
1125 token = dmatB.getAsyncToken();
1127 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1128 nseC, rowC, colC, valC, format, enableRT);
1131 auto dnCType = llvm::cast<ShapedType>(c.
getType()).getElementType();
1134 auto bufferComp = gpu::SDDMMBufferSizeOp::create(
1135 rewriter, loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1136 Value bufferSz = bufferComp.getResult(0);
1137 token = bufferComp.getAsyncToken();
1138 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1139 Value buffer = buf.getResult(0);
1140 token = buf.getAsyncToken();
1143 auto sddmmComp = gpu::SDDMMOp::create(rewriter, loc, tokenTp, token, dnA, dnB,
1144 spMatC, dnCType, buffer);
1145 token = sddmmComp.getAsyncToken();
1148 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnA)
1150 token = gpu::DestroyDnTensorOp::create(rewriter, loc, tokenTp, token, dnB)
1152 token = gpu::DestroySpMatOp::create(rewriter, loc, tokenTp, token, spMatC)
1154 token = genDeallocMemRef(rewriter, loc, buffer, token);
1155 token = genDeallocMemRef(rewriter, loc, matA, token);
1156 token = genDeallocMemRef(rewriter, loc, matB, token);
1157 token = genDeallocMemRef(rewriter, loc, rowC, token);
1159 token = genDeallocMemRef(rewriter, loc, colC, token);
1160 token = genCopyMemRef(rewriter, loc, memV, valC, token);
1161 token = genDeallocMemRef(rewriter, loc, valC, token);
1162 tokens.push_back(token);
1163 genBlockingWait(rewriter, loc, tokens);
1180 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1182 ForallRewriter(MLIRContext *context,
unsigned nT)
1183 : OpRewritePattern(context), numThreads(nT) {};
1185 LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1186 PatternRewriter &rewriter)
const override {
1193 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1199 forallOp->walk([&](Operation *op) {
1202 Value val = o.get();
1204 if (
auto arg = dyn_cast<BlockArgument>(val))
1205 block = arg.getOwner();
1208 if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1209 invariants.insert(val);
1214 SmallVector<Value> constants;
1215 SmallVector<Value> scalars;
1216 SmallVector<Value> buffers;
1220 SmallVector<bool> copyBack;
1221 for (Value val : invariants) {
1224 constants.push_back(val);
1226 scalars.push_back(val);
1227 else if (isa<MemRefType>(tp)) {
1228 buffers.push_back(val);
1232 bool isWrite =
false;
1233 for (Operation *user : val.
getUsers()) {
1234 if (isa<memref::StoreOp>(user)) {
1238 if (
auto memInterface = dyn_cast<MemoryEffectOpInterface>(user)) {
1239 if (memInterface.getEffectOnValue<MemoryEffects::Write>(val)) {
1245 copyBack.push_back(isWrite);
1253 Location loc = forallOp->getLoc();
1254 SmallVector<Value> args;
1255 SmallVector<Value> tokens;
1256 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1260 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1261 auto gpuModule = genGPUModule(rewriter, topModule);
1262 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1263 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1269 genBlockingWait(rewriter, loc, tokens);
1272 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1274 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1276 genBlockingWait(rewriter, loc, tokens);
1282 unsigned numThreads;
1292 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1294 LinalgOpRewriter(MLIRContext *context,
bool rt)
1295 : OpRewritePattern(context), enableRT(rt) {}
1297 LogicalResult matchAndRewrite(linalg::GenericOp op,
1298 PatternRewriter &rewriter)
const override {
1299 if (op.getNumDpsInits() != 1)
1302 const unsigned numLoops = op.getNumLoops();
1303 const unsigned numTensors = op->getNumOperands();
1304 const auto iteratorTypes = op.getIteratorTypesArray();
1305 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1307 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1308 auto infer = [&](MapList m) {
1318 if (numLoops == 2 && numTensors == 3 &&
1321 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1322 return rewriteSpMV(rewriter, op, enableRT);
1326 if (numLoops == 3 && numTensors == 3 &&
1330 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1331 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1332 return rewriteSpGEMM(rewriter, op, enableRT);
1333 if (isConversionInto24(op.getOperand(0)))
1334 return rewrite2To4SpMM(rewriter, op);
1335 return rewriteSpMM(rewriter, op, enableRT);
1339 if (numLoops == 3 && numTensors == 3 &&
1343 maps == infer({{i, k}, {k, j}, {i, j}}) &&
1344 matchSumReductionOfMulUnary(op)) {
1345 return rewriteSDDMM(rewriter, op, enableRT);
1368 unsigned numThreads) {
1369 patterns.
add<ForallRewriter>(patterns.
getContext(), numThreads);
1374 patterns.
add<LinalgOpRewriter>(patterns.
getContext(), enableRT);
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static constexpr llvm::StringLiteral getLoopEmitterLoopAttrName()
A wrapper around RankedTensorType, which has three goals:
bool isSingletonLvl(Level l) const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
bool isNOutOfMLvl(Level l) const
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
bool isDenseLvl(Level l) const
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isOrderedLvl(Level l) const
bool isUniqueLvl(Level l) const
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT)
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Utility class for the GPU dialect to represent triples of Values accessible through ....