33 static constexpr uint64_t
loIdx = 0;
34 static constexpr uint64_t
hiIdx = 1;
39 "_sparse_binary_search_";
41 "_sparse_hybrid_qsort_";
43 "_sparse_sort_stable_";
57 nameOstream << namePrefix;
59 nameOstream << cast<AffineDimExpr>(res).getPosition() <<
"_";
62 nameOstream <<
"_coo_" << ny;
64 constexpr uint64_t yBufferOffset = 1;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
85 operands.drop_back(nTrailingP));
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
97 func = func::FuncOp::create(
98 builder, loc, nameOstream.str(),
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
114 Value iOffset = arith::MulIOp::create(builder, loc, args[0], cstep);
115 Value jOffset = arith::MulIOp::create(builder, loc, args[1], cstep);
116 for (
unsigned k = 0, e = xPerm.
getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(xPerm.
getResult(k)).getPosition();
119 Value i = arith::AddIOp::create(builder, loc, ak, iOffset);
120 Value j = arith::AddIOp::create(builder, loc, ak, jOffset);
123 bodyBuilder(k, i,
j, buffer);
136 for (
unsigned y = 0; y < ny; y++) {
144 constexpr uint64_t numHandledBuffers = 1;
148 for (
const auto &arg :
150 bodyBuilder(arg.index() + xPerm.
getNumResults() + ny, i,
j, arg.value());
168 Value vi = memref::LoadOp::create(builder, loc, buffer, i);
169 Value vj = memref::LoadOp::create(builder, loc, buffer,
j);
170 memref::StoreOp::create(builder, loc, vj, buffer, i);
171 memref::StoreOp::create(builder, loc, vi, buffer,
j);
186 bool isFirstDim = (k == 0);
189 compareBuilder(builder, loc, i,
j, buffer, isFirstDim, isLastDim);
192 }
else if (!isLastDim) {
196 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
209 Value x,
bool isFirstDim,
bool isLastDim) {
210 Value vi = memref::LoadOp::create(builder, loc, x, i);
211 Value vj = memref::LoadOp::create(builder, loc, x,
j);
215 res = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, vi, vj);
219 scf::YieldOp::create(builder, loc, res);
222 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
223 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.
getIntegerType(1),
228 scf::YieldOp::create(builder, loc, f);
233 res = ifOp.getResult(0);
251 uint64_t ny, uint32_t nTrailingP = 0) {
254 assert(nTrailingP == 0);
264 Value vi = memref::LoadOp::create(builder, loc, x, i);
265 Value vj = memref::LoadOp::create(builder, loc, x,
j);
270 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
274 scf::YieldOp::create(builder, loc, res);
277 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
278 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.
getIntegerType(1),
283 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
284 scf::YieldOp::create(builder, loc, lt);
289 res = ifOp.getResult(0);
306 uint64_t ny, uint32_t nTrailingP = 0) {
309 assert(nTrailingP == 0);
329 uint64_t ny, uint32_t nTrailingP = 0) {
332 assert(nTrailingP == 0);
334 Block *entryBlock = func.addEntryBlock();
341 scf::WhileOp whileOp = scf::WhileOp::create(
346 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
349 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
350 before->getArgument(0), before->getArgument(1));
351 scf::ConditionOp::create(builder, loc, cond1, before->getArguments());
355 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
361 Value mid = arith::ShRUIOp::create(
362 builder, loc, arith::AddIOp::create(builder, loc, lo, hi), c1);
363 Value midp1 = arith::AddIOp::create(builder, loc, mid, c1);
367 constexpr uint64_t numXBuffers = 1;
368 compareOperands.append(args.begin() +
xStartIdx,
376 Value newLo = arith::SelectOp::create(builder, loc, cond2, lo, midp1);
377 Value newHi = arith::SelectOp::create(builder, loc, cond2, mid, hi);
378 scf::YieldOp::create(builder, loc,
ValueRange{newLo, newHi});
381 func::ReturnOp::create(builder, loc, whileOp.getResult(0));
394 uint64_t ny,
int step) {
396 scf::WhileOp whileOp =
400 builder.
createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
404 compareOperands.push_back(before->getArgument(0));
405 compareOperands.push_back(p);
408 compareOperands.push_back(p);
409 compareOperands.push_back(before->getArgument(0));
411 compareOperands.append(xs.begin(), xs.end());
413 scf::ConditionOp::create(builder, loc, cond, before->getArguments());
416 builder.
createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
419 i = arith::AddIOp::create(builder, loc, after->
getArgument(0), cs);
420 scf::YieldOp::create(builder, loc,
ValueRange{i});
421 i = whileOp.getResult(0);
424 compareOperands[0] = i;
425 compareOperands[1] = p;
429 return std::make_pair(whileOp.getResult(0), compareEq);
441 compareOperands[0] = b;
442 compareOperands[1] = a;
444 scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond,
false);
448 createSwap(builder, loc, swapOperands, xPerm, ny);
458 compareOperands, v1, v2);
471 compareOperands, v0, v1);
475 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
485 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
488 auto insert4th = [&]() {
490 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
501 compareOperands, v3, v4);
511 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
514 constexpr uint64_t numXBuffers = 1;
515 compareOperands.append(args.begin() +
xStartIdx,
518 swapOperands.append(args.begin() +
xStartIdx, args.end());
521 Value hiP1 = arith::AddIOp::create(builder, loc, hi, c1);
522 Value len = arith::SubIOp::create(builder, loc, hiP1, lo);
524 Value lenCond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
526 scf::IfOp lenIf = scf::IfOp::create(builder, loc, lenCond,
true);
530 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
535 Value miP1 = arith::AddIOp::create(builder, loc, hi, c1);
536 Value a = arith::AddIOp::create(builder, loc, lo, miP1);
538 a = arith::ShRUIOp::create(builder, loc, a, c1);
539 Value b = arith::AddIOp::create(builder, loc, mi, hiP1);
541 b = arith::ShRUIOp::create(builder, loc, b, c1);
542 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
579 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
580 uint32_t nTrailingP = 0) {
583 assert(nTrailingP == 0);
586 Block *entryBlock = func.addEntryBlock();
593 Value sum = arith::AddIOp::create(builder, loc, lo, hi);
595 Value p = arith::ShRUIOp::create(builder, loc, sum, c1);
598 Value j = arith::SubIOp::create(builder, loc, hi, c1);
604 scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, operands);
608 {loc, loc, loc, loc});
610 scf::ConditionOp::create(builder, loc, before->
getArgument(3),
615 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
621 constexpr uint64_t numXBuffers = 1;
622 auto [iresult, iCompareEq] =
626 auto [jresult, jCompareEq] =
628 j, p, xPerm, ny, -1);
633 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, i,
j);
634 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond,
true);
637 swapOperands.append(args.begin() +
xStartIdx, args.end());
638 createSwap(builder, loc, swapOperands, xPerm, ny);
641 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, i, p);
648 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
j, p);
652 scf::YieldOp::create(builder, loc,
ValueRange{i});
654 scf::YieldOp::create(builder, loc,
ValueRange{p});
656 scf::YieldOp::create(builder, loc, ifOpJ.getResults());
659 arith::AndIOp::create(builder, loc, iCompareEq, jCompareEq);
664 Value i2 = arith::AddIOp::create(builder, loc, i, c1);
665 Value j2 = arith::SubIOp::create(builder, loc,
j, c1);
666 scf::YieldOp::create(builder, loc,
ValueRange{i2, j2});
668 scf::YieldOp::create(builder, loc,
ValueRange{i,
j});
670 scf::YieldOp::create(builder, loc,
671 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1),
677 p = arith::AddIOp::create(builder, loc,
j,
679 scf::YieldOp::create(
685 scf::YieldOp::create(builder, loc, ifOp.getResults());
689 func::ReturnOp::create(builder, loc, whileOp.getResult(2));
696 Value res = arith::SubIOp::create(builder, loc, n, i2);
698 return arith::ShRUIOp::create(builder, loc, res, i1);
732 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
733 uint32_t nTrailingP) {
735 assert(nTrailingP == 1);
737 Block *entryBlock = func.addEntryBlock();
749 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, n, c2);
750 scf::IfOp ifN = scf::IfOp::create(builder, loc, condN,
false);
752 Value child = arith::SubIOp::create(builder, loc, start, first);
757 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
758 scf::IfOp ifNc = scf::IfOp::create(builder, loc, condNc,
false);
763 constexpr uint64_t numXBuffers = 1;
764 compareOperands.append(args.begin() +
xStartIdx,
773 auto getLargerChild = [&](
Value r) -> std::pair<Value, Value> {
774 Value lChild = arith::ShLIOp::create(builder, loc, r, c1);
775 lChild = arith::AddIOp::create(builder, loc, lChild, c1);
776 Value lChildIdx = arith::AddIOp::create(builder, loc, lChild, first);
777 Value rChild = arith::AddIOp::create(builder, loc, lChild, c1);
778 Value cond1 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
782 scf::IfOp::create(builder, loc, ifTypes, cond1,
true);
784 Value rChildIdx = arith::AddIOp::create(builder, loc, rChild, first);
786 compareOperands[0] = lChildIdx;
787 compareOperands[1] = rChildIdx;
791 scf::IfOp::create(builder, loc, ifTypes, cond2,
true);
793 scf::YieldOp::create(builder, loc,
ValueRange{rChild, rChildIdx});
795 scf::YieldOp::create(builder, loc,
ValueRange{lChild, lChildIdx});
797 scf::YieldOp::create(builder, loc, if2.getResults());
799 scf::YieldOp::create(builder, loc,
ValueRange{lChild, lChildIdx});
801 return std::make_pair(if1.getResult(0), if1.getResult(1));
805 std::tie(child, childIdx) = getLargerChild(child);
809 scf::WhileOp whileOp = scf::WhileOp::create(
818 compareOperands[0] = start;
819 compareOperands[1] = childIdx;
821 scf::ConditionOp::create(builder, loc, cond, before->
getArguments());
829 swapOperands.append(args.begin() +
xStartIdx, args.end());
830 createSwap(builder, loc, swapOperands, xPerm, ny);
833 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
834 scf::IfOp if2 = scf::IfOp::create(builder, loc,
838 auto [newChild, newChildIdx] = getLargerChild(child);
839 scf::YieldOp::create(builder, loc,
ValueRange{newChild, newChildIdx});
841 scf::YieldOp::create(builder, loc,
ValueRange{child, childIdx});
843 scf::YieldOp::create(builder, loc,
844 ValueRange{start, if2.getResult(0), if2.getResult(1)});
847 func::ReturnOp::create(builder, loc);
864 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
865 uint32_t nTrailingP) {
868 assert(nTrailingP == 0);
870 Block *entryBlock = func.addEntryBlock();
877 Value n = arith::SubIOp::create(builder, loc, hi, lo);
883 Value up = arith::AddIOp::create(builder, loc, s, c1);
884 scf::ForOp forI = scf::ForOp::create(builder, loc, c0, up, c1);
886 Value i = arith::SubIOp::create(builder, loc, s, forI.getInductionVar());
887 Value lopi = arith::AddIOp::create(builder, loc, lo, i);
889 shiftDownOperands.append(args.begin() +
xStartIdx, args.end());
890 shiftDownOperands.push_back(n);
894 func::CallOp::create(builder, loc, shiftDownFunc,
TypeRange(),
899 up = arith::SubIOp::create(builder, loc, n, c1);
900 scf::ForOp forL = scf::ForOp::create(builder, loc, c0, up, c1);
902 Value l = arith::SubIOp::create(builder, loc, n, forL.getInductionVar());
903 Value loplm1 = arith::AddIOp::create(builder, loc, lo, l);
904 loplm1 = arith::SubIOp::create(builder, loc, loplm1, c1);
906 swapOperands.append(args.begin() +
xStartIdx, args.end());
907 createSwap(builder, loc, swapOperands, xPerm, ny);
908 shiftDownOperands[1] = lo;
909 shiftDownOperands[shiftDownOperands.size() - 1] =
910 arith::SubIOp::create(builder, loc, l, c1);
911 func::CallOp::create(builder, loc, shiftDownFunc,
TypeRange(),
915 func::ReturnOp::create(builder, loc);
921 static std::pair<Value, Value>
924 uint32_t nTrailingP) {
934 Value p = func::CallOp::create(builder, loc, partitionFunc,
936 args.drop_back(nTrailingP))
939 Value lenLow = arith::SubIOp::create(builder, loc, p, lo);
940 Value lenHigh = arith::SubIOp::create(builder, loc, hi, p);
943 Value len = arith::SubIOp::create(builder, loc, hi, lo);
945 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ugt, len, c2);
946 scf::IfOp ifLenGtTwo =
947 scf::IfOp::create(builder, loc, types, lenGtTwo,
true);
950 scf::YieldOp::create(builder, loc,
ValueRange{lo, lo});
954 Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
958 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond,
true);
962 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, len, c0);
963 scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond,
false);
966 operands.append(args.begin() +
xStartIdx, args.end());
967 func::CallOp::create(builder, loc, func, operands);
974 mayRecursion(lo, p, lenLow);
975 scf::YieldOp::create(builder, loc,
ValueRange{p, hi});
978 mayRecursion(p, hi, lenHigh);
979 scf::YieldOp::create(builder, loc,
ValueRange{lo, p});
982 scf::YieldOp::create(builder, loc, ifOp.getResults());
985 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
1003 uint64_t ny, uint32_t nTrailingP) {
1006 assert(nTrailingP == 0);
1008 Block *entryBlock = func.addEntryBlock();
1017 Value lop1 = arith::AddIOp::create(builder, loc, lo, c1);
1020 scf::ForOp forOpI = scf::ForOp::create(builder, loc, lop1, hi, c1);
1022 Value i = forOpI.getInductionVar();
1026 operands.append(args.begin() +
xStartIdx, args.end());
1030 Value p = func::CallOp::create(builder, loc, searchFunc,
1035 operands[0] = operands[1] = i;
1038 builder, loc, operands, xPerm, ny,
1040 d.push_back(memref::LoadOp::create(builder, loc, buffer, i));
1045 Value imp = arith::SubIOp::create(builder, loc, i, p);
1047 scf::ForOp forOpJ = scf::ForOp::create(builder, loc, c0, imp, c1);
1049 Value j = forOpJ.getInductionVar();
1050 Value imj = arith::SubIOp::create(builder, loc, i,
j);
1052 operands[0] = arith::SubIOp::create(builder, loc, imj, c1);
1054 builder, loc, operands, xPerm, ny,
1056 Value t = memref::LoadOp::create(builder, loc, buffer, imjm1);
1057 memref::StoreOp::create(builder, loc, t, buffer, imj);
1062 operands[0] = operands[1] = p;
1064 builder, loc, operands, xPerm, ny,
1066 memref::StoreOp::create(builder, loc, d[k], buffer, p);
1070 func::ReturnOp::create(builder, loc);
1116 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
1117 uint32_t nTrailingP) {
1118 assert(nTrailingP == 1 || nTrailingP == 0);
1119 bool isHybrid = (nTrailingP == 1);
1121 Block *entryBlock = func.addEntryBlock();
1131 scf::WhileOp whileOp =
1136 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1138 lo = before->getArgument(0);
1139 hi = before->getArgument(1);
1141 arith::AddIOp::create(builder, loc, lo,
constantIndex(builder, loc, 1));
1143 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, loP1, hi);
1144 scf::ConditionOp::create(builder, loc, needSort, before->getArguments());
1148 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1156 Value len = arith::SubIOp::create(builder, loc, hi, lo);
1158 Value lenCond = arith::CmpIOp::create(
1159 builder, loc, arith::CmpIPredicate::ule, len, lenLimit);
1161 scf::IfOp::create(builder, loc, types, lenCond,
true);
1168 func::CallOp::create(builder, loc, insertionSortFunc,
TypeRange(),
1170 scf::YieldOp::create(builder, loc,
ValueRange{lo, lo});
1174 Value depthLimit = args.back();
1175 depthLimit = arith::SubIOp::create(builder, loc, depthLimit,
1178 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
1181 scf::IfOp::create(builder, loc, types, depthCond,
true);
1188 func::CallOp::create(builder, loc, heapSortFunc,
TypeRange(),
1190 scf::YieldOp::create(builder, loc,
ValueRange{lo, lo});
1194 args.back() = depthLimit;
1197 scf::YieldOp::create(builder, loc,
ValueRange{lo, hi});
1200 lo = depthIf.getResult(0);
1201 hi = depthIf.getResult(1);
1202 scf::YieldOp::create(builder, loc,
ValueRange{lo, hi});
1205 lo = lenIf.getResult(0);
1206 hi = lenIf.getResult(1);
1213 scf::YieldOp::create(builder, loc,
ValueRange{lo, hi});
1217 func::ReturnOp::create(builder, loc);
1221 template <
typename OpTy>
1229 for (
Value v : xys) {
1231 if (!mtp.isDynamicDim(0)) {
1234 v = memref::CastOp::create(rewriter, loc, newMtp, v);
1236 operands.push_back(v);
1239 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1245 uint32_t nTrailingP = 0;
1246 switch (op.getAlgorithm()) {
1247 case SparseTensorSortKind::HybridQuickSort: {
1254 Value len = arith::IndexCastOp::create(
1256 arith::SubIOp::create(rewriter, loc, hi, lo));
1257 Value depthLimit = arith::SubIOp::create(
1259 math::CountLeadingZerosOp::create(rewriter, loc, len));
1260 operands.push_back(depthLimit);
1263 case SparseTensorSortKind::QuickSort:
1267 case SparseTensorSortKind::InsertionSortStable:
1271 case SparseTensorSortKind::HeapSort:
1279 xPerm, ny, operands, funcGenerator, nTrailingP);
1293 PushBackRewriter(
MLIRContext *context,
bool enableInit)
1295 LogicalResult matchAndRewrite(PushBackOp op,
1312 Value buffer = op.getInBuffer();
1313 Value capacity = memref::DimOp::create(rewriter, loc, buffer, c0);
1314 Value size = op.getCurSize();
1315 Value value = op.getValue();
1318 Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
1320 bool nIsOne = (nValue && nValue.value() == 1);
1322 if (!op.getInbounds()) {
1323 Value cond = arith::CmpIOp::create(
1324 rewriter, loc, arith::CmpIPredicate::ugt, newSize, capacity);
1329 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, bufferType, cond,
1334 capacity = arith::MulIOp::create(rewriter, loc, capacity, c2);
1338 scf::WhileOp whileOp =
1339 scf::WhileOp::create(rewriter, loc, capacity.
getType(), capacity);
1343 {capacity.getType()}, {loc});
1347 arith::MulIOp::create(rewriter, loc, before->
getArgument(0), c2);
1348 cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt,
1350 scf::ConditionOp::create(rewriter, loc, cond,
ValueRange{capacity});
1353 {capacity.getType()}, {loc});
1355 scf::YieldOp::create(rewriter, loc, after->getArguments());
1358 capacity = whileOp.getResult(0);
1361 Value newBuffer = memref::ReallocOp::create(rewriter, loc, bufferType,
1363 if (enableBufferInitialization) {
1365 arith::SubIOp::create(rewriter, loc, capacity, newSize);
1367 Value subBuffer = memref::SubViewOp::create(
1368 rewriter, loc, newBuffer,
ValueRange{newSize},
1371 linalg::FillOp::create(rewriter, loc, fillValue, subBuffer);
1373 scf::YieldOp::create(rewriter, loc, newBuffer);
1377 scf::YieldOp::create(rewriter, loc, buffer);
1381 buffer = ifOp.getResult(0);
1386 memref::StoreOp::create(rewriter, loc, value, buffer, size);
1388 Value subBuffer = memref::SubViewOp::create(
1392 linalg::FillOp::create(rewriter, loc, value, subBuffer);
1396 rewriter.
replaceOp(op, {buffer, newSize});
1401 bool enableBufferInitialization;
1409 LogicalResult matchAndRewrite(SortOp op,
1412 xys.push_back(op.getXy());
1413 xys.append(op.getYs().begin(), op.getYs().end());
1415 auto xPerm = op.getPermMap();
1417 if (
auto nyAttr = op.getNyAttr())
1418 ny = nyAttr.getInt();
1431 bool enableBufferInitialization) {
1433 enableBufferInitialization);
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Include the generated interface declarations.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.