33 static constexpr uint64_t
loIdx = 0;
34 static constexpr uint64_t
hiIdx = 1;
39 "_sparse_binary_search_";
41 "_sparse_hybrid_qsort_";
43 "_sparse_sort_stable_";
57 nameOstream << namePrefix;
59 nameOstream << cast<AffineDimExpr>(res).getPosition() <<
"_";
62 nameOstream <<
"_coo_" << ny;
64 constexpr uint64_t yBufferOffset = 1;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
85 operands.drop_back(nTrailingP));
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
97 func = builder.
create<func::FuncOp>(
98 loc, nameOstream.str(),
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
114 Value iOffset = builder.
create<arith::MulIOp>(loc, args[0], cstep);
115 Value jOffset = builder.
create<arith::MulIOp>(loc, args[1], cstep);
116 for (
unsigned k = 0, e = xPerm.
getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(xPerm.
getResult(k)).getPosition();
119 Value i = builder.
create<arith::AddIOp>(loc, ak, iOffset);
120 Value j = builder.
create<arith::AddIOp>(loc, ak, jOffset);
123 bodyBuilder(k, i,
j, buffer);
136 for (
unsigned y = 0; y < ny; y++) {
144 constexpr uint64_t numHandledBuffers = 1;
148 for (
const auto &arg :
150 bodyBuilder(arg.index() + xPerm.
getNumResults() + ny, i,
j, arg.value());
168 Value vi = builder.
create<memref::LoadOp>(loc, buffer, i);
169 Value vj = builder.
create<memref::LoadOp>(loc, buffer,
j);
170 builder.
create<memref::StoreOp>(loc, vj, buffer, i);
171 builder.
create<memref::StoreOp>(loc, vi, buffer,
j);
186 bool isFirstDim = (k == 0);
189 compareBuilder(builder, loc, i,
j, buffer, isFirstDim, isLastDim);
192 }
else if (!isLastDim) {
196 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
209 Value x,
bool isFirstDim,
bool isLastDim) {
210 Value vi = builder.
create<memref::LoadOp>(loc, x, i);
215 res = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
219 builder.
create<scf::YieldOp>(loc, res);
222 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
228 builder.
create<scf::YieldOp>(loc, f);
233 res = ifOp.getResult(0);
251 uint64_t ny, uint32_t nTrailingP = 0) {
254 assert(nTrailingP == 0);
264 Value vi = builder.
create<memref::LoadOp>(loc, x, i);
269 res = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
273 builder.
create<scf::YieldOp>(loc, res);
276 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
282 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
283 builder.
create<scf::YieldOp>(loc, lt);
288 res = ifOp.getResult(0);
305 uint64_t ny, uint32_t nTrailingP = 0) {
308 assert(nTrailingP == 0);
328 uint64_t ny, uint32_t nTrailingP = 0) {
331 assert(nTrailingP == 0);
333 Block *entryBlock = func.addEntryBlock();
340 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(
345 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
347 Value cond1 = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
348 before->getArgument(0),
349 before->getArgument(1));
350 builder.
create<scf::ConditionOp>(loc, cond1, before->getArguments());
354 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
356 Value lo = after->getArgument(0);
357 Value hi = after->getArgument(1);
361 loc, builder.
create<arith::AddIOp>(loc, lo, hi), c1);
362 Value midp1 = builder.
create<arith::AddIOp>(loc, mid, c1);
366 constexpr uint64_t numXBuffers = 1;
367 compareOperands.append(args.begin() +
xStartIdx,
375 Value newLo = builder.
create<arith::SelectOp>(loc, cond2, lo, midp1);
376 Value newHi = builder.
create<arith::SelectOp>(loc, cond2, mid, hi);
380 builder.
create<func::ReturnOp>(loc, whileOp.getResult(0));
393 uint64_t ny,
int step) {
395 scf::WhileOp whileOp =
403 compareOperands.push_back(before->getArgument(0));
404 compareOperands.push_back(p);
407 compareOperands.push_back(p);
408 compareOperands.push_back(before->getArgument(0));
410 compareOperands.append(xs.begin(), xs.end());
412 builder.
create<scf::ConditionOp>(loc, cond, before->getArguments());
418 i = builder.
create<arith::AddIOp>(loc, after->getArgument(0), cs);
420 i = whileOp.getResult(0);
423 compareOperands[0] = i;
424 compareOperands[1] = p;
428 return std::make_pair(whileOp.getResult(0), compareEq);
440 compareOperands[0] = b;
441 compareOperands[1] = a;
443 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, cond,
false);
447 createSwap(builder, loc, swapOperands, xPerm, ny);
457 compareOperands, v1, v2);
470 compareOperands, v0, v1);
474 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
484 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
487 auto insert4th = [&]() {
489 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
490 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
500 compareOperands, v3, v4);
510 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
513 constexpr uint64_t numXBuffers = 1;
514 compareOperands.append(args.begin() +
xStartIdx,
517 swapOperands.append(args.begin() +
xStartIdx, args.end());
520 Value hiP1 = builder.
create<arith::AddIOp>(loc, hi, c1);
521 Value len = builder.
create<arith::SubIOp>(loc, hiP1, lo);
523 Value lenCond = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
525 scf::IfOp lenIf = builder.
create<scf::IfOp>(loc, lenCond,
true);
529 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
534 Value miP1 = builder.
create<arith::AddIOp>(loc, hi, c1);
535 Value a = builder.
create<arith::AddIOp>(loc, lo, miP1);
537 a = builder.
create<arith::ShRUIOp>(loc, a, c1);
538 Value b = builder.
create<arith::AddIOp>(loc, mi, hiP1);
540 b = builder.
create<arith::ShRUIOp>(loc, b, c1);
541 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
578 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
579 uint32_t nTrailingP = 0) {
582 assert(nTrailingP == 0);
585 Block *entryBlock = func.addEntryBlock();
592 Value sum = builder.
create<arith::AddIOp>(loc, lo, hi);
594 Value p = builder.
create<arith::ShRUIOp>(loc, sum, c1);
603 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(loc, types, operands);
607 {loc, loc, loc, loc});
609 builder.
create<scf::ConditionOp>(loc, before->getArgument(3),
610 before->getArguments());
614 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
616 i = after->getArgument(0);
617 j = after->getArgument(1);
618 p = after->getArgument(2);
620 constexpr uint64_t numXBuffers = 1;
621 auto [iresult, iCompareEq] =
625 auto [jresult, jCompareEq] =
627 j, p, xPerm, ny, -1);
632 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i,
j);
633 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, types, cond,
true);
636 swapOperands.append(args.begin() +
xStartIdx, args.end());
637 createSwap(builder, loc, swapOperands, xPerm, ny);
640 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
647 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
j, p);
655 builder.
create<scf::YieldOp>(loc, ifOpJ.getResults());
658 builder.
create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
659 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(
660 loc,
TypeRange{i.getType(),
j.getType()}, compareEqIJ,
true);
662 Value i2 = builder.
create<arith::AddIOp>(loc, i, c1);
668 builder.
create<scf::YieldOp>(
670 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
675 p = builder.
create<arith::AddIOp>(loc,
j,
677 builder.
create<scf::YieldOp>(
682 builder.
create<scf::YieldOp>(loc, ifOp.getResults());
686 builder.
create<func::ReturnOp>(loc, whileOp.getResult(2));
693 Value res = builder.
create<arith::SubIOp>(loc, n, i2);
695 return builder.
create<arith::ShRUIOp>(loc, res, i1);
729 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
730 uint32_t nTrailingP) {
732 assert(nTrailingP == 1);
734 Block *entryBlock = func.addEntryBlock();
746 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
747 scf::IfOp ifN = builder.
create<scf::IfOp>(loc, condN,
false);
749 Value child = builder.
create<arith::SubIOp>(loc, start, first);
754 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
755 scf::IfOp ifNc = builder.
create<scf::IfOp>(loc, condNc,
false);
760 constexpr uint64_t numXBuffers = 1;
761 compareOperands.append(args.begin() +
xStartIdx,
770 auto getLargerChild = [&](
Value r) -> std::pair<Value, Value> {
771 Value lChild = builder.
create<arith::ShLIOp>(loc, r, c1);
772 lChild = builder.
create<arith::AddIOp>(loc, lChild, c1);
773 Value lChildIdx = builder.
create<arith::AddIOp>(loc, lChild, first);
774 Value rChild = builder.
create<arith::AddIOp>(loc, lChild, c1);
775 Value cond1 = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
779 builder.
create<scf::IfOp>(loc, ifTypes, cond1,
true);
781 Value rChildIdx = builder.
create<arith::AddIOp>(loc, rChild, first);
783 compareOperands[0] = lChildIdx;
784 compareOperands[1] = rChildIdx;
788 builder.
create<scf::IfOp>(loc, ifTypes, cond2,
true);
794 builder.
create<scf::YieldOp>(loc, if2.getResults());
798 return std::make_pair(if1.getResult(0), if1.getResult(1));
802 std::tie(child, childIdx) = getLargerChild(child);
806 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(
815 compareOperands[0] = start;
816 compareOperands[1] = childIdx;
822 start = after->getArgument(0);
823 child = after->getArgument(1);
824 childIdx = after->getArgument(2);
826 swapOperands.append(args.begin() +
xStartIdx, args.end());
827 createSwap(builder, loc, swapOperands, xPerm, ny);
830 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
831 scf::IfOp if2 = builder.
create<scf::IfOp>(
834 auto [newChild, newChildIdx] = getLargerChild(child);
839 builder.
create<scf::YieldOp>(
840 loc,
ValueRange{start, if2.getResult(0), if2.getResult(1)});
843 builder.
create<func::ReturnOp>(loc);
860 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
861 uint32_t nTrailingP) {
864 assert(nTrailingP == 0);
866 Block *entryBlock = func.addEntryBlock();
873 Value n = builder.
create<arith::SubIOp>(loc, hi, lo);
879 Value up = builder.
create<arith::AddIOp>(loc, s, c1);
880 scf::ForOp forI = builder.
create<scf::ForOp>(loc, c0, up, c1);
882 Value i = builder.
create<arith::SubIOp>(loc, s, forI.getInductionVar());
883 Value lopi = builder.
create<arith::AddIOp>(loc, lo, i);
885 shiftDownOperands.append(args.begin() +
xStartIdx, args.end());
886 shiftDownOperands.push_back(n);
895 up = builder.
create<arith::SubIOp>(loc, n, c1);
896 scf::ForOp forL = builder.
create<scf::ForOp>(loc, c0, up, c1);
898 Value l = builder.
create<arith::SubIOp>(loc, n, forL.getInductionVar());
899 Value loplm1 = builder.
create<arith::AddIOp>(loc, lo, l);
900 loplm1 = builder.
create<arith::SubIOp>(loc, loplm1, c1);
902 swapOperands.append(args.begin() +
xStartIdx, args.end());
903 createSwap(builder, loc, swapOperands, xPerm, ny);
904 shiftDownOperands[1] = lo;
905 shiftDownOperands[shiftDownOperands.size() - 1] =
906 builder.
create<arith::SubIOp>(loc, l, c1);
911 builder.
create<func::ReturnOp>(loc);
917 static std::pair<Value, Value>
920 uint32_t nTrailingP) {
931 .
create<func::CallOp>(loc, partitionFunc,
933 args.drop_back(nTrailingP))
936 Value lenLow = builder.
create<arith::SubIOp>(loc, p, lo);
937 Value lenHigh = builder.
create<arith::SubIOp>(loc, hi, p);
940 Value len = builder.
create<arith::SubIOp>(loc, hi, lo);
942 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
943 scf::IfOp ifLenGtTwo =
944 builder.
create<scf::IfOp>(loc, types, lenGtTwo,
true);
951 Value cond = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
955 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, types, cond,
true);
959 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
960 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, cond,
false);
963 operands.append(args.begin() +
xStartIdx, args.end());
964 builder.
create<func::CallOp>(loc, func, operands);
971 mayRecursion(lo, p, lenLow);
975 mayRecursion(p, hi, lenHigh);
979 builder.
create<scf::YieldOp>(loc, ifOp.getResults());
982 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
1000 uint64_t ny, uint32_t nTrailingP) {
1003 assert(nTrailingP == 0);
1005 Block *entryBlock = func.addEntryBlock();
1014 Value lop1 = builder.
create<arith::AddIOp>(loc, lo, c1);
1017 scf::ForOp forOpI = builder.
create<scf::ForOp>(loc, lop1, hi, c1);
1019 Value i = forOpI.getInductionVar();
1023 operands.append(args.begin() +
xStartIdx, args.end());
1033 operands[0] = operands[1] = i;
1036 builder, loc, operands, xPerm, ny,
1038 d.push_back(builder.
create<memref::LoadOp>(loc, buffer, i));
1043 Value imp = builder.
create<arith::SubIOp>(loc, i, p);
1045 scf::ForOp forOpJ = builder.
create<scf::ForOp>(loc, c0, imp, c1);
1047 Value j = forOpJ.getInductionVar();
1050 operands[0] = builder.
create<arith::SubIOp>(loc, imj, c1);
1052 builder, loc, operands, xPerm, ny,
1054 Value t = builder.
create<memref::LoadOp>(loc, buffer, imjm1);
1055 builder.
create<memref::StoreOp>(loc, t, buffer, imj);
1060 operands[0] = operands[1] = p;
1062 builder, loc, operands, xPerm, ny,
1064 builder.
create<memref::StoreOp>(loc, d[k], buffer, p);
1068 builder.
create<func::ReturnOp>(loc);
1114 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
1115 uint32_t nTrailingP) {
1116 assert(nTrailingP == 1 || nTrailingP == 0);
1117 bool isHybrid = (nTrailingP == 1);
1119 Block *entryBlock = func.addEntryBlock();
1129 scf::WhileOp whileOp =
1134 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1136 lo = before->getArgument(0);
1137 hi = before->getArgument(1);
1141 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1142 builder.
create<scf::ConditionOp>(loc, needSort, before->getArguments());
1146 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1148 lo = after->getArgument(0);
1149 hi = after->getArgument(1);
1154 Value len = builder.
create<arith::SubIOp>(loc, hi, lo);
1157 loc, arith::CmpIPredicate::ule, len, lenLimit);
1159 builder.
create<scf::IfOp>(loc, types, lenCond,
true);
1172 Value depthLimit = args.back();
1173 depthLimit = builder.
create<arith::SubIOp>(loc, depthLimit,
1176 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1179 builder.
create<scf::IfOp>(loc, types, depthCond,
true);
1192 args.back() = depthLimit;
1198 lo = depthIf.getResult(0);
1199 hi = depthIf.getResult(1);
1203 lo = lenIf.getResult(0);
1204 hi = lenIf.getResult(1);
1215 builder.
create<func::ReturnOp>(loc);
1219 template <
typename OpTy>
1226 for (
Value v : xys) {
1228 if (!mtp.isDynamicDim(0)) {
1231 v = rewriter.
create<memref::CastOp>(loc, newMtp, v);
1233 operands.push_back(v);
1236 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1242 uint32_t nTrailingP = 0;
1243 switch (op.getAlgorithm()) {
1244 case SparseTensorSortKind::HybridQuickSort: {
1253 rewriter.
create<arith::SubIOp>(loc, hi, lo));
1254 Value depthLimit = rewriter.
create<arith::SubIOp>(
1256 rewriter.
create<math::CountLeadingZerosOp>(loc, len));
1257 operands.push_back(depthLimit);
1260 case SparseTensorSortKind::QuickSort:
1264 case SparseTensorSortKind::InsertionSortStable:
1268 case SparseTensorSortKind::HeapSort:
1276 xPerm, ny, operands, funcGenerator, nTrailingP);
1290 PushBackRewriter(
MLIRContext *context,
bool enableInit)
1292 LogicalResult matchAndRewrite(PushBackOp op,
1309 Value buffer = op.getInBuffer();
1310 Value capacity = rewriter.
create<memref::DimOp>(loc, buffer, c0);
1311 Value size = op.getCurSize();
1312 Value value = op.getValue();
1315 Value newSize = rewriter.
create<arith::AddIOp>(loc, size, n);
1316 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.
getDefiningOp());
1317 bool nIsOne = (nValue && nValue.value() == 1);
1319 if (!op.getInbounds()) {
1321 loc, arith::CmpIPredicate::ugt, newSize, capacity);
1326 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(loc, bufferType, cond,
1331 capacity = rewriter.
create<arith::MulIOp>(loc, capacity, c2);
1335 scf::WhileOp whileOp =
1336 rewriter.
create<scf::WhileOp>(loc, capacity.
getType(), capacity);
1344 rewriter.
create<arith::MulIOp>(loc, before->getArgument(0), c2);
1345 cond = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1352 rewriter.
create<scf::YieldOp>(loc, after->getArguments());
1355 capacity = whileOp.getResult(0);
1359 rewriter.
create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1360 if (enableBufferInitialization) {
1361 Value fillSize = rewriter.
create<arith::SubIOp>(loc, capacity, newSize);
1363 Value subBuffer = rewriter.
create<memref::SubViewOp>(
1367 rewriter.
create<linalg::FillOp>(loc, fillValue, subBuffer);
1369 rewriter.
create<scf::YieldOp>(loc, newBuffer);
1373 rewriter.
create<scf::YieldOp>(loc, buffer);
1377 buffer = ifOp.getResult(0);
1382 rewriter.
create<memref::StoreOp>(loc, value, buffer, size);
1384 Value subBuffer = rewriter.
create<memref::SubViewOp>(
1387 rewriter.
create<linalg::FillOp>(loc, value, subBuffer);
1391 rewriter.
replaceOp(op, {buffer, newSize});
1396 bool enableBufferInitialization;
1404 LogicalResult matchAndRewrite(SortOp op,
1407 xys.push_back(op.getXy());
1408 xys.append(op.getYs().begin(), op.getYs().end());
1410 auto xPerm = op.getPermMap();
1412 if (
auto nyAttr = op.getNyAttr())
1413 ny = nyAttr.getInt();
1426 bool enableBufferInitialization) {
1428 enableBufferInitialization);
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Include the generated interface declarations.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.