33 static constexpr uint64_t
loIdx = 0;
34 static constexpr uint64_t
hiIdx = 1;
39 "_sparse_binary_search_";
41 "_sparse_hybrid_qsort_";
43 "_sparse_sort_stable_";
57 nameOstream << namePrefix;
59 nameOstream << cast<AffineDimExpr>(res).getPosition() <<
"_";
62 nameOstream <<
"_coo_" << ny;
64 constexpr uint64_t yBufferOffset = 1;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
85 operands.drop_back(nTrailingP));
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
97 func = builder.
create<func::FuncOp>(
98 loc, nameOstream.str(),
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
114 Value iOffset = builder.
create<arith::MulIOp>(loc, args[0], cstep);
115 Value jOffset = builder.
create<arith::MulIOp>(loc, args[1], cstep);
116 for (
unsigned k = 0, e = xPerm.
getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(xPerm.
getResult(k)).getPosition();
119 Value i = builder.
create<arith::AddIOp>(loc, ak, iOffset);
120 Value j = builder.
create<arith::AddIOp>(loc, ak, jOffset);
123 bodyBuilder(k, i,
j, buffer);
137 for (
unsigned y = 0; y < ny; y++) {
145 constexpr uint64_t numHandledBuffers = 1;
149 for (
const auto &arg :
151 bodyBuilder(arg.index() + xPerm.
getNumResults() + ny, i,
j, arg.value());
169 Value vi = builder.
create<memref::LoadOp>(loc, buffer, i);
170 Value vj = builder.
create<memref::LoadOp>(loc, buffer,
j);
171 builder.
create<memref::StoreOp>(loc, vj, buffer, i);
172 builder.
create<memref::StoreOp>(loc, vi, buffer,
j);
187 bool isFirstDim = (k == 0);
190 compareBuilder(builder, loc, i,
j, buffer, isFirstDim, isLastDim);
193 }
else if (!isLastDim) {
197 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
210 Value x,
bool isFirstDim,
bool isLastDim) {
211 Value vi = builder.
create<memref::LoadOp>(loc, x, i);
216 res = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
220 builder.
create<scf::YieldOp>(loc, res);
223 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
229 builder.
create<scf::YieldOp>(loc, f);
234 res = ifOp.getResult(0);
252 uint64_t ny, uint32_t nTrailingP = 0) {
255 assert(nTrailingP == 0);
265 Value vi = builder.
create<memref::LoadOp>(loc, x, i);
270 res = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
274 builder.
create<scf::YieldOp>(loc, res);
277 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
283 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
284 builder.
create<scf::YieldOp>(loc, lt);
289 res = ifOp.getResult(0);
306 uint64_t ny, uint32_t nTrailingP = 0) {
309 assert(nTrailingP == 0);
329 uint64_t ny, uint32_t nTrailingP = 0) {
332 assert(nTrailingP == 0);
334 Block *entryBlock = func.addEntryBlock();
341 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(
346 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
348 Value cond1 = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
349 before->getArgument(0),
350 before->getArgument(1));
351 builder.
create<scf::ConditionOp>(loc, cond1, before->getArguments());
355 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
357 Value lo = after->getArgument(0);
358 Value hi = after->getArgument(1);
362 loc, builder.
create<arith::AddIOp>(loc, lo, hi), c1);
363 Value midp1 = builder.
create<arith::AddIOp>(loc, mid, c1);
367 constexpr uint64_t numXBuffers = 1;
368 compareOperands.append(args.begin() +
xStartIdx,
376 Value newLo = builder.
create<arith::SelectOp>(loc, cond2, lo, midp1);
377 Value newHi = builder.
create<arith::SelectOp>(loc, cond2, mid, hi);
381 builder.
create<func::ReturnOp>(loc, whileOp.getResult(0));
394 uint64_t ny,
int step) {
396 scf::WhileOp whileOp =
404 compareOperands.push_back(before->getArgument(0));
405 compareOperands.push_back(p);
408 compareOperands.push_back(p);
409 compareOperands.push_back(before->getArgument(0));
411 compareOperands.append(xs.begin(), xs.end());
413 builder.
create<scf::ConditionOp>(loc, cond, before->getArguments());
419 i = builder.
create<arith::AddIOp>(loc, after->getArgument(0), cs);
421 i = whileOp.getResult(0);
424 compareOperands[0] = i;
425 compareOperands[1] = p;
429 return std::make_pair(whileOp.getResult(0), compareEq);
441 compareOperands[0] = b;
442 compareOperands[1] = a;
444 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, cond,
false);
448 createSwap(builder, loc, swapOperands, xPerm, ny);
458 compareOperands, v1, v2);
471 compareOperands, v0, v1);
475 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
485 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
488 auto insert4th = [&]() {
490 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
501 compareOperands, v3, v4);
511 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
514 constexpr uint64_t numXBuffers = 1;
515 compareOperands.append(args.begin() +
xStartIdx,
518 swapOperands.append(args.begin() +
xStartIdx, args.end());
521 Value hiP1 = builder.
create<arith::AddIOp>(loc, hi, c1);
522 Value len = builder.
create<arith::SubIOp>(loc, hiP1, lo);
524 Value lenCond = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
526 scf::IfOp lenIf = builder.
create<scf::IfOp>(loc, lenCond,
true);
530 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
535 Value miP1 = builder.
create<arith::AddIOp>(loc, hi, c1);
536 Value a = builder.
create<arith::AddIOp>(loc, lo, miP1);
538 a = builder.
create<arith::ShRUIOp>(loc, a, c1);
539 Value b = builder.
create<arith::AddIOp>(loc, mi, hiP1);
541 b = builder.
create<arith::ShRUIOp>(loc, b, c1);
542 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
579 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
580 uint32_t nTrailingP = 0) {
583 assert(nTrailingP == 0);
586 Block *entryBlock = func.addEntryBlock();
593 Value sum = builder.
create<arith::AddIOp>(loc, lo, hi);
595 Value p = builder.
create<arith::ShRUIOp>(loc, sum, c1);
604 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(loc, types, operands);
608 {loc, loc, loc, loc});
610 builder.
create<scf::ConditionOp>(loc, before->getArgument(3),
611 before->getArguments());
615 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
617 i = after->getArgument(0);
618 j = after->getArgument(1);
619 p = after->getArgument(2);
621 constexpr uint64_t numXBuffers = 1;
622 auto [iresult, iCompareEq] =
626 auto [jresult, jCompareEq] =
628 j, p, xPerm, ny, -1);
633 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i,
j);
634 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, types, cond,
true);
637 swapOperands.append(args.begin() +
xStartIdx, args.end());
638 createSwap(builder, loc, swapOperands, xPerm, ny);
641 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
648 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
j, p);
656 builder.
create<scf::YieldOp>(loc, ifOpJ.getResults());
659 builder.
create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
660 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(
661 loc,
TypeRange{i.getType(),
j.getType()}, compareEqIJ,
true);
663 Value i2 = builder.
create<arith::AddIOp>(loc, i, c1);
669 builder.
create<scf::YieldOp>(
671 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
676 p = builder.
create<arith::AddIOp>(loc,
j,
678 builder.
create<scf::YieldOp>(
683 builder.
create<scf::YieldOp>(loc, ifOp.getResults());
687 builder.
create<func::ReturnOp>(loc, whileOp.getResult(2));
694 Value res = builder.
create<arith::SubIOp>(loc, n, i2);
696 return builder.
create<arith::ShRUIOp>(loc, res, i1);
730 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
731 uint32_t nTrailingP) {
733 assert(nTrailingP == 1);
735 Block *entryBlock = func.addEntryBlock();
747 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
748 scf::IfOp ifN = builder.
create<scf::IfOp>(loc, condN,
false);
750 Value child = builder.
create<arith::SubIOp>(loc, start, first);
755 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
756 scf::IfOp ifNc = builder.
create<scf::IfOp>(loc, condNc,
false);
761 constexpr uint64_t numXBuffers = 1;
762 compareOperands.append(args.begin() +
xStartIdx,
771 auto getLargerChild = [&](
Value r) -> std::pair<Value, Value> {
772 Value lChild = builder.
create<arith::ShLIOp>(loc, r, c1);
773 lChild = builder.
create<arith::AddIOp>(loc, lChild, c1);
774 Value lChildIdx = builder.
create<arith::AddIOp>(loc, lChild, first);
775 Value rChild = builder.
create<arith::AddIOp>(loc, lChild, c1);
776 Value cond1 = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
780 builder.
create<scf::IfOp>(loc, ifTypes, cond1,
true);
782 Value rChildIdx = builder.
create<arith::AddIOp>(loc, rChild, first);
784 compareOperands[0] = lChildIdx;
785 compareOperands[1] = rChildIdx;
789 builder.
create<scf::IfOp>(loc, ifTypes, cond2,
true);
795 builder.
create<scf::YieldOp>(loc, if2.getResults());
799 return std::make_pair(if1.getResult(0), if1.getResult(1));
803 std::tie(child, childIdx) = getLargerChild(child);
807 scf::WhileOp whileOp = builder.
create<scf::WhileOp>(
816 compareOperands[0] = start;
817 compareOperands[1] = childIdx;
823 start = after->getArgument(0);
824 child = after->getArgument(1);
825 childIdx = after->getArgument(2);
827 swapOperands.append(args.begin() +
xStartIdx, args.end());
828 createSwap(builder, loc, swapOperands, xPerm, ny);
831 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
832 scf::IfOp if2 = builder.
create<scf::IfOp>(
835 auto [newChild, newChildIdx] = getLargerChild(child);
840 builder.
create<scf::YieldOp>(
841 loc,
ValueRange{start, if2.getResult(0), if2.getResult(1)});
844 builder.
create<func::ReturnOp>(loc);
861 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
862 uint32_t nTrailingP) {
865 assert(nTrailingP == 0);
867 Block *entryBlock = func.addEntryBlock();
874 Value n = builder.
create<arith::SubIOp>(loc, hi, lo);
880 Value up = builder.
create<arith::AddIOp>(loc, s, c1);
881 scf::ForOp forI = builder.
create<scf::ForOp>(loc, c0, up, c1);
883 Value i = builder.
create<arith::SubIOp>(loc, s, forI.getInductionVar());
884 Value lopi = builder.
create<arith::AddIOp>(loc, lo, i);
886 shiftDownOperands.append(args.begin() +
xStartIdx, args.end());
887 shiftDownOperands.push_back(n);
896 up = builder.
create<arith::SubIOp>(loc, n, c1);
897 scf::ForOp forL = builder.
create<scf::ForOp>(loc, c0, up, c1);
899 Value l = builder.
create<arith::SubIOp>(loc, n, forL.getInductionVar());
900 Value loplm1 = builder.
create<arith::AddIOp>(loc, lo, l);
901 loplm1 = builder.
create<arith::SubIOp>(loc, loplm1, c1);
903 swapOperands.append(args.begin() +
xStartIdx, args.end());
904 createSwap(builder, loc, swapOperands, xPerm, ny);
905 shiftDownOperands[1] = lo;
906 shiftDownOperands[shiftDownOperands.size() - 1] =
907 builder.
create<arith::SubIOp>(loc, l, c1);
912 builder.
create<func::ReturnOp>(loc);
918 static std::pair<Value, Value>
921 uint32_t nTrailingP) {
932 .
create<func::CallOp>(loc, partitionFunc,
934 args.drop_back(nTrailingP))
937 Value lenLow = builder.
create<arith::SubIOp>(loc, p, lo);
938 Value lenHigh = builder.
create<arith::SubIOp>(loc, hi, p);
941 Value len = builder.
create<arith::SubIOp>(loc, hi, lo);
943 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
944 scf::IfOp ifLenGtTwo =
945 builder.
create<scf::IfOp>(loc, types, lenGtTwo,
true);
952 Value cond = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
956 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, types, cond,
true);
960 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
961 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, cond,
false);
964 operands.append(args.begin() +
xStartIdx, args.end());
965 builder.
create<func::CallOp>(loc, func, operands);
972 mayRecursion(lo, p, lenLow);
976 mayRecursion(p, hi, lenHigh);
980 builder.
create<scf::YieldOp>(loc, ifOp.getResults());
983 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
1001 uint64_t ny, uint32_t nTrailingP) {
1004 assert(nTrailingP == 0);
1006 Block *entryBlock = func.addEntryBlock();
1015 Value lop1 = builder.
create<arith::AddIOp>(loc, lo, c1);
1018 scf::ForOp forOpI = builder.
create<scf::ForOp>(loc, lop1, hi, c1);
1020 Value i = forOpI.getInductionVar();
1024 operands.append(args.begin() +
xStartIdx, args.end());
1034 operands[0] = operands[1] = i;
1037 builder, loc, operands, xPerm, ny,
1039 d.push_back(builder.
create<memref::LoadOp>(loc, buffer, i));
1044 Value imp = builder.
create<arith::SubIOp>(loc, i, p);
1046 scf::ForOp forOpJ = builder.
create<scf::ForOp>(loc, c0, imp, c1);
1048 Value j = forOpJ.getInductionVar();
1051 operands[0] = builder.
create<arith::SubIOp>(loc, imj, c1);
1053 builder, loc, operands, xPerm, ny,
1055 Value t = builder.
create<memref::LoadOp>(loc, buffer, imjm1);
1056 builder.
create<memref::StoreOp>(loc, t, buffer, imj);
1061 operands[0] = operands[1] = p;
1063 builder, loc, operands, xPerm, ny,
1065 builder.
create<memref::StoreOp>(loc, d[k], buffer, p);
1069 builder.
create<func::ReturnOp>(loc);
1115 func::FuncOp func,
AffineMap xPerm, uint64_t ny,
1116 uint32_t nTrailingP) {
1117 assert(nTrailingP == 1 || nTrailingP == 0);
1118 bool isHybrid = (nTrailingP == 1);
1120 Block *entryBlock = func.addEntryBlock();
1130 scf::WhileOp whileOp =
1135 builder.
createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1137 lo = before->getArgument(0);
1138 hi = before->getArgument(1);
1142 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1143 builder.
create<scf::ConditionOp>(loc, needSort, before->getArguments());
1147 builder.
createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1149 lo = after->getArgument(0);
1150 hi = after->getArgument(1);
1155 Value len = builder.
create<arith::SubIOp>(loc, hi, lo);
1158 loc, arith::CmpIPredicate::ule, len, lenLimit);
1160 builder.
create<scf::IfOp>(loc, types, lenCond,
true);
1173 Value depthLimit = args.back();
1174 depthLimit = builder.
create<arith::SubIOp>(loc, depthLimit,
1177 builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1180 builder.
create<scf::IfOp>(loc, types, depthCond,
true);
1193 args.back() = depthLimit;
1199 lo = depthIf.getResult(0);
1200 hi = depthIf.getResult(1);
1204 lo = lenIf.getResult(0);
1205 hi = lenIf.getResult(1);
1216 builder.
create<func::ReturnOp>(loc);
1220 template <
typename OpTy>
1227 for (
Value v : xys) {
1229 if (!mtp.isDynamicDim(0)) {
1232 v = rewriter.
create<memref::CastOp>(loc, newMtp, v);
1234 operands.push_back(v);
1237 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1243 uint32_t nTrailingP = 0;
1244 switch (op.getAlgorithm()) {
1245 case SparseTensorSortKind::HybridQuickSort: {
1254 rewriter.
create<arith::SubIOp>(loc, hi, lo));
1255 Value depthLimit = rewriter.
create<arith::SubIOp>(
1257 rewriter.
create<math::CountLeadingZerosOp>(loc, len));
1258 operands.push_back(depthLimit);
1261 case SparseTensorSortKind::QuickSort:
1265 case SparseTensorSortKind::InsertionSortStable:
1269 case SparseTensorSortKind::HeapSort:
1277 xPerm, ny, operands, funcGenerator, nTrailingP);
1291 PushBackRewriter(
MLIRContext *context,
bool enableInit)
1310 Value buffer = op.getInBuffer();
1311 Value capacity = rewriter.
create<memref::DimOp>(loc, buffer, c0);
1312 Value size = op.getCurSize();
1313 Value value = op.getValue();
1316 Value newSize = rewriter.
create<arith::AddIOp>(loc, size, n);
1317 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.
getDefiningOp());
1318 bool nIsOne = (nValue && nValue.value() == 1);
1320 if (!op.getInbounds()) {
1322 loc, arith::CmpIPredicate::ugt, newSize, capacity);
1327 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(loc, bufferType, cond,
1332 capacity = rewriter.
create<arith::MulIOp>(loc, capacity, c2);
1336 scf::WhileOp whileOp =
1337 rewriter.
create<scf::WhileOp>(loc, capacity.
getType(), capacity);
1345 rewriter.
create<arith::MulIOp>(loc, before->getArgument(0), c2);
1346 cond = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1353 rewriter.
create<scf::YieldOp>(loc, after->getArguments());
1356 capacity = whileOp.getResult(0);
1360 rewriter.
create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1361 if (enableBufferInitialization) {
1362 Value fillSize = rewriter.
create<arith::SubIOp>(loc, capacity, newSize);
1364 Value subBuffer = rewriter.
create<memref::SubViewOp>(
1368 rewriter.
create<linalg::FillOp>(loc, fillValue, subBuffer);
1370 rewriter.
create<scf::YieldOp>(loc, newBuffer);
1374 rewriter.
create<scf::YieldOp>(loc, buffer);
1378 buffer = ifOp.getResult(0);
1383 rewriter.
create<memref::StoreOp>(loc, value, buffer, size);
1385 Value subBuffer = rewriter.
create<memref::SubViewOp>(
1388 rewriter.
create<linalg::FillOp>(loc, value, subBuffer);
1392 rewriter.
replaceOp(op, {buffer, newSize});
1397 bool enableBufferInitialization;
1408 xys.push_back(op.getXy());
1409 xys.append(op.getYs().begin(), op.getYs().end());
1411 auto xPerm = op.getPermMap();
1413 if (
auto nyAttr = op.getNyAttr())
1414 ny = nyAttr.getInt();
1427 bool enableBufferInitialization) {
1429 enableBufferInitialization);
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.