MLIR  20.0.0git
SparseBufferRewriting.cpp
Go to the documentation of this file.
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "Utils/CodegenUtils.h"
15 
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38 static constexpr const char kBinarySearchFuncNamePrefix[] =
39  "_sparse_binary_search_";
40 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41  "_sparse_hybrid_qsort_";
42 static constexpr const char kSortStableFuncNamePrefix[] =
43  "_sparse_sort_stable_";
44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47 
48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49  AffineMap, uint64_t, uint32_t)>;
50 
51 /// Constructs a function name with this format to facilitate quick sort:
52 /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53 /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55  StringRef namePrefix, AffineMap xPerm,
56  uint64_t ny, ValueRange operands) {
57  nameOstream << namePrefix;
58  for (auto res : xPerm.getResults())
59  nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
60 
61  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62  nameOstream << "_coo_" << ny;
63 
64  constexpr uint64_t yBufferOffset = 1;
65  for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66  nameOstream << "_" << getMemRefType(v).getElementType();
67 }
68 
69 /// Looks up a function that is appropriate for the given operands being
70 /// sorted, and creates such a function if it doesn't exist yet. The
71 /// parameters `xPerm` and `ny` tell the number of x and y values provided
72 /// by the buffer in xStartIdx.
73 //
74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
75 // parameters for the sorting functions. Other parameters, such as the recursive
76 // call depth, are appended to the end of the parameter list as
77 // "trailing parameters".
79  OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80  StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81  FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82  SmallString<32> nameBuffer;
83  llvm::raw_svector_ostream nameOstream(nameBuffer);
84  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85  operands.drop_back(nTrailingP));
86 
87  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88  MLIRContext *context = module.getContext();
89  auto result = SymbolRefAttr::get(context, nameOstream.str());
90  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91 
92  if (!func) {
93  // Create the function.
94  OpBuilder::InsertionGuard insertionGuard(builder);
95  builder.setInsertionPoint(insertPoint);
96  Location loc = insertPoint.getLoc();
97  func = builder.create<func::FuncOp>(
98  loc, nameOstream.str(),
99  FunctionType::get(context, operands.getTypes(), resultTypes));
100  func.setPrivate();
101  createFunc(builder, module, func, xPerm, ny, nTrailingP);
102  }
103 
104  return result;
105 }
106 
107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108 /// The code to process the value pairs is generated by `bodyBuilder`.
109 static void forEachIJPairInXs(
110  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111  uint64_t ny,
112  function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117  unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
118  Value ak = constantIndex(builder, loc, actualK);
119  Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120  Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121  Value buffer = args[xStartIdx];
122 
123  bodyBuilder(k, i, j, buffer);
124  }
125 }
126 
127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128 /// The code to process the value pairs is generated by `bodyBuilder`.
130  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131  uint64_t ny,
132  function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133 
134  // Create code for the first (xPerm + ny) buffers.
135  SmallVector<AffineExpr> exps(xPerm.getResults());
136  for (unsigned y = 0; y < ny; y++) {
137  exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
138  }
139  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
140  assert(xyPerm.isPermutation());
141 
142  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
143 
144  constexpr uint64_t numHandledBuffers = 1;
145  // Create code for the remaining buffers.
146  Value i = args[0];
147  Value j = args[1];
148  for (const auto &arg :
149  llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
150  bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
151  }
152 }
153 
154 /// Creates a code block for swapping the values in index i and j for all the
155 /// buffers.
156 //
157 // The generated IR corresponds to this C like algorithm:
158 // swap(x0[i], x0[j]);
159 // swap(x1[i], x1[j]);
160 // ...
161 // swap(xn[i], xn[j]);
162 // swap(y0[i], y0[j]);
163 // ...
164 // swap(yn[i], yn[j]);
165 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
166  AffineMap xPerm, uint64_t ny) {
167  auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
168  Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
169  Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
170  builder.create<memref::StoreOp>(loc, vj, buffer, i);
171  builder.create<memref::StoreOp>(loc, vi, buffer, j);
172  };
173 
174  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
175 }
176 
177 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
178 /// each pair is create via `compareBuilder`.
180  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
181  uint64_t ny,
182  function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
183  compareBuilder) {
184  Value result;
185  auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
186  bool isFirstDim = (k == 0);
187  bool isLastDim = (k == xPerm.getNumResults() - 1);
188  Value val =
189  compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
190  if (isFirstDim) {
191  result = val;
192  } else if (!isLastDim) {
193  OpBuilder::InsertionGuard insertionGuard(builder);
194  auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
195  builder.setInsertionPointAfter(ifOp);
196  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
197  }
198  };
199 
200  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
201 
202  builder.setInsertionPointAfterValue(result);
203  return result;
204 }
205 
206 /// Generates code to compare whether x[i] is equal to x[j] and returns the
207 /// result of the comparison.
209  Value x, bool isFirstDim, bool isLastDim) {
210  Value vi = builder.create<memref::LoadOp>(loc, x, i);
211  Value vj = builder.create<memref::LoadOp>(loc, x, j);
212 
213  Value res;
214  if (isLastDim) {
215  res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
216  // For 1D, we create a compare without any control flow. Otherwise, we
217  // create YieldOp to return the result in the nested if-stmt.
218  if (!isFirstDim)
219  builder.create<scf::YieldOp>(loc, res);
220  } else {
221  Value ne =
222  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
223  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
224  ne, /*else=*/true);
225  // If (x[i] != x[j]).
226  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
227  Value f = constantI1(builder, loc, false);
228  builder.create<scf::YieldOp>(loc, f);
229 
230  // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
231  // checks the remaining dimensions.
232  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
233  res = ifOp.getResult(0);
234  }
235 
236  return res;
237 }
238 
239 /// Creates code to compare whether xs[i] is equal to xs[j].
240 //
241 // The generate IR corresponds to this C like algorithm:
242 // if (x0[i] != x0[j])
243 // return false;
244 // else
245 // if (x1[i] != x1[j])
246 // return false;
247 // else if (x2[2] != x2[j]))
248 // and so on ...
250  ValueRange args, AffineMap xPerm,
251  uint64_t ny, uint32_t nTrailingP = 0) {
252  // Compare functions don't use trailing parameters.
253  (void)nTrailingP;
254  assert(nTrailingP == 0);
255  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
257 }
258 
259 /// Generates code to compare whether x[i] is less than x[j] and returns the
260 /// result of the comparison.
262  Value j, Value x, bool isFirstDim,
263  bool isLastDim) {
264  Value vi = builder.create<memref::LoadOp>(loc, x, i);
265  Value vj = builder.create<memref::LoadOp>(loc, x, j);
266 
267  Value res;
268  if (isLastDim) {
269  res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
270  // For 1D, we create a compare without any control flow. Otherwise, we
271  // create YieldOp to return the result in the nested if-stmt.
272  if (!isFirstDim)
273  builder.create<scf::YieldOp>(loc, res);
274  } else {
275  Value ne =
276  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
277  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
278  ne, /*else=*/true);
279  // If (x[i] != x[j]).
280  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
281  Value lt =
282  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
283  builder.create<scf::YieldOp>(loc, lt);
284 
285  // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
286  // checks the remaining dimensions.
287  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
288  res = ifOp.getResult(0);
289  }
290 
291  return res;
292 }
293 
294 /// Creates code to compare whether xs[i] is less than xs[j].
295 //
296 // The generate IR corresponds to this C like algorithm:
297 // if (x0[i] != x0[j])
298 // return x0[i] < x0[j];
299 // else if (x1[j] != x1[i])
300 // return x1[i] < x1[j];
301 // else
302 // and so on ...
304  ValueRange args, AffineMap xPerm,
305  uint64_t ny, uint32_t nTrailingP = 0) {
306  // Compare functions don't use trailing parameters.
307  (void)nTrailingP;
308  assert(nTrailingP == 0);
309  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
311 }
312 
313 /// Creates a function to use a binary search to find the insertion point for
314 /// inserting xs[hi] to the sorted values xs[lo..hi).
315 //
316 // The generate IR corresponds to this C like algorithm:
317 // p = hi
318 // while (lo < hi)
319 // mid = (lo + hi) >> 1
320 // if (xs[p] < xs[mid])
321 // hi = mid
322 // else
323 // lo = mid - 1
324 // return lo;
325 //
326 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
327  func::FuncOp func, AffineMap xPerm,
328  uint64_t ny, uint32_t nTrailingP = 0) {
329  // Binary search doesn't use trailing parameters.
330  (void)nTrailingP;
331  assert(nTrailingP == 0);
332  OpBuilder::InsertionGuard insertionGuard(builder);
333  Block *entryBlock = func.addEntryBlock();
334  builder.setInsertionPointToStart(entryBlock);
335 
336  Location loc = func.getLoc();
337  ValueRange args = entryBlock->getArguments();
338  Value p = args[hiIdx];
339  SmallVector<Type, 2> types(2, p.getType()); // Only two types.
340  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
341  loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
342 
343  // The before-region of the WhileOp.
344  Block *before =
345  builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
346  builder.setInsertionPointToEnd(before);
347  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
348  before->getArgument(0),
349  before->getArgument(1));
350  builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
351 
352  // The after-region of the WhileOp.
353  Block *after =
354  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
355  builder.setInsertionPointToEnd(after);
356  Value lo = after->getArgument(0);
357  Value hi = after->getArgument(1);
358  // Compute mid = (lo + hi) >> 1.
359  Value c1 = constantIndex(builder, loc, 1);
360  Value mid = builder.create<arith::ShRUIOp>(
361  loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
362  Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
363 
364  // Compare xs[p] < xs[mid].
365  SmallVector<Value> compareOperands{p, mid};
366  constexpr uint64_t numXBuffers = 1;
367  compareOperands.append(args.begin() + xStartIdx,
368  args.begin() + xStartIdx + numXBuffers);
369  Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
370  // Update lo and hi for the WhileOp as follows:
371  // if (xs[p] < xs[mid]))
372  // hi = mid;
373  // else
374  // lo = mid + 1;
375  Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
376  Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
377  builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
378 
379  builder.setInsertionPointAfter(whileOp);
380  builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
381 }
382 
383 /// Creates code to advance i in a loop based on xs[p] as follows:
384 /// while (xs[i] < xs[p]) i += step (step > 0)
385 /// or
386 /// while (xs[i] > xs[p]) i += step (step < 0)
387 /// The routine returns i as well as a boolean value to indicate whether
388 /// xs[i] == xs[p].
389 static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
390  ModuleOp module,
391  func::FuncOp func, ValueRange xs,
392  Value i, Value p, AffineMap xPerm,
393  uint64_t ny, int step) {
394  Location loc = func.getLoc();
395  scf::WhileOp whileOp =
396  builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
397 
398  Block *before =
399  builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
400  builder.setInsertionPointToEnd(before);
401  SmallVector<Value> compareOperands;
402  if (step > 0) {
403  compareOperands.push_back(before->getArgument(0));
404  compareOperands.push_back(p);
405  } else {
406  assert(step < 0);
407  compareOperands.push_back(p);
408  compareOperands.push_back(before->getArgument(0));
409  }
410  compareOperands.append(xs.begin(), xs.end());
411  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
412  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
413 
414  Block *after =
415  builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
416  builder.setInsertionPointToEnd(after);
417  Value cs = constantIndex(builder, loc, step);
418  i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
419  builder.create<scf::YieldOp>(loc, ValueRange{i});
420  i = whileOp.getResult(0);
421 
422  builder.setInsertionPointAfter(whileOp);
423  compareOperands[0] = i;
424  compareOperands[1] = p;
425  Value compareEq =
426  createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
427 
428  return std::make_pair(whileOp.getResult(0), compareEq);
429 }
430 
431 /// Creates and returns an IfOp to compare two elements and swap the elements
432 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
433 /// right after the swap instructions.
434 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
435  AffineMap xPerm, uint64_t ny,
436  SmallVectorImpl<Value> &swapOperands,
437  SmallVectorImpl<Value> &compareOperands,
438  Value a, Value b) {
439  // Compare(data[b], data[a]).
440  compareOperands[0] = b;
441  compareOperands[1] = a;
442  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
443  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
444  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
445  swapOperands[0] = b;
446  swapOperands[1] = a;
447  createSwap(builder, loc, swapOperands, xPerm, ny);
448  return ifOp;
449 }
450 
451 /// Creates code to insert the 3rd element to a list of two sorted elements.
452 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
453  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
454  SmallVectorImpl<Value> &compareOperands, Value v0,
455  Value v1, Value v2) {
456  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
457  compareOperands, v1, v2);
458  createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
459  v0, v1);
460  builder.setInsertionPointAfter(ifOp);
461 }
462 
463 /// Creates code to sort 3 elements.
464 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
465  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
466  SmallVectorImpl<Value> &compareOperands, Value v0,
467  Value v1, Value v2) {
468  // Sort the first 2 elements.
469  scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
470  compareOperands, v0, v1);
471  builder.setInsertionPointAfter(ifOp1);
472 
473  // Insert the 3th element.
474  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
475  v1, v2);
476 }
477 
478 /// Creates code to sort 5 elements.
479 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
480  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
481  SmallVectorImpl<Value> &compareOperands, Value v0,
482  Value v1, Value v2, Value v3, Value v4) {
483  // Sort the first 3 elements.
484  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
485  v2);
486 
487  auto insert4th = [&]() {
488  scf::IfOp ifOp = createCompareThenSwap(
489  builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
490  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
491  v1, v2);
492  builder.setInsertionPointAfter(ifOp);
493  };
494 
495  // Insert the 4th element.
496  insert4th();
497 
498  // Insert the 5th element.
499  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
500  compareOperands, v3, v4);
501  insert4th();
502  builder.setInsertionPointAfter(ifOp);
503 }
504 
505 /// Creates a code block to swap the values in indices lo, mi, and hi so that
506 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
507 /// the number of values in range [lo, hi) is more than a threshold, we also
508 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
509 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
510  func::FuncOp func, AffineMap xPerm, uint64_t ny,
511  Value lo, Value hi, Value mi, ValueRange args) {
512  SmallVector<Value> compareOperands{mi, lo};
513  constexpr uint64_t numXBuffers = 1;
514  compareOperands.append(args.begin() + xStartIdx,
515  args.begin() + xStartIdx + numXBuffers);
516  SmallVector<Value> swapOperands{mi, lo};
517  swapOperands.append(args.begin() + xStartIdx, args.end());
518  Location loc = func.getLoc();
519  Value c1 = constantIndex(builder, loc, 1);
520  Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
521  Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
522  Value lenThreshold = constantIndex(builder, loc, 1000);
523  Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
524  len, lenThreshold);
525  scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
526 
527  // When len < 1000, choose pivot from median of 3 values.
528  builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
529  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
530  hi);
531 
532  // When len >= 1000, choose pivot from median of 5 values.
533  builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
534  Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
535  Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
536  // Value a is the middle between [loc, mi].
537  a = builder.create<arith::ShRUIOp>(loc, a, c1);
538  Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
539  // Value b is the middle between [mi, hi].
540  b = builder.create<arith::ShRUIOp>(loc, b, c1);
541  createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
542  b, hi);
543 
544  builder.setInsertionPointAfter(lenIf);
545 }
546 
547 /// Creates a function to perform quick sort partition on the values in the
548 /// range of index [lo, hi), assuming lo < hi.
549 //
550 // The generated IR corresponds to this C like algorithm:
551 // int partition(lo, hi, xs) {
552 // p = (lo+hi)/2 // pivot index
553 // i = lo
554 // j = hi-1
555 // while (true) do {
556 // while (xs[i] < xs[p]) i ++;
557 // i_eq = (xs[i] == xs[p]);
558 // while (xs[j] > xs[p]) j --;
559 // j_eq = (xs[j] == xs[p]);
560 //
561 // if (i >= j) return j + 1;
562 //
563 // if (i < j) {
564 // swap(xs[i], xs[j])
565 // if (i == p) {
566 // p = j;
567 // } else if (j == p) {
568 // p = i;
569 // }
570 // if (i_eq && j_eq) {
571 // ++i;
572 // --j;
573 // }
574 // }
575 // }
576 // }
577 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
578  func::FuncOp func, AffineMap xPerm, uint64_t ny,
579  uint32_t nTrailingP = 0) {
580  // Quick sort partition doesn't use trailing parameters.
581  (void)nTrailingP;
582  assert(nTrailingP == 0);
583  OpBuilder::InsertionGuard insertionGuard(builder);
584 
585  Block *entryBlock = func.addEntryBlock();
586  builder.setInsertionPointToStart(entryBlock);
587 
588  Location loc = func.getLoc();
589  ValueRange args = entryBlock->getArguments();
590  Value lo = args[loIdx];
591  Value hi = args[hiIdx];
592  Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
593  Value c1 = constantIndex(builder, loc, 1);
594  Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
595 
596  Value i = lo;
597  Value j = builder.create<arith::SubIOp>(loc, hi, c1);
598  createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
599  Value trueVal = constantI1(builder, loc, true); // The value for while (true)
600  SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
601  SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
602  trueVal.getType()};
603  scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
604 
605  // The before-region of the WhileOp.
606  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
607  {loc, loc, loc, loc});
608  builder.setInsertionPointToEnd(before);
609  builder.create<scf::ConditionOp>(loc, before->getArgument(3),
610  before->getArguments());
611 
612  // The after-region of the WhileOp.
613  Block *after =
614  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
615  builder.setInsertionPointToEnd(after);
616  i = after->getArgument(0);
617  j = after->getArgument(1);
618  p = after->getArgument(2);
619 
620  constexpr uint64_t numXBuffers = 1;
621  auto [iresult, iCompareEq] =
622  createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
623  i, p, xPerm, ny, 1);
624  i = iresult;
625  auto [jresult, jCompareEq] =
626  createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
627  j, p, xPerm, ny, -1);
628  j = jresult;
629 
630  // If i < j:
631  Value cond =
632  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
633  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
634  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
635  SmallVector<Value> swapOperands{i, j};
636  swapOperands.append(args.begin() + xStartIdx, args.end());
637  createSwap(builder, loc, swapOperands, xPerm, ny);
638  // If the pivot is moved, update p with the new pivot.
639  Value icond =
640  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
641  scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
642  icond, /*else=*/true);
643  builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
644  builder.create<scf::YieldOp>(loc, ValueRange{j});
645  builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
646  Value jcond =
647  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
648  scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
649  jcond, /*else=*/true);
650  builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
651  builder.create<scf::YieldOp>(loc, ValueRange{i});
652  builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
653  builder.create<scf::YieldOp>(loc, ValueRange{p});
654  builder.setInsertionPointAfter(ifOpJ);
655  builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
656  builder.setInsertionPointAfter(ifOpI);
657  Value compareEqIJ =
658  builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
659  scf::IfOp ifOp2 = builder.create<scf::IfOp>(
660  loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
661  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
662  Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
663  Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
664  builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
665  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
666  builder.create<scf::YieldOp>(loc, ValueRange{i, j});
667  builder.setInsertionPointAfter(ifOp2);
668  builder.create<scf::YieldOp>(
669  loc,
670  ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
671  /*cont=*/constantI1(builder, loc, true)});
672 
673  // False branch for if i < j (i.e., i >= j):
674  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
675  p = builder.create<arith::AddIOp>(loc, j,
676  constantOne(builder, loc, j.getType()));
677  builder.create<scf::YieldOp>(
678  loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
679 
680  // Return for the whileOp.
681  builder.setInsertionPointAfter(ifOp);
682  builder.create<scf::YieldOp>(loc, ifOp.getResults());
683 
684  // Return for the function.
685  builder.setInsertionPointAfter(whileOp);
686  builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
687 }
688 
689 /// Computes (n-2)/n, assuming n has index type.
691  Value n) {
692  Value i2 = constantIndex(builder, loc, 2);
693  Value res = builder.create<arith::SubIOp>(loc, n, i2);
694  Value i1 = constantIndex(builder, loc, 1);
695  return builder.create<arith::ShRUIOp>(loc, res, i1);
696 }
697 
698 /// Creates a function to heapify the subtree with root `start` within the full
699 /// binary tree in the range of index [first, first + n).
700 //
701 // The generated IR corresponds to this C like algorithm:
702 // void shiftDown(first, start, n, data) {
703 // if (n >= 2) {
704 // child = start - first
705 // if ((n-2)/2 >= child) {
706 // // Left child exists.
707 // child = child * 2 + 1 // Initialize the bigger child to left child.
708 // childIndex = child + first
709 // if (child+1 < n && data[childIndex] < data[childIndex+1])
710 // // Right child exits and is bigger.
711 // childIndex++; child++;
712 // // Shift data[start] down to where it belongs in the subtree.
713 // while (data[start] < data[childIndex) {
714 // swap(data[start], data[childIndex])
715 // start = childIndex
716 // if ((n - 2)/2 >= child) {
717 // // Left child exists.
718 // child = 2*child + 1
719 // childIndex = child + 1
720 // if (child + 1) < n && data[childIndex] < data[childIndex+1]
721 // childIndex++; child++;
722 // }
723 // }
724 // }
725 // }
726 // }
727 //
728 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
729  func::FuncOp func, AffineMap xPerm, uint64_t ny,
730  uint32_t nTrailingP) {
731  // The value n is passed in as a trailing parameter.
732  assert(nTrailingP == 1);
733  OpBuilder::InsertionGuard insertionGuard(builder);
734  Block *entryBlock = func.addEntryBlock();
735  builder.setInsertionPointToStart(entryBlock);
736 
737  Location loc = func.getLoc();
738  Value n = entryBlock->getArguments().back();
739  ValueRange args = entryBlock->getArguments().drop_back();
740  Value first = args[loIdx];
741  Value start = args[hiIdx];
742 
743  // If (n >= 2).
744  Value c2 = constantIndex(builder, loc, 2);
745  Value condN =
746  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
747  scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
748  builder.setInsertionPointToStart(&ifN.getThenRegion().front());
749  Value child = builder.create<arith::SubIOp>(loc, start, first);
750 
751  // If ((n-2)/2 >= child).
752  Value t = createSubTwoDividedByTwo(builder, loc, n);
753  Value condNc =
754  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
755  scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
756 
757  builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
758  Value c1 = constantIndex(builder, loc, 1);
759  SmallVector<Value> compareOperands{start, start};
760  constexpr uint64_t numXBuffers = 1;
761  compareOperands.append(args.begin() + xStartIdx,
762  args.begin() + xStartIdx + numXBuffers);
763 
764  // Generate code to inspect the children of 'r' and return the larger child
765  // as follows:
766  // child = r * 2 + 1 // Left child.
767  // childIndex = child + first
768  // if (child+1 < n && data[childIndex] < data[childIndex+1])
769  // childIndex ++; child ++ // Right child is bigger.
770  auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
771  Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
772  lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
773  Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
774  Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
775  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
776  rChild, n);
777  SmallVector<Type, 2> ifTypes(2, r.getType());
778  scf::IfOp if1 =
779  builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
780  builder.setInsertionPointToStart(&if1.getThenRegion().front());
781  Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
782  // Compare data[left] < data[right].
783  compareOperands[0] = lChildIdx;
784  compareOperands[1] = rChildIdx;
785  Value cond2 =
786  createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
787  scf::IfOp if2 =
788  builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
789  builder.setInsertionPointToStart(&if2.getThenRegion().front());
790  builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
791  builder.setInsertionPointToStart(&if2.getElseRegion().front());
792  builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
793  builder.setInsertionPointAfter(if2);
794  builder.create<scf::YieldOp>(loc, if2.getResults());
795  builder.setInsertionPointToStart(&if1.getElseRegion().front());
796  builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
797  builder.setInsertionPointAfter(if1);
798  return std::make_pair(if1.getResult(0), if1.getResult(1));
799  };
800 
801  Value childIdx;
802  std::tie(child, childIdx) = getLargerChild(child);
803 
804  // While (data[start] < data[childIndex]).
805  SmallVector<Type, 3> types(3, child.getType());
806  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
807  loc, types, SmallVector<Value, 2>{start, child, childIdx});
808 
809  // The before-region of the WhileOp.
810  SmallVector<Location, 3> locs(3, loc);
811  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
812  builder.setInsertionPointToEnd(before);
813  start = before->getArgument(0);
814  childIdx = before->getArgument(2);
815  compareOperands[0] = start;
816  compareOperands[1] = childIdx;
817  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
818  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
819 
820  // The after-region of the WhileOp.
821  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
822  start = after->getArgument(0);
823  child = after->getArgument(1);
824  childIdx = after->getArgument(2);
825  SmallVector<Value> swapOperands{start, childIdx};
826  swapOperands.append(args.begin() + xStartIdx, args.end());
827  createSwap(builder, loc, swapOperands, xPerm, ny);
828  start = childIdx;
829  Value cond2 =
830  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
831  scf::IfOp if2 = builder.create<scf::IfOp>(
832  loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
833  builder.setInsertionPointToStart(&if2.getThenRegion().front());
834  auto [newChild, newChildIdx] = getLargerChild(child);
835  builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
836  builder.setInsertionPointToStart(&if2.getElseRegion().front());
837  builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
838  builder.setInsertionPointAfter(if2);
839  builder.create<scf::YieldOp>(
840  loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
841 
842  builder.setInsertionPointAfter(ifN);
843  builder.create<func::ReturnOp>(loc);
844 }
845 
846 /// Creates a function to perform heap sort on the values in the range of index
847 /// [lo, hi) with the assumption hi - lo >= 2.
848 //
849 // The generate IR corresponds to this C like algorithm:
850 // void heapSort(lo, hi, data) {
851 // n = hi - lo
852 // for i = (n-2)/2 downto 0
853 // shiftDown(lo, lo+i, n)
854 //
855 // for l = n downto 2
856 // swap(lo, lo+l-1)
857 // shiftdown(lo, lo, l-1)
858 // }
859 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
860  func::FuncOp func, AffineMap xPerm, uint64_t ny,
861  uint32_t nTrailingP) {
862  // Heap sort function doesn't have trailing parameters.
863  (void)nTrailingP;
864  assert(nTrailingP == 0);
865  OpBuilder::InsertionGuard insertionGuard(builder);
866  Block *entryBlock = func.addEntryBlock();
867  builder.setInsertionPointToStart(entryBlock);
868 
869  Location loc = func.getLoc();
870  ValueRange args = entryBlock->getArguments();
871  Value lo = args[loIdx];
872  Value hi = args[hiIdx];
873  Value n = builder.create<arith::SubIOp>(loc, hi, lo);
874 
875  // For i = (n-2)/2 downto 0.
876  Value c0 = constantIndex(builder, loc, 0);
877  Value c1 = constantIndex(builder, loc, 1);
878  Value s = createSubTwoDividedByTwo(builder, loc, n);
879  Value up = builder.create<arith::AddIOp>(loc, s, c1);
880  scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
881  builder.setInsertionPointToStart(forI.getBody());
882  Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
883  Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
884  SmallVector<Value> shiftDownOperands = {lo, lopi};
885  shiftDownOperands.append(args.begin() + xStartIdx, args.end());
886  shiftDownOperands.push_back(n);
888  builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
889  shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
890  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
891  shiftDownOperands);
892 
893  builder.setInsertionPointAfter(forI);
894  // For l = n downto 2.
895  up = builder.create<arith::SubIOp>(loc, n, c1);
896  scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
897  builder.setInsertionPointToStart(forL.getBody());
898  Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
899  Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
900  loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
901  SmallVector<Value> swapOperands{lo, loplm1};
902  swapOperands.append(args.begin() + xStartIdx, args.end());
903  createSwap(builder, loc, swapOperands, xPerm, ny);
904  shiftDownOperands[1] = lo;
905  shiftDownOperands[shiftDownOperands.size() - 1] =
906  builder.create<arith::SubIOp>(loc, l, c1);
907  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
908  shiftDownOperands);
909 
910  builder.setInsertionPointAfter(forL);
911  builder.create<func::ReturnOp>(loc);
912 }
913 
914 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
915 /// recursively calls quick sort to process the smaller partition and returns
916 /// the bigger partition to be processed by the enclosed while-loop.
917 static std::pair<Value, Value>
918 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
919  ValueRange args, AffineMap xPerm, uint64_t ny,
920  uint32_t nTrailingP) {
921  MLIRContext *context = module.getContext();
922  Location loc = func.getLoc();
923  Value lo = args[loIdx];
924  Value hi = args[hiIdx];
925  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
926 
928  builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
929  ny, args.drop_back(nTrailingP), createPartitionFunc);
930  Value p = builder
931  .create<func::CallOp>(loc, partitionFunc,
932  TypeRange{IndexType::get(context)},
933  args.drop_back(nTrailingP))
934  .getResult(0);
935 
936  Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
937  Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
938  // Partition already sorts array with len <= 2
939  Value c2 = constantIndex(builder, loc, 2);
940  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
941  Value lenGtTwo =
942  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
943  scf::IfOp ifLenGtTwo =
944  builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
945  builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
946  // Returns an empty range to mark the entire region is fully sorted.
947  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
948 
949  // Else len > 2, need recursion.
950  builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
951  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
952  lenLow, lenHigh);
953 
954  Value c0 = constantIndex(builder, loc, 0);
955  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956 
957  auto mayRecursion = [&](Value low, Value high, Value len) {
958  Value cond =
959  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
960  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
961  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
962  SmallVector<Value> operands{low, high};
963  operands.append(args.begin() + xStartIdx, args.end());
964  builder.create<func::CallOp>(loc, func, operands);
965  builder.setInsertionPointAfter(ifOp);
966  };
967 
968  // Recursively call quickSort to process the smaller partition and return
969  // the bigger partition to be processed by the enclosed while-loop.
970  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971  mayRecursion(lo, p, lenLow);
972  builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
973 
974  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
975  mayRecursion(p, hi, lenHigh);
976  builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
977 
978  builder.setInsertionPointAfter(ifOp);
979  builder.create<scf::YieldOp>(loc, ifOp.getResults());
980 
981  builder.setInsertionPointAfter(ifLenGtTwo);
982  return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
983 }
984 
985 /// Creates a function to perform insertion sort on the values in the range of
986 /// index [lo, hi).
987 //
988 // The generate IR corresponds to this C like algorithm:
989 // void insertionSort(lo, hi, data) {
990 // for (i = lo+1; i < hi; i++) {
991 // d = data[i];
992 // p = binarySearch(lo, i-1, data)
993 // for (j = 0; j > i - p; j++)
994 // data[i-j] = data[i-j-1]
995 // data[p] = d
996 // }
997 // }
998 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
999  func::FuncOp func, AffineMap xPerm,
1000  uint64_t ny, uint32_t nTrailingP) {
1001  // Stable sort function doesn't use trailing parameters.
1002  (void)nTrailingP;
1003  assert(nTrailingP == 0);
1004  OpBuilder::InsertionGuard insertionGuard(builder);
1005  Block *entryBlock = func.addEntryBlock();
1006  builder.setInsertionPointToStart(entryBlock);
1007 
1008  MLIRContext *context = module.getContext();
1009  Location loc = func.getLoc();
1010  ValueRange args = entryBlock->getArguments();
1011  Value c1 = constantIndex(builder, loc, 1);
1012  Value lo = args[loIdx];
1013  Value hi = args[hiIdx];
1014  Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1015 
1016  // Start the outer for-stmt with induction variable i.
1017  scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1018  builder.setInsertionPointToStart(forOpI.getBody());
1019  Value i = forOpI.getInductionVar();
1020 
1021  // Binary search to find the insertion point p.
1022  SmallVector<Value> operands{lo, i};
1023  operands.append(args.begin() + xStartIdx, args.end());
1025  builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1026  xPerm, ny, operands, createBinarySearchFunc);
1027  Value p = builder
1028  .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1029  operands)
1030  .getResult(0);
1031 
1032  // Move the value at data[i] to a temporary location.
1033  operands[0] = operands[1] = i;
1036  builder, loc, operands, xPerm, ny,
1037  [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1038  d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1039  });
1040 
1041  // Start the inner for-stmt with induction variable j, for moving data[p..i)
1042  // to data[p+1..i+1).
1043  Value imp = builder.create<arith::SubIOp>(loc, i, p);
1044  Value c0 = constantIndex(builder, loc, 0);
1045  scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1046  builder.setInsertionPointToStart(forOpJ.getBody());
1047  Value j = forOpJ.getInductionVar();
1048  Value imj = builder.create<arith::SubIOp>(loc, i, j);
1049  operands[1] = imj;
1050  operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1052  builder, loc, operands, xPerm, ny,
1053  [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1054  Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1055  builder.create<memref::StoreOp>(loc, t, buffer, imj);
1056  });
1057 
1058  // Store the value at data[i] to data[p].
1059  builder.setInsertionPointAfter(forOpJ);
1060  operands[0] = operands[1] = p;
1062  builder, loc, operands, xPerm, ny,
1063  [&](uint64_t k, Value p, Value usused, Value buffer) {
1064  builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1065  });
1066 
1067  builder.setInsertionPointAfter(forOpI);
1068  builder.create<func::ReturnOp>(loc);
1069 }
1070 
1071 /// Creates a function to perform quick sort or a hybrid quick sort on the
1072 /// values in the range of index [lo, hi).
1073 //
1074 //
1075 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1076 // void quickSort(lo, hi, data) {
1077 // while (lo + 1 < hi) {
1078 // p = partition(low, high, data);
1079 // if (len(lo, p) < len(p+1, hi)) {
1080 // quickSort(lo, p, data);
1081 // lo = p+1;
1082 // } else {
1083 // quickSort(p + 1, hi, data);
1084 // hi = p;
1085 // }
1086 // }
1087 // }
1088 //
1089 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1090 // void hybridQuickSort(lo, hi, data, depthLimit) {
1091 // while (lo + 1 < hi) {
1092 // len = hi - lo;
1093 // if (len <= limit) {
1094 // insertionSort(lo, hi, data);
1095 // } else {
1096 // depthLimit --;
1097 // if (depthLimit <= 0) {
1098 // heapSort(lo, hi, data);
1099 // } else {
1100 // p = partition(low, high, data);
1101 // if (len(lo, p) < len(p+1, hi)) {
1102 // quickSort(lo, p, data, depthLimit);
1103 // lo = p+1;
1104 // } else {
1105 // quickSort(p + 1, hi, data, depthLimit);
1106 // hi = p;
1107 // }
1108 // }
1109 // }
1110 // }
1111 // }
1112 //
1113 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1114  func::FuncOp func, AffineMap xPerm, uint64_t ny,
1115  uint32_t nTrailingP) {
1116  assert(nTrailingP == 1 || nTrailingP == 0);
1117  bool isHybrid = (nTrailingP == 1);
1118  OpBuilder::InsertionGuard insertionGuard(builder);
1119  Block *entryBlock = func.addEntryBlock();
1120  builder.setInsertionPointToStart(entryBlock);
1121 
1122  Location loc = func.getLoc();
1123  SmallVector<Value> args;
1124  args.append(entryBlock->getArguments().begin(),
1125  entryBlock->getArguments().end());
1126  Value lo = args[loIdx];
1127  Value hi = args[hiIdx];
1128  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1129  scf::WhileOp whileOp =
1130  builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1131 
1132  // The before-region of the WhileOp.
1133  Block *before =
1134  builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1135  builder.setInsertionPointToEnd(before);
1136  lo = before->getArgument(0);
1137  hi = before->getArgument(1);
1138  Value loP1 =
1139  builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1140  Value needSort =
1141  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1142  builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1143 
1144  // The after-region of the WhileOp.
1145  Block *after =
1146  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1147  builder.setInsertionPointToEnd(after);
1148  lo = after->getArgument(0);
1149  hi = after->getArgument(1);
1150  args[0] = lo;
1151  args[1] = hi;
1152 
1153  if (isHybrid) {
1154  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1155  Value lenLimit = constantIndex(builder, loc, 30);
1156  Value lenCond = builder.create<arith::CmpIOp>(
1157  loc, arith::CmpIPredicate::ule, len, lenLimit);
1158  scf::IfOp lenIf =
1159  builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1160 
1161  // When len <= limit.
1162  builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1163  FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1164  builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1165  ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1166  builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1167  ValueRange(args).drop_back(nTrailingP));
1168  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1169 
1170  // When len > limit.
1171  builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1172  Value depthLimit = args.back();
1173  depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1174  constantI64(builder, loc, 1));
1175  Value depthCond =
1176  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1177  depthLimit, constantI64(builder, loc, 0));
1178  scf::IfOp depthIf =
1179  builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1180 
1181  // When depth exceeds limit.
1182  builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1184  builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1185  ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1186  builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1187  ValueRange(args).drop_back(nTrailingP));
1188  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1189 
1190  // When depth doesn't exceed limit.
1191  builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1192  args.back() = depthLimit;
1193  std::tie(lo, hi) =
1194  createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1195  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1196 
1197  builder.setInsertionPointAfter(depthIf);
1198  lo = depthIf.getResult(0);
1199  hi = depthIf.getResult(1);
1200  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1201 
1202  builder.setInsertionPointAfter(lenIf);
1203  lo = lenIf.getResult(0);
1204  hi = lenIf.getResult(1);
1205  } else {
1206  std::tie(lo, hi) =
1207  createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1208  }
1209 
1210  // New [lo, hi) for the next while-loop iteration.
1211  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1212 
1213  // After the while-loop.
1214  builder.setInsertionPointAfter(whileOp);
1215  builder.create<func::ReturnOp>(loc);
1216 }
1217 
1218 /// Implements the rewriting for operator sort and sort_coo.
1219 template <typename OpTy>
1220 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1221  uint64_t ny, PatternRewriter &rewriter) {
1222  Location loc = op.getLoc();
1223  SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1224 
1225  // Convert `values` to have dynamic shape and append them to `operands`.
1226  for (Value v : xys) {
1227  auto mtp = getMemRefType(v);
1228  if (!mtp.isDynamicDim(0)) {
1229  auto newMtp =
1230  MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1231  v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1232  }
1233  operands.push_back(v);
1234  }
1235 
1236  auto insertPoint = op->template getParentOfType<func::FuncOp>();
1237  if (!insertPoint)
1238  return failure();
1239 
1240  SmallString<32> funcName;
1241  FuncGeneratorType funcGenerator;
1242  uint32_t nTrailingP = 0;
1243  switch (op.getAlgorithm()) {
1244  case SparseTensorSortKind::HybridQuickSort: {
1245  funcName = kHybridQuickSortFuncNamePrefix;
1246  funcGenerator = createQuickSortFunc;
1247  nTrailingP = 1;
1248  // As a heuristics, set depthLimit = 2 * log2(n).
1249  Value lo = operands[loIdx];
1250  Value hi = operands[hiIdx];
1251  Value len = rewriter.create<arith::IndexCastOp>(
1252  loc, rewriter.getI64Type(),
1253  rewriter.create<arith::SubIOp>(loc, hi, lo));
1254  Value depthLimit = rewriter.create<arith::SubIOp>(
1255  loc, constantI64(rewriter, loc, 64),
1256  rewriter.create<math::CountLeadingZerosOp>(loc, len));
1257  operands.push_back(depthLimit);
1258  break;
1259  }
1260  case SparseTensorSortKind::QuickSort:
1261  funcName = kQuickSortFuncNamePrefix;
1262  funcGenerator = createQuickSortFunc;
1263  break;
1264  case SparseTensorSortKind::InsertionSortStable:
1265  funcName = kSortStableFuncNamePrefix;
1266  funcGenerator = createSortStableFunc;
1267  break;
1268  case SparseTensorSortKind::HeapSort:
1269  funcName = kHeapSortFuncNamePrefix;
1270  funcGenerator = createHeapSortFunc;
1271  break;
1272  }
1273 
1274  FlatSymbolRefAttr func =
1275  getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1276  xPerm, ny, operands, funcGenerator, nTrailingP);
1277  rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1278  return success();
1279 }
1280 
1281 //===---------------------------------------------------------------------===//
1282 // The actual sparse buffer rewriting rules.
1283 //===---------------------------------------------------------------------===//
1284 
1285 namespace {
1286 /// Sparse rewriting rule for the push_back operator.
1287 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1288 public:
1290  PushBackRewriter(MLIRContext *context, bool enableInit)
1291  : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1292  LogicalResult matchAndRewrite(PushBackOp op,
1293  PatternRewriter &rewriter) const override {
1294  // Rewrite push_back(buffer, value, n) to:
1295  // new_size = size(buffer) + n
1296  // if (new_size > capacity(buffer))
1297  // while new_size > new_capacity
1298  // new_capacity = new_capacity*2
1299  // new_buffer = realloc(buffer, new_capacity)
1300  // buffer = new_buffer
1301  // subBuffer = subviewof(buffer)
1302  // linalg.fill subBuffer value
1303  //
1304  // size(buffer) += n
1305  //
1306  // The capacity check is skipped when the attribute inbounds is presented.
1307  Location loc = op->getLoc();
1308  Value c0 = constantIndex(rewriter, loc, 0);
1309  Value buffer = op.getInBuffer();
1310  Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1311  Value size = op.getCurSize();
1312  Value value = op.getValue();
1313 
1314  Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1315  Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1316  auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1317  bool nIsOne = (nValue && nValue.value() == 1);
1318 
1319  if (!op.getInbounds()) {
1320  Value cond = rewriter.create<arith::CmpIOp>(
1321  loc, arith::CmpIPredicate::ugt, newSize, capacity);
1322 
1323  Value c2 = constantIndex(rewriter, loc, 2);
1324  auto bufferType =
1325  MemRefType::get({ShapedType::kDynamic}, value.getType());
1326  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1327  /*else=*/true);
1328  // True branch.
1329  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1330  if (nIsOne) {
1331  capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1332  } else {
1333  // Use a do-while loop to calculate the new capacity as follows:
1334  // do { new_capacity *= 2 } while (size > new_capacity)
1335  scf::WhileOp whileOp =
1336  rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1337 
1338  // The before-region of the WhileOp.
1339  Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1340  {capacity.getType()}, {loc});
1341  rewriter.setInsertionPointToEnd(before);
1342 
1343  capacity =
1344  rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1345  cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1346  newSize, capacity);
1347  rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1348  // The after-region of the WhileOp.
1349  Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1350  {capacity.getType()}, {loc});
1351  rewriter.setInsertionPointToEnd(after);
1352  rewriter.create<scf::YieldOp>(loc, after->getArguments());
1353 
1354  rewriter.setInsertionPointAfter(whileOp);
1355  capacity = whileOp.getResult(0);
1356  }
1357 
1358  Value newBuffer =
1359  rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1360  if (enableBufferInitialization) {
1361  Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1362  Value fillValue = constantZero(rewriter, loc, value.getType());
1363  Value subBuffer = rewriter.create<memref::SubViewOp>(
1364  loc, newBuffer, /*offset=*/ValueRange{newSize},
1365  /*size=*/ValueRange{fillSize},
1366  /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1367  rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1368  }
1369  rewriter.create<scf::YieldOp>(loc, newBuffer);
1370 
1371  // False branch.
1372  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1373  rewriter.create<scf::YieldOp>(loc, buffer);
1374 
1375  // Prepare for adding the value to the end of the buffer.
1376  rewriter.setInsertionPointAfter(ifOp);
1377  buffer = ifOp.getResult(0);
1378  }
1379 
1380  // Add the value to the end of the buffer.
1381  if (nIsOne) {
1382  rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1383  } else {
1384  Value subBuffer = rewriter.create<memref::SubViewOp>(
1385  loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1386  /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1387  rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1388  }
1389 
1390  // Update the buffer size.
1391  rewriter.replaceOp(op, {buffer, newSize});
1392  return success();
1393  }
1394 
1395 private:
1396  bool enableBufferInitialization;
1397 };
1398 
1399 /// Sparse rewriting rule for the sort_coo operator.
1400 struct SortRewriter : public OpRewritePattern<SortOp> {
1401 public:
1403 
1404  LogicalResult matchAndRewrite(SortOp op,
1405  PatternRewriter &rewriter) const override {
1406  SmallVector<Value> xys;
1407  xys.push_back(op.getXy());
1408  xys.append(op.getYs().begin(), op.getYs().end());
1409 
1410  auto xPerm = op.getPermMap();
1411  uint64_t ny = 0;
1412  if (auto nyAttr = op.getNyAttr())
1413  ny = nyAttr.getInt();
1414 
1415  return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1416  }
1417 };
1418 
1419 } // namespace
1420 
1421 //===---------------------------------------------------------------------===//
1422 // Methods that add patterns described in this file to a pattern list.
1423 //===---------------------------------------------------------------------===//
1424 
1426  bool enableBufferInitialization) {
1427  patterns.add<PushBackRewriter>(patterns.getContext(),
1428  enableBufferInitialization);
1429  patterns.add<SortRewriter>(patterns.getContext());
1430 }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
BlockArgListType getArguments()
Definition: Block.h:87
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:430
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Definition: CodegenUtils.h:336
Include the generated interface declarations.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.