MLIR  18.0.0git
SparseBufferRewriting.cpp
Go to the documentation of this file.
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38 static constexpr const char kBinarySearchFuncNamePrefix[] =
39  "_sparse_binary_search_";
40 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41  "_sparse_hybrid_qsort_";
42 static constexpr const char kSortStableFuncNamePrefix[] =
43  "_sparse_sort_stable_";
44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47 
48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49  AffineMap, uint64_t, uint32_t)>;
50 
51 /// Constructs a function name with this format to facilitate quick sort:
52 /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53 /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55  StringRef namePrefix, AffineMap xPerm,
56  uint64_t ny, ValueRange operands) {
57  nameOstream << namePrefix;
58  for (auto res : xPerm.getResults())
59  nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
60 
61  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62  nameOstream << "_coo_" << ny;
63 
64  constexpr uint64_t yBufferOffset = 1;
65  for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66  nameOstream << "_" << getMemRefType(v).getElementType();
67 }
68 
69 /// Looks up a function that is appropriate for the given operands being
70 /// sorted, and creates such a function if it doesn't exist yet. The
71 /// parameters `xPerm` and `ny` tell the number of x and y values provided
72 /// by the buffer in xStartIdx.
73 //
74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
75 // parameters for the sorting functions. Other parameters, such as the recursive
76 // call depth, are appended to the end of the parameter list as
77 // "trailing parameters".
79  OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80  StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81  FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82  SmallString<32> nameBuffer;
83  llvm::raw_svector_ostream nameOstream(nameBuffer);
84  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85  operands.drop_back(nTrailingP));
86 
87  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88  MLIRContext *context = module.getContext();
89  auto result = SymbolRefAttr::get(context, nameOstream.str());
90  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91 
92  if (!func) {
93  // Create the function.
94  OpBuilder::InsertionGuard insertionGuard(builder);
95  builder.setInsertionPoint(insertPoint);
96  Location loc = insertPoint.getLoc();
97  func = builder.create<func::FuncOp>(
98  loc, nameOstream.str(),
99  FunctionType::get(context, operands.getTypes(), resultTypes));
100  func.setPrivate();
101  createFunc(builder, module, func, xPerm, ny, nTrailingP);
102  }
103 
104  return result;
105 }
106 
107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108 /// The code to process the value pairs is generated by `bodyBuilder`.
109 static void forEachIJPairInXs(
110  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111  uint64_t ny,
112  function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117  unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
118  Value ak = constantIndex(builder, loc, actualK);
119  Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120  Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121  Value buffer = args[xStartIdx];
122 
123  bodyBuilder(k, i, j, buffer);
124  }
125 }
126 
127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128 /// The code to process the value pairs is generated by `bodyBuilder`.
130  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131  uint64_t ny,
132  function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133 
134  // Create code for the first (xPerm + ny) buffers.
135  SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
136  xPerm.getResults().end());
137  for (unsigned y = 0; y < ny; y++) {
138  exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
139  }
140  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
141  assert(xyPerm.isPermutation());
142 
143  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
144 
145  constexpr uint64_t numHandledBuffers = 1;
146  // Create code for the remaining buffers.
147  Value i = args[0];
148  Value j = args[1];
149  for (const auto &arg :
150  llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
151  bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
152  }
153 }
154 
155 /// Creates a code block for swapping the values in index i and j for all the
156 /// buffers.
157 //
158 // The generated IR corresponds to this C like algorithm:
159 // swap(x0[i], x0[j]);
160 // swap(x1[i], x1[j]);
161 // ...
162 // swap(xn[i], xn[j]);
163 // swap(y0[i], y0[j]);
164 // ...
165 // swap(yn[i], yn[j]);
166 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
167  AffineMap xPerm, uint64_t ny) {
168  auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
169  Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
170  Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
171  builder.create<memref::StoreOp>(loc, vj, buffer, i);
172  builder.create<memref::StoreOp>(loc, vi, buffer, j);
173  };
174 
175  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
176 }
177 
178 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
179 /// each pair is create via `compareBuilder`.
181  OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
182  uint64_t ny,
183  function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
184  compareBuilder) {
185  Value result;
186  auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
187  bool isFirstDim = (k == 0);
188  bool isLastDim = (k == xPerm.getNumResults() - 1);
189  Value val =
190  compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
191  if (isFirstDim) {
192  result = val;
193  } else if (!isLastDim) {
194  OpBuilder::InsertionGuard insertionGuard(builder);
195  auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
196  builder.setInsertionPointAfter(ifOp);
197  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
198  }
199  };
200 
201  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
202 
203  builder.setInsertionPointAfterValue(result);
204  return result;
205 }
206 
207 /// Generates code to compare whether x[i] is equal to x[j] and returns the
208 /// result of the comparison.
210  Value x, bool isFirstDim, bool isLastDim) {
211  Value vi = builder.create<memref::LoadOp>(loc, x, i);
212  Value vj = builder.create<memref::LoadOp>(loc, x, j);
213 
214  Value res;
215  if (isLastDim) {
216  res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
217  // For 1D, we create a compare without any control flow. Otherwise, we
218  // create YieldOp to return the result in the nested if-stmt.
219  if (!isFirstDim)
220  builder.create<scf::YieldOp>(loc, res);
221  } else {
222  Value ne =
223  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
224  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
225  ne, /*else=*/true);
226  // If (x[i] != x[j]).
227  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
228  Value f = constantI1(builder, loc, false);
229  builder.create<scf::YieldOp>(loc, f);
230 
231  // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
232  // checks the remaining dimensions.
233  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
234  res = ifOp.getResult(0);
235  }
236 
237  return res;
238 }
239 
240 /// Creates code to compare whether xs[i] is equal to xs[j].
241 //
242 // The generate IR corresponds to this C like algorithm:
243 // if (x0[i] != x0[j])
244 // return false;
245 // else
246 // if (x1[i] != x1[j])
247 // return false;
248 // else if (x2[2] != x2[j]))
249 // and so on ...
251  ValueRange args, AffineMap xPerm,
252  uint64_t ny, uint32_t nTrailingP = 0) {
253  // Compare functions don't use trailing parameters.
254  (void)nTrailingP;
255  assert(nTrailingP == 0);
256  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
258 }
259 
260 /// Generates code to compare whether x[i] is less than x[j] and returns the
261 /// result of the comparison.
263  Value j, Value x, bool isFirstDim,
264  bool isLastDim) {
265  Value vi = builder.create<memref::LoadOp>(loc, x, i);
266  Value vj = builder.create<memref::LoadOp>(loc, x, j);
267 
268  Value res;
269  if (isLastDim) {
270  res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
271  // For 1D, we create a compare without any control flow. Otherwise, we
272  // create YieldOp to return the result in the nested if-stmt.
273  if (!isFirstDim)
274  builder.create<scf::YieldOp>(loc, res);
275  } else {
276  Value ne =
277  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
278  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
279  ne, /*else=*/true);
280  // If (x[i] != x[j]).
281  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
282  Value lt =
283  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
284  builder.create<scf::YieldOp>(loc, lt);
285 
286  // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
287  // checks the remaining dimensions.
288  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
289  res = ifOp.getResult(0);
290  }
291 
292  return res;
293 }
294 
295 /// Creates code to compare whether xs[i] is less than xs[j].
296 //
297 // The generate IR corresponds to this C like algorithm:
298 // if (x0[i] != x0[j])
299 // return x0[i] < x0[j];
300 // else if (x1[j] != x1[i])
301 // return x1[i] < x1[j];
302 // else
303 // and so on ...
305  ValueRange args, AffineMap xPerm,
306  uint64_t ny, uint32_t nTrailingP = 0) {
307  // Compare functions don't use trailing parameters.
308  (void)nTrailingP;
309  assert(nTrailingP == 0);
310  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
312 }
313 
314 /// Creates a function to use a binary search to find the insertion point for
315 /// inserting xs[hi] to the sorted values xs[lo..hi).
316 //
317 // The generate IR corresponds to this C like algorithm:
318 // p = hi
319 // while (lo < hi)
320 // mid = (lo + hi) >> 1
321 // if (xs[p] < xs[mid])
322 // hi = mid
323 // else
324 // lo = mid - 1
325 // return lo;
326 //
327 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
328  func::FuncOp func, AffineMap xPerm,
329  uint64_t ny, uint32_t nTrailingP = 0) {
330  // Binary search doesn't use trailing parameters.
331  (void)nTrailingP;
332  assert(nTrailingP == 0);
333  OpBuilder::InsertionGuard insertionGuard(builder);
334  Block *entryBlock = func.addEntryBlock();
335  builder.setInsertionPointToStart(entryBlock);
336 
337  Location loc = func.getLoc();
338  ValueRange args = entryBlock->getArguments();
339  Value p = args[hiIdx];
340  SmallVector<Type, 2> types(2, p.getType()); // Only two types.
341  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
342  loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
343 
344  // The before-region of the WhileOp.
345  Block *before =
346  builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
347  builder.setInsertionPointToEnd(before);
348  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
349  before->getArgument(0),
350  before->getArgument(1));
351  builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
352 
353  // The after-region of the WhileOp.
354  Block *after =
355  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
356  builder.setInsertionPointToEnd(after);
357  Value lo = after->getArgument(0);
358  Value hi = after->getArgument(1);
359  // Compute mid = (lo + hi) >> 1.
360  Value c1 = constantIndex(builder, loc, 1);
361  Value mid = builder.create<arith::ShRUIOp>(
362  loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
363  Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
364 
365  // Compare xs[p] < xs[mid].
366  SmallVector<Value> compareOperands{p, mid};
367  constexpr uint64_t numXBuffers = 1;
368  compareOperands.append(args.begin() + xStartIdx,
369  args.begin() + xStartIdx + numXBuffers);
370  Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
371  // Update lo and hi for the WhileOp as follows:
372  // if (xs[p] < xs[mid]))
373  // hi = mid;
374  // else
375  // lo = mid + 1;
376  Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
377  Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
378  builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
379 
380  builder.setInsertionPointAfter(whileOp);
381  builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
382 }
383 
384 /// Creates code to advance i in a loop based on xs[p] as follows:
385 /// while (xs[i] < xs[p]) i += step (step > 0)
386 /// or
387 /// while (xs[i] > xs[p]) i += step (step < 0)
388 /// The routine returns i as well as a boolean value to indicate whether
389 /// xs[i] == xs[p].
390 static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
391  ModuleOp module,
392  func::FuncOp func, ValueRange xs,
393  Value i, Value p, AffineMap xPerm,
394  uint64_t ny, int step) {
395  Location loc = func.getLoc();
396  scf::WhileOp whileOp =
397  builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
398 
399  Block *before =
400  builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
401  builder.setInsertionPointToEnd(before);
402  SmallVector<Value> compareOperands;
403  if (step > 0) {
404  compareOperands.push_back(before->getArgument(0));
405  compareOperands.push_back(p);
406  } else {
407  assert(step < 0);
408  compareOperands.push_back(p);
409  compareOperands.push_back(before->getArgument(0));
410  }
411  compareOperands.append(xs.begin(), xs.end());
412  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
413  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
414 
415  Block *after =
416  builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
417  builder.setInsertionPointToEnd(after);
418  Value cs = constantIndex(builder, loc, step);
419  i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
420  builder.create<scf::YieldOp>(loc, ValueRange{i});
421  i = whileOp.getResult(0);
422 
423  builder.setInsertionPointAfter(whileOp);
424  compareOperands[0] = i;
425  compareOperands[1] = p;
426  Value compareEq =
427  createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
428 
429  return std::make_pair(whileOp.getResult(0), compareEq);
430 }
431 
432 /// Creates and returns an IfOp to compare two elements and swap the elements
433 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
434 /// right after the swap instructions.
435 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
436  AffineMap xPerm, uint64_t ny,
437  SmallVectorImpl<Value> &swapOperands,
438  SmallVectorImpl<Value> &compareOperands,
439  Value a, Value b) {
440  // Compare(data[b], data[a]).
441  compareOperands[0] = b;
442  compareOperands[1] = a;
443  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
444  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
445  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
446  swapOperands[0] = b;
447  swapOperands[1] = a;
448  createSwap(builder, loc, swapOperands, xPerm, ny);
449  return ifOp;
450 }
451 
452 /// Creates code to insert the 3rd element to a list of two sorted elements.
453 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
454  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
455  SmallVectorImpl<Value> &compareOperands, Value v0,
456  Value v1, Value v2) {
457  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
458  compareOperands, v1, v2);
459  createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
460  v0, v1);
461  builder.setInsertionPointAfter(ifOp);
462 }
463 
464 /// Creates code to sort 3 elements.
465 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
466  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
467  SmallVectorImpl<Value> &compareOperands, Value v0,
468  Value v1, Value v2) {
469  // Sort the first 2 elements.
470  scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
471  compareOperands, v0, v1);
472  builder.setInsertionPointAfter(ifOp1);
473 
474  // Insert the 3th element.
475  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
476  v1, v2);
477 }
478 
479 /// Creates code to sort 5 elements.
480 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
481  uint64_t ny, SmallVectorImpl<Value> &swapOperands,
482  SmallVectorImpl<Value> &compareOperands, Value v0,
483  Value v1, Value v2, Value v3, Value v4) {
484  // Sort the first 3 elements.
485  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
486  v2);
487 
488  auto insert4th = [&]() {
489  scf::IfOp ifOp = createCompareThenSwap(
490  builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
492  v1, v2);
493  builder.setInsertionPointAfter(ifOp);
494  };
495 
496  // Insert the 4th element.
497  insert4th();
498 
499  // Insert the 5th element.
500  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
501  compareOperands, v3, v4);
502  insert4th();
503  builder.setInsertionPointAfter(ifOp);
504 }
505 
506 /// Creates a code block to swap the values in indices lo, mi, and hi so that
507 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
508 /// the number of values in range [lo, hi) is more than a threshold, we also
509 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
510 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
511  func::FuncOp func, AffineMap xPerm, uint64_t ny,
512  Value lo, Value hi, Value mi, ValueRange args) {
513  SmallVector<Value> compareOperands{mi, lo};
514  constexpr uint64_t numXBuffers = 1;
515  compareOperands.append(args.begin() + xStartIdx,
516  args.begin() + xStartIdx + numXBuffers);
517  SmallVector<Value> swapOperands{mi, lo};
518  swapOperands.append(args.begin() + xStartIdx, args.end());
519  Location loc = func.getLoc();
520  Value c1 = constantIndex(builder, loc, 1);
521  Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
522  Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
523  Value lenThreshold = constantIndex(builder, loc, 1000);
524  Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
525  len, lenThreshold);
526  scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
527 
528  // When len < 1000, choose pivot from median of 3 values.
529  builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
530  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
531  hi);
532 
533  // When len >= 1000, choose pivot from median of 5 values.
534  builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
535  Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
536  Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
537  // Value a is the middle between [loc, mi].
538  a = builder.create<arith::ShRUIOp>(loc, a, c1);
539  Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
540  // Value b is the middle between [mi, hi].
541  b = builder.create<arith::ShRUIOp>(loc, b, c1);
542  createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
543  b, hi);
544 
545  builder.setInsertionPointAfter(lenIf);
546 }
547 
548 /// Creates a function to perform quick sort partition on the values in the
549 /// range of index [lo, hi), assuming lo < hi.
550 //
551 // The generated IR corresponds to this C like algorithm:
552 // int partition(lo, hi, xs) {
553 // p = (lo+hi)/2 // pivot index
554 // i = lo
555 // j = hi-1
556 // while (true) do {
557 // while (xs[i] < xs[p]) i ++;
558 // i_eq = (xs[i] == xs[p]);
559 // while (xs[j] > xs[p]) j --;
560 // j_eq = (xs[j] == xs[p]);
561 //
562 // if (i >= j) return j + 1;
563 //
564 // if (i < j) {
565 // swap(xs[i], xs[j])
566 // if (i == p) {
567 // p = j;
568 // } else if (j == p) {
569 // p = i;
570 // }
571 // if (i_eq && j_eq) {
572 // ++i;
573 // --j;
574 // }
575 // }
576 // }
577 // }
578 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
579  func::FuncOp func, AffineMap xPerm, uint64_t ny,
580  uint32_t nTrailingP = 0) {
581  // Quick sort partition doesn't use trailing parameters.
582  (void)nTrailingP;
583  assert(nTrailingP == 0);
584  OpBuilder::InsertionGuard insertionGuard(builder);
585 
586  Block *entryBlock = func.addEntryBlock();
587  builder.setInsertionPointToStart(entryBlock);
588 
589  Location loc = func.getLoc();
590  ValueRange args = entryBlock->getArguments();
591  Value lo = args[loIdx];
592  Value hi = args[hiIdx];
593  Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
594  Value c1 = constantIndex(builder, loc, 1);
595  Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
596 
597  Value i = lo;
598  Value j = builder.create<arith::SubIOp>(loc, hi, c1);
599  createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
600  Value trueVal = constantI1(builder, loc, true); // The value for while (true)
601  SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
602  SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
603  trueVal.getType()};
604  scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
605 
606  // The before-region of the WhileOp.
607  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
608  {loc, loc, loc, loc});
609  builder.setInsertionPointToEnd(before);
610  builder.create<scf::ConditionOp>(loc, before->getArgument(3),
611  before->getArguments());
612 
613  // The after-region of the WhileOp.
614  Block *after =
615  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
616  builder.setInsertionPointToEnd(after);
617  i = after->getArgument(0);
618  j = after->getArgument(1);
619  p = after->getArgument(2);
620 
621  constexpr uint64_t numXBuffers = 1;
622  auto [iresult, iCompareEq] =
623  createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
624  i, p, xPerm, ny, 1);
625  i = iresult;
626  auto [jresult, jCompareEq] =
627  createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
628  j, p, xPerm, ny, -1);
629  j = jresult;
630 
631  // If i < j:
632  Value cond =
633  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
634  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
635  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
636  SmallVector<Value> swapOperands{i, j};
637  swapOperands.append(args.begin() + xStartIdx, args.end());
638  createSwap(builder, loc, swapOperands, xPerm, ny);
639  // If the pivot is moved, update p with the new pivot.
640  Value icond =
641  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
642  scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
643  icond, /*else=*/true);
644  builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
645  builder.create<scf::YieldOp>(loc, ValueRange{j});
646  builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
647  Value jcond =
648  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
649  scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
650  jcond, /*else=*/true);
651  builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
652  builder.create<scf::YieldOp>(loc, ValueRange{i});
653  builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
654  builder.create<scf::YieldOp>(loc, ValueRange{p});
655  builder.setInsertionPointAfter(ifOpJ);
656  builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
657  builder.setInsertionPointAfter(ifOpI);
658  Value compareEqIJ =
659  builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
660  scf::IfOp ifOp2 = builder.create<scf::IfOp>(
661  loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
662  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
663  Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
664  Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
665  builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
666  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
667  builder.create<scf::YieldOp>(loc, ValueRange{i, j});
668  builder.setInsertionPointAfter(ifOp2);
669  builder.create<scf::YieldOp>(
670  loc,
671  ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
672  /*cont=*/constantI1(builder, loc, true)});
673 
674  // False branch for if i < j (i.e., i >= j):
675  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
676  p = builder.create<arith::AddIOp>(loc, j,
677  constantOne(builder, loc, j.getType()));
678  builder.create<scf::YieldOp>(
679  loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
680 
681  // Return for the whileOp.
682  builder.setInsertionPointAfter(ifOp);
683  builder.create<scf::YieldOp>(loc, ifOp.getResults());
684 
685  // Return for the function.
686  builder.setInsertionPointAfter(whileOp);
687  builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
688 }
689 
690 /// Computes (n-2)/n, assuming n has index type.
692  Value n) {
693  Value i2 = constantIndex(builder, loc, 2);
694  Value res = builder.create<arith::SubIOp>(loc, n, i2);
695  Value i1 = constantIndex(builder, loc, 1);
696  return builder.create<arith::ShRUIOp>(loc, res, i1);
697 }
698 
699 /// Creates a function to heapify the subtree with root `start` within the full
700 /// binary tree in the range of index [first, first + n).
701 //
702 // The generated IR corresponds to this C like algorithm:
703 // void shiftDown(first, start, n, data) {
704 // if (n >= 2) {
705 // child = start - first
706 // if ((n-2)/2 >= child) {
707 // // Left child exists.
708 // child = child * 2 + 1 // Initialize the bigger child to left child.
709 // childIndex = child + first
710 // if (child+1 < n && data[childIndex] < data[childIndex+1])
711 // // Right child exits and is bigger.
712 // childIndex++; child++;
713 // // Shift data[start] down to where it belongs in the subtree.
714 // while (data[start] < data[childIndex) {
715 // swap(data[start], data[childIndex])
716 // start = childIndex
717 // if ((n - 2)/2 >= child) {
718 // // Left child exists.
719 // child = 2*child + 1
720 // childIndex = child + 1
721 // if (child + 1) < n && data[childIndex] < data[childIndex+1]
722 // childIndex++; child++;
723 // }
724 // }
725 // }
726 // }
727 // }
728 //
729 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
730  func::FuncOp func, AffineMap xPerm, uint64_t ny,
731  uint32_t nTrailingP) {
732  // The value n is passed in as a trailing parameter.
733  assert(nTrailingP == 1);
734  OpBuilder::InsertionGuard insertionGuard(builder);
735  Block *entryBlock = func.addEntryBlock();
736  builder.setInsertionPointToStart(entryBlock);
737 
738  Location loc = func.getLoc();
739  Value n = entryBlock->getArguments().back();
740  ValueRange args = entryBlock->getArguments().drop_back();
741  Value first = args[loIdx];
742  Value start = args[hiIdx];
743 
744  // If (n >= 2).
745  Value c2 = constantIndex(builder, loc, 2);
746  Value condN =
747  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
748  scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
749  builder.setInsertionPointToStart(&ifN.getThenRegion().front());
750  Value child = builder.create<arith::SubIOp>(loc, start, first);
751 
752  // If ((n-2)/2 >= child).
753  Value t = createSubTwoDividedByTwo(builder, loc, n);
754  Value condNc =
755  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
756  scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
757 
758  builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
759  Value c1 = constantIndex(builder, loc, 1);
760  SmallVector<Value> compareOperands{start, start};
761  constexpr uint64_t numXBuffers = 1;
762  compareOperands.append(args.begin() + xStartIdx,
763  args.begin() + xStartIdx + numXBuffers);
764 
765  // Generate code to inspect the children of 'r' and return the larger child
766  // as follows:
767  // child = r * 2 + 1 // Left child.
768  // childIndex = child + first
769  // if (child+1 < n && data[childIndex] < data[childIndex+1])
770  // childIndex ++; child ++ // Right child is bigger.
771  auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
772  Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
773  lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
774  Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
775  Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
776  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
777  rChild, n);
778  SmallVector<Type, 2> ifTypes(2, r.getType());
779  scf::IfOp if1 =
780  builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
781  builder.setInsertionPointToStart(&if1.getThenRegion().front());
782  Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
783  // Compare data[left] < data[right].
784  compareOperands[0] = lChildIdx;
785  compareOperands[1] = rChildIdx;
786  Value cond2 =
787  createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
788  scf::IfOp if2 =
789  builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
790  builder.setInsertionPointToStart(&if2.getThenRegion().front());
791  builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
792  builder.setInsertionPointToStart(&if2.getElseRegion().front());
793  builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
794  builder.setInsertionPointAfter(if2);
795  builder.create<scf::YieldOp>(loc, if2.getResults());
796  builder.setInsertionPointToStart(&if1.getElseRegion().front());
797  builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
798  builder.setInsertionPointAfter(if1);
799  return std::make_pair(if1.getResult(0), if1.getResult(1));
800  };
801 
802  Value childIdx;
803  std::tie(child, childIdx) = getLargerChild(child);
804 
805  // While (data[start] < data[childIndex]).
806  SmallVector<Type, 3> types(3, child.getType());
807  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
808  loc, types, SmallVector<Value, 2>{start, child, childIdx});
809 
810  // The before-region of the WhileOp.
811  SmallVector<Location, 3> locs(3, loc);
812  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
813  builder.setInsertionPointToEnd(before);
814  start = before->getArgument(0);
815  childIdx = before->getArgument(2);
816  compareOperands[0] = start;
817  compareOperands[1] = childIdx;
818  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
819  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
820 
821  // The after-region of the WhileOp.
822  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
823  start = after->getArgument(0);
824  child = after->getArgument(1);
825  childIdx = after->getArgument(2);
826  SmallVector<Value> swapOperands{start, childIdx};
827  swapOperands.append(args.begin() + xStartIdx, args.end());
828  createSwap(builder, loc, swapOperands, xPerm, ny);
829  start = childIdx;
830  Value cond2 =
831  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
832  scf::IfOp if2 = builder.create<scf::IfOp>(
833  loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
834  builder.setInsertionPointToStart(&if2.getThenRegion().front());
835  auto [newChild, newChildIdx] = getLargerChild(child);
836  builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
837  builder.setInsertionPointToStart(&if2.getElseRegion().front());
838  builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
839  builder.setInsertionPointAfter(if2);
840  builder.create<scf::YieldOp>(
841  loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
842 
843  builder.setInsertionPointAfter(ifN);
844  builder.create<func::ReturnOp>(loc);
845 }
846 
847 /// Creates a function to perform heap sort on the values in the range of index
848 /// [lo, hi) with the assumption hi - lo >= 2.
849 //
850 // The generate IR corresponds to this C like algorithm:
851 // void heapSort(lo, hi, data) {
852 // n = hi - lo
853 // for i = (n-2)/2 downto 0
854 // shiftDown(lo, lo+i, n)
855 //
856 // for l = n downto 2
857 // swap(lo, lo+l-1)
858 // shiftdown(lo, lo, l-1)
859 // }
860 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
861  func::FuncOp func, AffineMap xPerm, uint64_t ny,
862  uint32_t nTrailingP) {
863  // Heap sort function doesn't have trailing parameters.
864  (void)nTrailingP;
865  assert(nTrailingP == 0);
866  OpBuilder::InsertionGuard insertionGuard(builder);
867  Block *entryBlock = func.addEntryBlock();
868  builder.setInsertionPointToStart(entryBlock);
869 
870  Location loc = func.getLoc();
871  ValueRange args = entryBlock->getArguments();
872  Value lo = args[loIdx];
873  Value hi = args[hiIdx];
874  Value n = builder.create<arith::SubIOp>(loc, hi, lo);
875 
876  // For i = (n-2)/2 downto 0.
877  Value c0 = constantIndex(builder, loc, 0);
878  Value c1 = constantIndex(builder, loc, 1);
879  Value s = createSubTwoDividedByTwo(builder, loc, n);
880  Value up = builder.create<arith::AddIOp>(loc, s, c1);
881  scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
882  builder.setInsertionPointToStart(forI.getBody());
883  Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
884  Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
885  SmallVector<Value> shiftDownOperands = {lo, lopi};
886  shiftDownOperands.append(args.begin() + xStartIdx, args.end());
887  shiftDownOperands.push_back(n);
889  builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
890  shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
891  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
892  shiftDownOperands);
893 
894  builder.setInsertionPointAfter(forI);
895  // For l = n downto 2.
896  up = builder.create<arith::SubIOp>(loc, n, c1);
897  scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
898  builder.setInsertionPointToStart(forL.getBody());
899  Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
900  Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
901  loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
902  SmallVector<Value> swapOperands{lo, loplm1};
903  swapOperands.append(args.begin() + xStartIdx, args.end());
904  createSwap(builder, loc, swapOperands, xPerm, ny);
905  shiftDownOperands[1] = lo;
906  shiftDownOperands[shiftDownOperands.size() - 1] =
907  builder.create<arith::SubIOp>(loc, l, c1);
908  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
909  shiftDownOperands);
910 
911  builder.setInsertionPointAfter(forL);
912  builder.create<func::ReturnOp>(loc);
913 }
914 
915 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
916 /// recursively calls quick sort to process the smaller partition and returns
917 /// the bigger partition to be processed by the enclosed while-loop.
918 static std::pair<Value, Value>
919 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
920  ValueRange args, AffineMap xPerm, uint64_t ny,
921  uint32_t nTrailingP) {
922  MLIRContext *context = module.getContext();
923  Location loc = func.getLoc();
924  Value lo = args[loIdx];
925  Value hi = args[hiIdx];
926  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
927 
929  builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
930  ny, args.drop_back(nTrailingP), createPartitionFunc);
931  Value p = builder
932  .create<func::CallOp>(loc, partitionFunc,
933  TypeRange{IndexType::get(context)},
934  args.drop_back(nTrailingP))
935  .getResult(0);
936 
937  Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
938  Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
939  // Partition already sorts array with len <= 2
940  Value c2 = constantIndex(builder, loc, 2);
941  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
942  Value lenGtTwo =
943  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
944  scf::IfOp ifLenGtTwo =
945  builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
946  builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
947  // Returns an empty range to mark the entire region is fully sorted.
948  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
949 
950  // Else len > 2, need recursion.
951  builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
952  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
953  lenLow, lenHigh);
954 
955  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956 
957  Value c0 = constantIndex(builder, loc, 0);
958  auto mayRecursion = [&](Value low, Value high, Value len) {
959  Value cond =
960  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
961  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
962  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
963  SmallVector<Value> operands{low, high};
964  operands.append(args.begin() + xStartIdx, args.end());
965  builder.create<func::CallOp>(loc, func, operands);
966  builder.setInsertionPointAfter(ifOp);
967  };
968 
969  // Recursively call quickSort to process the smaller partition and return
970  // the bigger partition to be processed by the enclosed while-loop.
971  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
972  mayRecursion(lo, p, lenLow);
973  builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
974 
975  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
976  mayRecursion(p, hi, lenHigh);
977  builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
978 
979  builder.setInsertionPointAfter(ifOp);
980  builder.create<scf::YieldOp>(loc, ifOp.getResults());
981 
982  builder.setInsertionPointAfter(ifLenGtTwo);
983  return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
984 }
985 
986 /// Creates a function to perform insertion sort on the values in the range of
987 /// index [lo, hi).
988 //
989 // The generate IR corresponds to this C like algorithm:
990 // void insertionSort(lo, hi, data) {
991 // for (i = lo+1; i < hi; i++) {
992 // d = data[i];
993 // p = binarySearch(lo, i-1, data)
994 // for (j = 0; j > i - p; j++)
995 // data[i-j] = data[i-j-1]
996 // data[p] = d
997 // }
998 // }
999 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1000  func::FuncOp func, AffineMap xPerm,
1001  uint64_t ny, uint32_t nTrailingP) {
1002  // Stable sort function doesn't use trailing parameters.
1003  (void)nTrailingP;
1004  assert(nTrailingP == 0);
1005  OpBuilder::InsertionGuard insertionGuard(builder);
1006  Block *entryBlock = func.addEntryBlock();
1007  builder.setInsertionPointToStart(entryBlock);
1008 
1009  MLIRContext *context = module.getContext();
1010  Location loc = func.getLoc();
1011  ValueRange args = entryBlock->getArguments();
1012  Value c1 = constantIndex(builder, loc, 1);
1013  Value lo = args[loIdx];
1014  Value hi = args[hiIdx];
1015  Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1016 
1017  // Start the outer for-stmt with induction variable i.
1018  scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1019  builder.setInsertionPointToStart(forOpI.getBody());
1020  Value i = forOpI.getInductionVar();
1021 
1022  // Binary search to find the insertion point p.
1023  SmallVector<Value> operands{lo, i};
1024  operands.append(args.begin() + xStartIdx, args.end());
1026  builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1027  xPerm, ny, operands, createBinarySearchFunc);
1028  Value p = builder
1029  .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1030  operands)
1031  .getResult(0);
1032 
1033  // Move the value at data[i] to a temporary location.
1034  operands[0] = operands[1] = i;
1037  builder, loc, operands, xPerm, ny,
1038  [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1039  d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1040  });
1041 
1042  // Start the inner for-stmt with induction variable j, for moving data[p..i)
1043  // to data[p+1..i+1).
1044  Value imp = builder.create<arith::SubIOp>(loc, i, p);
1045  Value c0 = constantIndex(builder, loc, 0);
1046  scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1047  builder.setInsertionPointToStart(forOpJ.getBody());
1048  Value j = forOpJ.getInductionVar();
1049  Value imj = builder.create<arith::SubIOp>(loc, i, j);
1050  operands[1] = imj;
1051  operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1053  builder, loc, operands, xPerm, ny,
1054  [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1055  Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1056  builder.create<memref::StoreOp>(loc, t, buffer, imj);
1057  });
1058 
1059  // Store the value at data[i] to data[p].
1060  builder.setInsertionPointAfter(forOpJ);
1061  operands[0] = operands[1] = p;
1063  builder, loc, operands, xPerm, ny,
1064  [&](uint64_t k, Value p, Value usused, Value buffer) {
1065  builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1066  });
1067 
1068  builder.setInsertionPointAfter(forOpI);
1069  builder.create<func::ReturnOp>(loc);
1070 }
1071 
1072 /// Creates a function to perform quick sort or a hybrid quick sort on the
1073 /// values in the range of index [lo, hi).
1074 //
1075 //
1076 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1077 // void quickSort(lo, hi, data) {
1078 // while (lo + 1 < hi) {
1079 // p = partition(low, high, data);
1080 // if (len(lo, p) < len(p+1, hi)) {
1081 // quickSort(lo, p, data);
1082 // lo = p+1;
1083 // } else {
1084 // quickSort(p + 1, hi, data);
1085 // hi = p;
1086 // }
1087 // }
1088 // }
1089 //
1090 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1091 // void hybridQuickSort(lo, hi, data, depthLimit) {
1092 // while (lo + 1 < hi) {
1093 // len = hi - lo;
1094 // if (len <= limit) {
1095 // insertionSort(lo, hi, data);
1096 // } else {
1097 // depthLimit --;
1098 // if (depthLimit <= 0) {
1099 // heapSort(lo, hi, data);
1100 // } else {
1101 // p = partition(low, high, data);
1102 // if (len(lo, p) < len(p+1, hi)) {
1103 // quickSort(lo, p, data, depthLimit);
1104 // lo = p+1;
1105 // } else {
1106 // quickSort(p + 1, hi, data, depthLimit);
1107 // hi = p;
1108 // }
1109 // }
1110 // }
1111 // }
1112 // }
1113 //
1114 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1115  func::FuncOp func, AffineMap xPerm, uint64_t ny,
1116  uint32_t nTrailingP) {
1117  assert(nTrailingP == 1 || nTrailingP == 0);
1118  bool isHybrid = (nTrailingP == 1);
1119  OpBuilder::InsertionGuard insertionGuard(builder);
1120  Block *entryBlock = func.addEntryBlock();
1121  builder.setInsertionPointToStart(entryBlock);
1122 
1123  Location loc = func.getLoc();
1124  SmallVector<Value> args;
1125  args.append(entryBlock->getArguments().begin(),
1126  entryBlock->getArguments().end());
1127  Value lo = args[loIdx];
1128  Value hi = args[hiIdx];
1129  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1130  scf::WhileOp whileOp =
1131  builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1132 
1133  // The before-region of the WhileOp.
1134  Block *before =
1135  builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1136  builder.setInsertionPointToEnd(before);
1137  lo = before->getArgument(0);
1138  hi = before->getArgument(1);
1139  Value loP1 =
1140  builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1141  Value needSort =
1142  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1143  builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1144 
1145  // The after-region of the WhileOp.
1146  Block *after =
1147  builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1148  builder.setInsertionPointToEnd(after);
1149  lo = after->getArgument(0);
1150  hi = after->getArgument(1);
1151  args[0] = lo;
1152  args[1] = hi;
1153 
1154  if (isHybrid) {
1155  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1156  Value lenLimit = constantIndex(builder, loc, 30);
1157  Value lenCond = builder.create<arith::CmpIOp>(
1158  loc, arith::CmpIPredicate::ule, len, lenLimit);
1159  scf::IfOp lenIf =
1160  builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1161 
1162  // When len <= limit.
1163  builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1164  FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1165  builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1166  ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1167  builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1168  ValueRange(args).drop_back(nTrailingP));
1169  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1170 
1171  // When len > limit.
1172  builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1173  Value depthLimit = args.back();
1174  depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1175  constantI64(builder, loc, 1));
1176  Value depthCond =
1177  builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1178  depthLimit, constantI64(builder, loc, 0));
1179  scf::IfOp depthIf =
1180  builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1181 
1182  // When depth exceeds limit.
1183  builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1185  builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1186  ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1187  builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1188  ValueRange(args).drop_back(nTrailingP));
1189  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1190 
1191  // When depth doesn't exceed limit.
1192  builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1193  args.back() = depthLimit;
1194  std::tie(lo, hi) =
1195  createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1196  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1197 
1198  builder.setInsertionPointAfter(depthIf);
1199  lo = depthIf.getResult(0);
1200  hi = depthIf.getResult(1);
1201  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1202 
1203  builder.setInsertionPointAfter(lenIf);
1204  lo = lenIf.getResult(0);
1205  hi = lenIf.getResult(1);
1206  } else {
1207  std::tie(lo, hi) =
1208  createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1209  }
1210 
1211  // New [lo, hi) for the next while-loop iteration.
1212  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1213 
1214  // After the while-loop.
1215  builder.setInsertionPointAfter(whileOp);
1216  builder.create<func::ReturnOp>(loc);
1217 }
1218 
1219 /// Implements the rewriting for operator sort and sort_coo.
1220 template <typename OpTy>
1222  uint64_t ny, PatternRewriter &rewriter) {
1223  Location loc = op.getLoc();
1224  SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1225 
1226  // Convert `values` to have dynamic shape and append them to `operands`.
1227  for (Value v : xys) {
1228  auto mtp = getMemRefType(v);
1229  if (!mtp.isDynamicDim(0)) {
1230  auto newMtp =
1231  MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1232  v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1233  }
1234  operands.push_back(v);
1235  }
1236 
1237  auto insertPoint = op->template getParentOfType<func::FuncOp>();
1238  if (!insertPoint)
1239  return failure();
1240 
1241  SmallString<32> funcName;
1242  FuncGeneratorType funcGenerator;
1243  uint32_t nTrailingP = 0;
1244  switch (op.getAlgorithm()) {
1245  case SparseTensorSortKind::HybridQuickSort: {
1246  funcName = kHybridQuickSortFuncNamePrefix;
1247  funcGenerator = createQuickSortFunc;
1248  nTrailingP = 1;
1249  // As a heuristics, set depthLimit = 2 * log2(n).
1250  Value lo = operands[loIdx];
1251  Value hi = operands[hiIdx];
1252  Value len = rewriter.create<arith::IndexCastOp>(
1253  loc, rewriter.getI64Type(),
1254  rewriter.create<arith::SubIOp>(loc, hi, lo));
1255  Value depthLimit = rewriter.create<arith::SubIOp>(
1256  loc, constantI64(rewriter, loc, 64),
1257  rewriter.create<math::CountLeadingZerosOp>(loc, len));
1258  operands.push_back(depthLimit);
1259  break;
1260  }
1261  case SparseTensorSortKind::QuickSort:
1262  funcName = kQuickSortFuncNamePrefix;
1263  funcGenerator = createQuickSortFunc;
1264  break;
1265  case SparseTensorSortKind::InsertionSortStable:
1266  funcName = kSortStableFuncNamePrefix;
1267  funcGenerator = createSortStableFunc;
1268  break;
1269  case SparseTensorSortKind::HeapSort:
1270  funcName = kHeapSortFuncNamePrefix;
1271  funcGenerator = createHeapSortFunc;
1272  break;
1273  }
1274 
1275  FlatSymbolRefAttr func =
1276  getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1277  xPerm, ny, operands, funcGenerator, nTrailingP);
1278  rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1279  return success();
1280 }
1281 
1282 //===---------------------------------------------------------------------===//
1283 // The actual sparse buffer rewriting rules.
1284 //===---------------------------------------------------------------------===//
1285 
1286 namespace {
1287 /// Sparse rewriting rule for the push_back operator.
1288 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1289 public:
1291  PushBackRewriter(MLIRContext *context, bool enableInit)
1292  : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1293  LogicalResult matchAndRewrite(PushBackOp op,
1294  PatternRewriter &rewriter) const override {
1295  // Rewrite push_back(buffer, value, n) to:
1296  // new_size = size(buffer) + n
1297  // if (new_size > capacity(buffer))
1298  // while new_size > new_capacity
1299  // new_capacity = new_capacity*2
1300  // new_buffer = realloc(buffer, new_capacity)
1301  // buffer = new_buffer
1302  // subBuffer = subviewof(buffer)
1303  // linalg.fill subBuffer value
1304  //
1305  // size(buffer) += n
1306  //
1307  // The capacity check is skipped when the attribute inbounds is presented.
1308  Location loc = op->getLoc();
1309  Value c0 = constantIndex(rewriter, loc, 0);
1310  Value buffer = op.getInBuffer();
1311  Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1312  Value size = op.getCurSize();
1313  Value value = op.getValue();
1314 
1315  Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1316  Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1317  auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1318  bool nIsOne = (nValue && nValue.value() == 1);
1319 
1320  if (!op.getInbounds()) {
1321  Value cond = rewriter.create<arith::CmpIOp>(
1322  loc, arith::CmpIPredicate::ugt, newSize, capacity);
1323 
1324  Value c2 = constantIndex(rewriter, loc, 2);
1325  auto bufferType =
1326  MemRefType::get({ShapedType::kDynamic}, value.getType());
1327  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1328  /*else=*/true);
1329  // True branch.
1330  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1331  if (nIsOne) {
1332  capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1333  } else {
1334  // Use a do-while loop to calculate the new capacity as follows:
1335  // do { new_capacity *= 2 } while (size > new_capacity)
1336  scf::WhileOp whileOp =
1337  rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1338 
1339  // The before-region of the WhileOp.
1340  Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1341  {capacity.getType()}, {loc});
1342  rewriter.setInsertionPointToEnd(before);
1343 
1344  capacity =
1345  rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1346  cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1347  newSize, capacity);
1348  rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1349  // The after-region of the WhileOp.
1350  Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1351  {capacity.getType()}, {loc});
1352  rewriter.setInsertionPointToEnd(after);
1353  rewriter.create<scf::YieldOp>(loc, after->getArguments());
1354 
1355  rewriter.setInsertionPointAfter(whileOp);
1356  capacity = whileOp.getResult(0);
1357  }
1358 
1359  Value newBuffer =
1360  rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1361  if (enableBufferInitialization) {
1362  Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1363  Value fillValue = constantZero(rewriter, loc, value.getType());
1364  Value subBuffer = rewriter.create<memref::SubViewOp>(
1365  loc, newBuffer, /*offset=*/ValueRange{newSize},
1366  /*size=*/ValueRange{fillSize},
1367  /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1368  rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1369  }
1370  rewriter.create<scf::YieldOp>(loc, newBuffer);
1371 
1372  // False branch.
1373  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1374  rewriter.create<scf::YieldOp>(loc, buffer);
1375 
1376  // Prepare for adding the value to the end of the buffer.
1377  rewriter.setInsertionPointAfter(ifOp);
1378  buffer = ifOp.getResult(0);
1379  }
1380 
1381  // Add the value to the end of the buffer.
1382  if (nIsOne) {
1383  rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1384  } else {
1385  Value subBuffer = rewriter.create<memref::SubViewOp>(
1386  loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1387  /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1388  rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1389  }
1390 
1391  // Update the buffer size.
1392  rewriter.replaceOp(op, {buffer, newSize});
1393  return success();
1394  }
1395 
1396 private:
1397  bool enableBufferInitialization;
1398 };
1399 
1400 /// Sparse rewriting rule for the sort_coo operator.
1401 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
1402 public:
1404 
1405  LogicalResult matchAndRewrite(SortCooOp op,
1406  PatternRewriter &rewriter) const override {
1407  SmallVector<Value> xys;
1408  xys.push_back(op.getXy());
1409  xys.append(op.getYs().begin(), op.getYs().end());
1410 
1411  auto xPerm = op.getPermMap();
1412  uint64_t ny = 0;
1413  if (auto nyAttr = op.getNyAttr())
1414  ny = nyAttr.getInt();
1415 
1416  return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1417  }
1418 };
1419 
1420 } // namespace
1421 
1422 //===---------------------------------------------------------------------===//
1423 // Methods that add patterns described in this file to a pattern list.
1424 //===---------------------------------------------------------------------===//
1425 
1427  bool enableBufferInitialization) {
1428  patterns.add<PushBackRewriter>(patterns.getContext(),
1429  enableBufferInitialization);
1430  patterns.add<SortCooRewriter>(patterns.getContext());
1431 }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
U cast() const
Definition: AffineExpr.h:293
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
unsigned getNumResults() const
Definition: AffineMap.cpp:345
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:354
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:564
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
BlockArgListType getArguments()
Definition: Block.h:80
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:350
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:328
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:339
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:375
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:104
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Definition: CodegenUtils.h:355
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.