MLIR 22.0.0git
SparseBufferRewriting.cpp
Go to the documentation of this file.
1//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensor
10// primitives with memref operands.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Utils/CodegenUtils.h"
15
24#include "mlir/Support/LLVM.h"
25
26using namespace mlir;
27using namespace mlir::sparse_tensor;
28
29//===---------------------------------------------------------------------===//
30// Helper methods for the actual rewriting rules.
31//===---------------------------------------------------------------------===//
32
33static constexpr uint64_t loIdx = 0;
34static constexpr uint64_t hiIdx = 1;
35static constexpr uint64_t xStartIdx = 2;
36
37static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38static constexpr const char kBinarySearchFuncNamePrefix[] =
39 "_sparse_binary_search_";
40static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41 "_sparse_hybrid_qsort_";
42static constexpr const char kSortStableFuncNamePrefix[] =
43 "_sparse_sort_stable_";
44static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47
48using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49 AffineMap, uint64_t, uint32_t)>;
50
51/// Constructs a function name with this format to facilitate quick sort:
52/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55 StringRef namePrefix, AffineMap xPerm,
56 uint64_t ny, ValueRange operands) {
57 nameOstream << namePrefix;
58 for (auto res : xPerm.getResults())
59 nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
60
61 nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62 nameOstream << "_coo_" << ny;
63
64 constexpr uint64_t yBufferOffset = 1;
65 for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66 nameOstream << "_" << getMemRefType(v).getElementType();
67}
68
69/// Looks up a function that is appropriate for the given operands being
70/// sorted, and creates such a function if it doesn't exist yet. The
71/// parameters `xPerm` and `ny` tell the number of x and y values provided
72/// by the buffer in xStartIdx.
73//
74// All sorting function generators take (lo, hi, xs, ys) in `operands` as
75// parameters for the sorting functions. Other parameters, such as the recursive
76// call depth, are appended to the end of the parameter list as
77// "trailing parameters".
79 OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80 StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81 FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82 SmallString<32> nameBuffer;
83 llvm::raw_svector_ostream nameOstream(nameBuffer);
84 getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85 operands.drop_back(nTrailingP));
86
87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88 MLIRContext *context = module.getContext();
89 auto result = SymbolRefAttr::get(context, nameOstream.str());
90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91
92 if (!func) {
93 // Create the function.
94 OpBuilder::InsertionGuard insertionGuard(builder);
95 builder.setInsertionPoint(insertPoint);
96 Location loc = insertPoint.getLoc();
97 func = func::FuncOp::create(
98 builder, loc, nameOstream.str(),
99 FunctionType::get(context, operands.getTypes(), resultTypes));
100 func.setPrivate();
101 createFunc(builder, module, func, xPerm, ny, nTrailingP);
102 }
103
104 return result;
105}
106
107/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108/// The code to process the value pairs is generated by `bodyBuilder`.
110 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111 uint64_t ny,
112 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113 Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114 Value iOffset = arith::MulIOp::create(builder, loc, args[0], cstep);
115 Value jOffset = arith::MulIOp::create(builder, loc, args[1], cstep);
116 for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117 unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
118 Value ak = constantIndex(builder, loc, actualK);
119 Value i = arith::AddIOp::create(builder, loc, ak, iOffset);
120 Value j = arith::AddIOp::create(builder, loc, ak, jOffset);
121 Value buffer = args[xStartIdx];
122
123 bodyBuilder(k, i, j, buffer);
124 }
125}
126
127/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128/// The code to process the value pairs is generated by `bodyBuilder`.
130 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131 uint64_t ny,
132 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133
134 // Create code for the first (xPerm + ny) buffers.
136 for (unsigned y = 0; y < ny; y++) {
137 exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
138 }
139 AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
140 assert(xyPerm.isPermutation());
141
142 forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
143
144 constexpr uint64_t numHandledBuffers = 1;
145 // Create code for the remaining buffers.
146 Value i = args[0];
147 Value j = args[1];
148 for (const auto &arg :
149 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
150 bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
151 }
152}
153
154/// Creates a code block for swapping the values in index i and j for all the
155/// buffers.
156//
157// The generated IR corresponds to this C like algorithm:
158// swap(x0[i], x0[j]);
159// swap(x1[i], x1[j]);
160// ...
161// swap(xn[i], xn[j]);
162// swap(y0[i], y0[j]);
163// ...
164// swap(yn[i], yn[j]);
165static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
166 AffineMap xPerm, uint64_t ny) {
167 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
168 Value vi = memref::LoadOp::create(builder, loc, buffer, i);
169 Value vj = memref::LoadOp::create(builder, loc, buffer, j);
170 memref::StoreOp::create(builder, loc, vj, buffer, i);
171 memref::StoreOp::create(builder, loc, vi, buffer, j);
172 };
173
174 forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
175}
176
177/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
178/// each pair is create via `compareBuilder`.
180 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
181 uint64_t ny,
183 compareBuilder) {
185 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
186 bool isFirstDim = (k == 0);
187 bool isLastDim = (k == xPerm.getNumResults() - 1);
188 Value val =
189 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
190 if (isFirstDim) {
191 result = val;
192 } else if (!isLastDim) {
193 OpBuilder::InsertionGuard insertionGuard(builder);
194 auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
195 builder.setInsertionPointAfter(ifOp);
196 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
197 }
198 };
199
200 forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
201
203 return result;
204}
205
206/// Generates code to compare whether x[i] is equal to x[j] and returns the
207/// result of the comparison.
209 Value x, bool isFirstDim, bool isLastDim) {
210 Value vi = memref::LoadOp::create(builder, loc, x, i);
211 Value vj = memref::LoadOp::create(builder, loc, x, j);
212
213 Value res;
214 if (isLastDim) {
215 res = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, vi, vj);
216 // For 1D, we create a compare without any control flow. Otherwise, we
217 // create YieldOp to return the result in the nested if-stmt.
218 if (!isFirstDim)
219 scf::YieldOp::create(builder, loc, res);
220 } else {
221 Value ne =
222 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
223 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1),
224 ne, /*else=*/true);
225 // If (x[i] != x[j]).
226 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
227 Value f = constantI1(builder, loc, false);
228 scf::YieldOp::create(builder, loc, f);
229
230 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
231 // checks the remaining dimensions.
232 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
233 res = ifOp.getResult(0);
234 }
235
236 return res;
237}
238
239/// Creates code to compare whether xs[i] is equal to xs[j].
240//
241// The generate IR corresponds to this C like algorithm:
242// if (x0[i] != x0[j])
243// return false;
244// else
245// if (x1[i] != x1[j])
246// return false;
247// else if (x2[2] != x2[j]))
248// and so on ...
250 ValueRange args, AffineMap xPerm,
251 uint64_t ny, uint32_t nTrailingP = 0) {
252 // Compare functions don't use trailing parameters.
253 (void)nTrailingP;
254 assert(nTrailingP == 0);
255 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
257}
258
259/// Generates code to compare whether x[i] is less than x[j] and returns the
260/// result of the comparison.
262 Value j, Value x, bool isFirstDim,
263 bool isLastDim) {
264 Value vi = memref::LoadOp::create(builder, loc, x, i);
265 Value vj = memref::LoadOp::create(builder, loc, x, j);
266
267 Value res;
268 if (isLastDim) {
269 res =
270 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
271 // For 1D, we create a compare without any control flow. Otherwise, we
272 // create YieldOp to return the result in the nested if-stmt.
273 if (!isFirstDim)
274 scf::YieldOp::create(builder, loc, res);
275 } else {
276 Value ne =
277 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, vi, vj);
278 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.getIntegerType(1),
279 ne, /*else=*/true);
280 // If (x[i] != x[j]).
281 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
282 Value lt =
283 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, vi, vj);
284 scf::YieldOp::create(builder, loc, lt);
285
286 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
287 // checks the remaining dimensions.
288 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
289 res = ifOp.getResult(0);
290 }
291
292 return res;
293}
294
295/// Creates code to compare whether xs[i] is less than xs[j].
296//
297// The generate IR corresponds to this C like algorithm:
298// if (x0[i] != x0[j])
299// return x0[i] < x0[j];
300// else if (x1[j] != x1[i])
301// return x1[i] < x1[j];
302// else
303// and so on ...
305 ValueRange args, AffineMap xPerm,
306 uint64_t ny, uint32_t nTrailingP = 0) {
307 // Compare functions don't use trailing parameters.
308 (void)nTrailingP;
309 assert(nTrailingP == 0);
310 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
312}
313
314/// Creates a function to use a binary search to find the insertion point for
315/// inserting xs[hi] to the sorted values xs[lo..hi).
316//
317// The generate IR corresponds to this C like algorithm:
318// p = hi
319// while (lo < hi)
320// mid = (lo + hi) >> 1
321// if (xs[p] < xs[mid])
322// hi = mid
323// else
324// lo = mid - 1
325// return lo;
326//
327static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
328 func::FuncOp func, AffineMap xPerm,
329 uint64_t ny, uint32_t nTrailingP = 0) {
330 // Binary search doesn't use trailing parameters.
331 (void)nTrailingP;
332 assert(nTrailingP == 0);
333 OpBuilder::InsertionGuard insertionGuard(builder);
334 Block *entryBlock = func.addEntryBlock();
335 builder.setInsertionPointToStart(entryBlock);
336
337 Location loc = func.getLoc();
338 ValueRange args = entryBlock->getArguments();
339 Value p = args[hiIdx];
340 SmallVector<Type, 2> types(2, p.getType()); // Only two types.
341 scf::WhileOp whileOp = scf::WhileOp::create(
342 builder, loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
343
344 // The before-region of the WhileOp.
345 Block *before =
346 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
347 builder.setInsertionPointToEnd(before);
348 Value cond1 =
349 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
350 before->getArgument(0), before->getArgument(1));
351 scf::ConditionOp::create(builder, loc, cond1, before->getArguments());
352
353 // The after-region of the WhileOp.
354 Block *after =
355 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
356 builder.setInsertionPointToEnd(after);
357 Value lo = after->getArgument(0);
358 Value hi = after->getArgument(1);
359 // Compute mid = (lo + hi) >> 1.
360 Value c1 = constantIndex(builder, loc, 1);
361 Value mid = arith::ShRUIOp::create(
362 builder, loc, arith::AddIOp::create(builder, loc, lo, hi), c1);
363 Value midp1 = arith::AddIOp::create(builder, loc, mid, c1);
364
365 // Compare xs[p] < xs[mid].
366 SmallVector<Value> compareOperands{p, mid};
367 constexpr uint64_t numXBuffers = 1;
368 compareOperands.append(args.begin() + xStartIdx,
369 args.begin() + xStartIdx + numXBuffers);
370 Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
371 // Update lo and hi for the WhileOp as follows:
372 // if (xs[p] < xs[mid]))
373 // hi = mid;
374 // else
375 // lo = mid + 1;
376 Value newLo = arith::SelectOp::create(builder, loc, cond2, lo, midp1);
377 Value newHi = arith::SelectOp::create(builder, loc, cond2, mid, hi);
378 scf::YieldOp::create(builder, loc, ValueRange{newLo, newHi});
379
380 builder.setInsertionPointAfter(whileOp);
381 func::ReturnOp::create(builder, loc, whileOp.getResult(0));
382}
383
384/// Creates code to advance i in a loop based on xs[p] as follows:
385/// while (xs[i] < xs[p]) i += step (step > 0)
386/// or
387/// while (xs[i] > xs[p]) i += step (step < 0)
388/// The routine returns i as well as a boolean value to indicate whether
389/// xs[i] == xs[p].
390static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
391 ModuleOp module,
392 func::FuncOp func, ValueRange xs,
393 Value i, Value p, AffineMap xPerm,
394 uint64_t ny, int step) {
395 Location loc = func.getLoc();
396 scf::WhileOp whileOp =
397 scf::WhileOp::create(builder, loc, TypeRange{i.getType()}, ValueRange{i});
398
399 Block *before =
400 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
401 builder.setInsertionPointToEnd(before);
402 SmallVector<Value> compareOperands;
403 if (step > 0) {
404 compareOperands.push_back(before->getArgument(0));
405 compareOperands.push_back(p);
406 } else {
407 assert(step < 0);
408 compareOperands.push_back(p);
409 compareOperands.push_back(before->getArgument(0));
410 }
411 compareOperands.append(xs.begin(), xs.end());
412 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
413 scf::ConditionOp::create(builder, loc, cond, before->getArguments());
414
415 Block *after =
416 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
417 builder.setInsertionPointToEnd(after);
418 Value cs = constantIndex(builder, loc, step);
419 i = arith::AddIOp::create(builder, loc, after->getArgument(0), cs);
420 scf::YieldOp::create(builder, loc, ValueRange{i});
421 i = whileOp.getResult(0);
422
423 builder.setInsertionPointAfter(whileOp);
424 compareOperands[0] = i;
425 compareOperands[1] = p;
426 Value compareEq =
427 createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
428
429 return std::make_pair(whileOp.getResult(0), compareEq);
430}
431
432/// Creates and returns an IfOp to compare two elements and swap the elements
433/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
434/// right after the swap instructions.
435static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
436 AffineMap xPerm, uint64_t ny,
437 SmallVectorImpl<Value> &swapOperands,
438 SmallVectorImpl<Value> &compareOperands,
439 Value a, Value b) {
440 // Compare(data[b], data[a]).
441 compareOperands[0] = b;
442 compareOperands[1] = a;
443 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
444 scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false);
445 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
446 swapOperands[0] = b;
447 swapOperands[1] = a;
448 createSwap(builder, loc, swapOperands, xPerm, ny);
449 return ifOp;
450}
451
452/// Creates code to insert the 3rd element to a list of two sorted elements.
453static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
454 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
455 SmallVectorImpl<Value> &compareOperands, Value v0,
456 Value v1, Value v2) {
457 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
458 compareOperands, v1, v2);
459 createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
460 v0, v1);
461 builder.setInsertionPointAfter(ifOp);
462}
463
464/// Creates code to sort 3 elements.
465static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
466 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
467 SmallVectorImpl<Value> &compareOperands, Value v0,
468 Value v1, Value v2) {
469 // Sort the first 2 elements.
470 scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
471 compareOperands, v0, v1);
472 builder.setInsertionPointAfter(ifOp1);
473
474 // Insert the 3th element.
475 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
476 v1, v2);
477}
478
479/// Creates code to sort 5 elements.
480static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
481 uint64_t ny, SmallVectorImpl<Value> &swapOperands,
482 SmallVectorImpl<Value> &compareOperands, Value v0,
483 Value v1, Value v2, Value v3, Value v4) {
484 // Sort the first 3 elements.
485 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
486 v2);
487
488 auto insert4th = [&]() {
489 scf::IfOp ifOp = createCompareThenSwap(
490 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
492 v1, v2);
493 builder.setInsertionPointAfter(ifOp);
494 };
495
496 // Insert the 4th element.
497 insert4th();
498
499 // Insert the 5th element.
500 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
501 compareOperands, v3, v4);
502 insert4th();
503 builder.setInsertionPointAfter(ifOp);
504}
505
506/// Creates a code block to swap the values in indices lo, mi, and hi so that
507/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
508/// the number of values in range [lo, hi) is more than a threshold, we also
509/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
510static void createChoosePivot(OpBuilder &builder, ModuleOp module,
511 func::FuncOp func, AffineMap xPerm, uint64_t ny,
512 Value lo, Value hi, Value mi, ValueRange args) {
513 SmallVector<Value> compareOperands{mi, lo};
514 constexpr uint64_t numXBuffers = 1;
515 compareOperands.append(args.begin() + xStartIdx,
516 args.begin() + xStartIdx + numXBuffers);
517 SmallVector<Value> swapOperands{mi, lo};
518 swapOperands.append(args.begin() + xStartIdx, args.end());
519 Location loc = func.getLoc();
520 Value c1 = constantIndex(builder, loc, 1);
521 Value hiP1 = arith::AddIOp::create(builder, loc, hi, c1);
522 Value len = arith::SubIOp::create(builder, loc, hiP1, lo);
523 Value lenThreshold = constantIndex(builder, loc, 1000);
524 Value lenCond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
525 len, lenThreshold);
526 scf::IfOp lenIf = scf::IfOp::create(builder, loc, lenCond, /*else=*/true);
527
528 // When len < 1000, choose pivot from median of 3 values.
529 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
530 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
531 hi);
532
533 // When len >= 1000, choose pivot from median of 5 values.
534 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
535 Value miP1 = arith::AddIOp::create(builder, loc, hi, c1);
536 Value a = arith::AddIOp::create(builder, loc, lo, miP1);
537 // Value a is the middle between [loc, mi].
538 a = arith::ShRUIOp::create(builder, loc, a, c1);
539 Value b = arith::AddIOp::create(builder, loc, mi, hiP1);
540 // Value b is the middle between [mi, hi].
541 b = arith::ShRUIOp::create(builder, loc, b, c1);
542 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
543 b, hi);
544
545 builder.setInsertionPointAfter(lenIf);
546}
547
548/// Creates a function to perform quick sort partition on the values in the
549/// range of index [lo, hi), assuming lo < hi.
550//
551// The generated IR corresponds to this C like algorithm:
552// int partition(lo, hi, xs) {
553// p = (lo+hi)/2 // pivot index
554// i = lo
555// j = hi-1
556// while (true) do {
557// while (xs[i] < xs[p]) i ++;
558// i_eq = (xs[i] == xs[p]);
559// while (xs[j] > xs[p]) j --;
560// j_eq = (xs[j] == xs[p]);
561//
562// if (i >= j) return j + 1;
563//
564// if (i < j) {
565// swap(xs[i], xs[j])
566// if (i == p) {
567// p = j;
568// } else if (j == p) {
569// p = i;
570// }
571// if (i_eq && j_eq) {
572// ++i;
573// --j;
574// }
575// }
576// }
577// }
578static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
579 func::FuncOp func, AffineMap xPerm, uint64_t ny,
580 uint32_t nTrailingP = 0) {
581 // Quick sort partition doesn't use trailing parameters.
582 (void)nTrailingP;
583 assert(nTrailingP == 0);
584 OpBuilder::InsertionGuard insertionGuard(builder);
585
586 Block *entryBlock = func.addEntryBlock();
587 builder.setInsertionPointToStart(entryBlock);
588
589 Location loc = func.getLoc();
590 ValueRange args = entryBlock->getArguments();
591 Value lo = args[loIdx];
592 Value hi = args[hiIdx];
593 Value sum = arith::AddIOp::create(builder, loc, lo, hi);
594 Value c1 = constantIndex(builder, loc, 1);
595 Value p = arith::ShRUIOp::create(builder, loc, sum, c1);
596
597 Value i = lo;
598 Value j = arith::SubIOp::create(builder, loc, hi, c1);
599 createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
600 Value trueVal = constantI1(builder, loc, true); // The value for while (true)
601 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
602 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
603 trueVal.getType()};
604 scf::WhileOp whileOp = scf::WhileOp::create(builder, loc, types, operands);
605
606 // The before-region of the WhileOp.
607 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
608 {loc, loc, loc, loc});
609 builder.setInsertionPointToEnd(before);
610 scf::ConditionOp::create(builder, loc, before->getArgument(3),
611 before->getArguments());
612
613 // The after-region of the WhileOp.
614 Block *after =
615 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
616 builder.setInsertionPointToEnd(after);
617 i = after->getArgument(0);
618 j = after->getArgument(1);
619 p = after->getArgument(2);
620
621 constexpr uint64_t numXBuffers = 1;
622 auto [iresult, iCompareEq] =
623 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
624 i, p, xPerm, ny, 1);
625 i = iresult;
626 auto [jresult, jCompareEq] =
627 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
628 j, p, xPerm, ny, -1);
629 j = jresult;
630
631 // If i < j:
632 Value cond =
633 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, i, j);
634 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
635 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
636 SmallVector<Value> swapOperands{i, j};
637 swapOperands.append(args.begin() + xStartIdx, args.end());
638 createSwap(builder, loc, swapOperands, xPerm, ny);
639 // If the pivot is moved, update p with the new pivot.
640 Value icond =
641 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, i, p);
642 scf::IfOp ifOpI = scf::IfOp::create(builder, loc, TypeRange{p.getType()},
643 icond, /*else=*/true);
644 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
645 scf::YieldOp::create(builder, loc, ValueRange{j});
646 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
647 Value jcond =
648 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, j, p);
649 scf::IfOp ifOpJ = scf::IfOp::create(builder, loc, TypeRange{p.getType()},
650 jcond, /*else=*/true);
651 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
652 scf::YieldOp::create(builder, loc, ValueRange{i});
653 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
654 scf::YieldOp::create(builder, loc, ValueRange{p});
655 builder.setInsertionPointAfter(ifOpJ);
656 scf::YieldOp::create(builder, loc, ifOpJ.getResults());
657 builder.setInsertionPointAfter(ifOpI);
658 Value compareEqIJ =
659 arith::AndIOp::create(builder, loc, iCompareEq, jCompareEq);
660 scf::IfOp ifOp2 =
661 scf::IfOp::create(builder, loc, TypeRange{i.getType(), j.getType()},
662 compareEqIJ, /*else=*/true);
663 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
664 Value i2 = arith::AddIOp::create(builder, loc, i, c1);
665 Value j2 = arith::SubIOp::create(builder, loc, j, c1);
666 scf::YieldOp::create(builder, loc, ValueRange{i2, j2});
667 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
668 scf::YieldOp::create(builder, loc, ValueRange{i, j});
669 builder.setInsertionPointAfter(ifOp2);
670 scf::YieldOp::create(builder, loc,
671 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1),
672 ifOpI.getResult(0),
673 /*cont=*/constantI1(builder, loc, true)});
674
675 // False branch for if i < j (i.e., i >= j):
676 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
677 p = arith::AddIOp::create(builder, loc, j,
678 constantOne(builder, loc, j.getType()));
679 scf::YieldOp::create(
680 builder, loc,
681 ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
682
683 // Return for the whileOp.
684 builder.setInsertionPointAfter(ifOp);
685 scf::YieldOp::create(builder, loc, ifOp.getResults());
686
687 // Return for the function.
688 builder.setInsertionPointAfter(whileOp);
689 func::ReturnOp::create(builder, loc, whileOp.getResult(2));
690}
691
692/// Computes (n-2)/n, assuming n has index type.
694 Value n) {
695 Value i2 = constantIndex(builder, loc, 2);
696 Value res = arith::SubIOp::create(builder, loc, n, i2);
697 Value i1 = constantIndex(builder, loc, 1);
698 return arith::ShRUIOp::create(builder, loc, res, i1);
699}
700
701/// Creates a function to heapify the subtree with root `start` within the full
702/// binary tree in the range of index [first, first + n).
703//
704// The generated IR corresponds to this C like algorithm:
705// void shiftDown(first, start, n, data) {
706// if (n >= 2) {
707// child = start - first
708// if ((n-2)/2 >= child) {
709// // Left child exists.
710// child = child * 2 + 1 // Initialize the bigger child to left child.
711// childIndex = child + first
712// if (child+1 < n && data[childIndex] < data[childIndex+1])
713// // Right child exits and is bigger.
714// childIndex++; child++;
715// // Shift data[start] down to where it belongs in the subtree.
716// while (data[start] < data[childIndex) {
717// swap(data[start], data[childIndex])
718// start = childIndex
719// if ((n - 2)/2 >= child) {
720// // Left child exists.
721// child = 2*child + 1
722// childIndex = child + 1
723// if (child + 1) < n && data[childIndex] < data[childIndex+1]
724// childIndex++; child++;
725// }
726// }
727// }
728// }
729// }
730//
731static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
732 func::FuncOp func, AffineMap xPerm, uint64_t ny,
733 uint32_t nTrailingP) {
734 // The value n is passed in as a trailing parameter.
735 assert(nTrailingP == 1);
736 OpBuilder::InsertionGuard insertionGuard(builder);
737 Block *entryBlock = func.addEntryBlock();
738 builder.setInsertionPointToStart(entryBlock);
739
740 Location loc = func.getLoc();
741 Value n = entryBlock->getArguments().back();
742 ValueRange args = entryBlock->getArguments().drop_back();
743 Value first = args[loIdx];
744 Value start = args[hiIdx];
745
746 // If (n >= 2).
747 Value c2 = constantIndex(builder, loc, 2);
748 Value condN =
749 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, n, c2);
750 scf::IfOp ifN = scf::IfOp::create(builder, loc, condN, /*else=*/false);
751 builder.setInsertionPointToStart(&ifN.getThenRegion().front());
752 Value child = arith::SubIOp::create(builder, loc, start, first);
753
754 // If ((n-2)/2 >= child).
755 Value t = createSubTwoDividedByTwo(builder, loc, n);
756 Value condNc =
757 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
758 scf::IfOp ifNc = scf::IfOp::create(builder, loc, condNc, /*else=*/false);
759
760 builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
761 Value c1 = constantIndex(builder, loc, 1);
762 SmallVector<Value> compareOperands{start, start};
763 constexpr uint64_t numXBuffers = 1;
764 compareOperands.append(args.begin() + xStartIdx,
765 args.begin() + xStartIdx + numXBuffers);
766
767 // Generate code to inspect the children of 'r' and return the larger child
768 // as follows:
769 // child = r * 2 + 1 // Left child.
770 // childIndex = child + first
771 // if (child+1 < n && data[childIndex] < data[childIndex+1])
772 // childIndex ++; child ++ // Right child is bigger.
773 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
774 Value lChild = arith::ShLIOp::create(builder, loc, r, c1);
775 lChild = arith::AddIOp::create(builder, loc, lChild, c1);
776 Value lChildIdx = arith::AddIOp::create(builder, loc, lChild, first);
777 Value rChild = arith::AddIOp::create(builder, loc, lChild, c1);
778 Value cond1 = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
779 rChild, n);
780 SmallVector<Type, 2> ifTypes(2, r.getType());
781 scf::IfOp if1 =
782 scf::IfOp::create(builder, loc, ifTypes, cond1, /*else=*/true);
783 builder.setInsertionPointToStart(&if1.getThenRegion().front());
784 Value rChildIdx = arith::AddIOp::create(builder, loc, rChild, first);
785 // Compare data[left] < data[right].
786 compareOperands[0] = lChildIdx;
787 compareOperands[1] = rChildIdx;
788 Value cond2 =
789 createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
790 scf::IfOp if2 =
791 scf::IfOp::create(builder, loc, ifTypes, cond2, /*else=*/true);
792 builder.setInsertionPointToStart(&if2.getThenRegion().front());
793 scf::YieldOp::create(builder, loc, ValueRange{rChild, rChildIdx});
794 builder.setInsertionPointToStart(&if2.getElseRegion().front());
795 scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx});
796 builder.setInsertionPointAfter(if2);
797 scf::YieldOp::create(builder, loc, if2.getResults());
798 builder.setInsertionPointToStart(&if1.getElseRegion().front());
799 scf::YieldOp::create(builder, loc, ValueRange{lChild, lChildIdx});
800 builder.setInsertionPointAfter(if1);
801 return std::make_pair(if1.getResult(0), if1.getResult(1));
802 };
803
804 Value childIdx;
805 std::tie(child, childIdx) = getLargerChild(child);
806
807 // While (data[start] < data[childIndex]).
808 SmallVector<Type, 3> types(3, child.getType());
809 scf::WhileOp whileOp = scf::WhileOp::create(
810 builder, loc, types, SmallVector<Value, 2>{start, child, childIdx});
811
812 // The before-region of the WhileOp.
813 SmallVector<Location, 3> locs(3, loc);
814 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
815 builder.setInsertionPointToEnd(before);
816 start = before->getArgument(0);
817 childIdx = before->getArgument(2);
818 compareOperands[0] = start;
819 compareOperands[1] = childIdx;
820 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
821 scf::ConditionOp::create(builder, loc, cond, before->getArguments());
822
823 // The after-region of the WhileOp.
824 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
825 start = after->getArgument(0);
826 child = after->getArgument(1);
827 childIdx = after->getArgument(2);
828 SmallVector<Value> swapOperands{start, childIdx};
829 swapOperands.append(args.begin() + xStartIdx, args.end());
830 createSwap(builder, loc, swapOperands, xPerm, ny);
831 start = childIdx;
832 Value cond2 =
833 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::uge, t, child);
834 scf::IfOp if2 = scf::IfOp::create(builder, loc,
835 TypeRange{child.getType(), child.getType()},
836 cond2, /*else=*/true);
837 builder.setInsertionPointToStart(&if2.getThenRegion().front());
838 auto [newChild, newChildIdx] = getLargerChild(child);
839 scf::YieldOp::create(builder, loc, ValueRange{newChild, newChildIdx});
840 builder.setInsertionPointToStart(&if2.getElseRegion().front());
841 scf::YieldOp::create(builder, loc, ValueRange{child, childIdx});
842 builder.setInsertionPointAfter(if2);
843 scf::YieldOp::create(builder, loc,
844 ValueRange{start, if2.getResult(0), if2.getResult(1)});
845
846 builder.setInsertionPointAfter(ifN);
847 func::ReturnOp::create(builder, loc);
848}
849
850/// Creates a function to perform heap sort on the values in the range of index
851/// [lo, hi) with the assumption hi - lo >= 2.
852//
853// The generate IR corresponds to this C like algorithm:
854// void heapSort(lo, hi, data) {
855// n = hi - lo
856// for i = (n-2)/2 downto 0
857// shiftDown(lo, lo+i, n)
858//
859// for l = n downto 2
860// swap(lo, lo+l-1)
861// shiftdown(lo, lo, l-1)
862// }
863static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
864 func::FuncOp func, AffineMap xPerm, uint64_t ny,
865 uint32_t nTrailingP) {
866 // Heap sort function doesn't have trailing parameters.
867 (void)nTrailingP;
868 assert(nTrailingP == 0);
869 OpBuilder::InsertionGuard insertionGuard(builder);
870 Block *entryBlock = func.addEntryBlock();
871 builder.setInsertionPointToStart(entryBlock);
872
873 Location loc = func.getLoc();
874 ValueRange args = entryBlock->getArguments();
875 Value lo = args[loIdx];
876 Value hi = args[hiIdx];
877 Value n = arith::SubIOp::create(builder, loc, hi, lo);
878
879 // For i = (n-2)/2 downto 0.
880 Value c0 = constantIndex(builder, loc, 0);
881 Value c1 = constantIndex(builder, loc, 1);
882 Value s = createSubTwoDividedByTwo(builder, loc, n);
883 Value up = arith::AddIOp::create(builder, loc, s, c1);
884 scf::ForOp forI = scf::ForOp::create(builder, loc, c0, up, c1);
885 builder.setInsertionPointToStart(forI.getBody());
886 Value i = arith::SubIOp::create(builder, loc, s, forI.getInductionVar());
887 Value lopi = arith::AddIOp::create(builder, loc, lo, i);
888 SmallVector<Value> shiftDownOperands = {lo, lopi};
889 shiftDownOperands.append(args.begin() + xStartIdx, args.end());
890 shiftDownOperands.push_back(n);
892 builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
893 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
894 func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(),
895 shiftDownOperands);
896
897 builder.setInsertionPointAfter(forI);
898 // For l = n downto 2.
899 up = arith::SubIOp::create(builder, loc, n, c1);
900 scf::ForOp forL = scf::ForOp::create(builder, loc, c0, up, c1);
901 builder.setInsertionPointToStart(forL.getBody());
902 Value l = arith::SubIOp::create(builder, loc, n, forL.getInductionVar());
903 Value loplm1 = arith::AddIOp::create(builder, loc, lo, l);
904 loplm1 = arith::SubIOp::create(builder, loc, loplm1, c1);
905 SmallVector<Value> swapOperands{lo, loplm1};
906 swapOperands.append(args.begin() + xStartIdx, args.end());
907 createSwap(builder, loc, swapOperands, xPerm, ny);
908 shiftDownOperands[1] = lo;
909 shiftDownOperands[shiftDownOperands.size() - 1] =
910 arith::SubIOp::create(builder, loc, l, c1);
911 func::CallOp::create(builder, loc, shiftDownFunc, TypeRange(),
912 shiftDownOperands);
913
914 builder.setInsertionPointAfter(forL);
915 func::ReturnOp::create(builder, loc);
916}
917
918/// A helper for generating code to perform quick sort. It partitions [lo, hi),
919/// recursively calls quick sort to process the smaller partition and returns
920/// the bigger partition to be processed by the enclosed while-loop.
921static std::pair<Value, Value>
922createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
923 ValueRange args, AffineMap xPerm, uint64_t ny,
924 uint32_t nTrailingP) {
925 MLIRContext *context = module.getContext();
926 Location loc = func.getLoc();
927 Value lo = args[loIdx];
928 Value hi = args[hiIdx];
929 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
930
932 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
933 ny, args.drop_back(nTrailingP), createPartitionFunc);
934 Value p = func::CallOp::create(builder, loc, partitionFunc,
935 TypeRange{IndexType::get(context)},
936 args.drop_back(nTrailingP))
937 .getResult(0);
938
939 Value lenLow = arith::SubIOp::create(builder, loc, p, lo);
940 Value lenHigh = arith::SubIOp::create(builder, loc, hi, p);
941 // Partition already sorts array with len <= 2
942 Value c2 = constantIndex(builder, loc, 2);
943 Value len = arith::SubIOp::create(builder, loc, hi, lo);
944 Value lenGtTwo =
945 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ugt, len, c2);
946 scf::IfOp ifLenGtTwo =
947 scf::IfOp::create(builder, loc, types, lenGtTwo, /*else=*/true);
948 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
949 // Returns an empty range to mark the entire region is fully sorted.
950 scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
951
952 // Else len > 2, need recursion.
953 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
954 Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
955 lenLow, lenHigh);
956
957 Value c0 = constantIndex(builder, loc, 0);
958 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond, /*else=*/true);
959
960 auto mayRecursion = [&](Value low, Value high, Value len) {
961 Value cond =
962 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, len, c0);
963 scf::IfOp ifOp = scf::IfOp::create(builder, loc, cond, /*else=*/false);
964 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
965 SmallVector<Value> operands{low, high};
966 operands.append(args.begin() + xStartIdx, args.end());
967 func::CallOp::create(builder, loc, func, operands);
968 builder.setInsertionPointAfter(ifOp);
969 };
970
971 // Recursively call quickSort to process the smaller partition and return
972 // the bigger partition to be processed by the enclosed while-loop.
973 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
974 mayRecursion(lo, p, lenLow);
975 scf::YieldOp::create(builder, loc, ValueRange{p, hi});
976
977 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
978 mayRecursion(p, hi, lenHigh);
979 scf::YieldOp::create(builder, loc, ValueRange{lo, p});
980
981 builder.setInsertionPointAfter(ifOp);
982 scf::YieldOp::create(builder, loc, ifOp.getResults());
983
984 builder.setInsertionPointAfter(ifLenGtTwo);
985 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
986}
987
988/// Creates a function to perform insertion sort on the values in the range of
989/// index [lo, hi).
990//
991// The generate IR corresponds to this C like algorithm:
992// void insertionSort(lo, hi, data) {
993// for (i = lo+1; i < hi; i++) {
994// d = data[i];
995// p = binarySearch(lo, i-1, data)
996// for (j = 0; j > i - p; j++)
997// data[i-j] = data[i-j-1]
998// data[p] = d
999// }
1000// }
1001static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1002 func::FuncOp func, AffineMap xPerm,
1003 uint64_t ny, uint32_t nTrailingP) {
1004 // Stable sort function doesn't use trailing parameters.
1005 (void)nTrailingP;
1006 assert(nTrailingP == 0);
1007 OpBuilder::InsertionGuard insertionGuard(builder);
1008 Block *entryBlock = func.addEntryBlock();
1009 builder.setInsertionPointToStart(entryBlock);
1010
1011 MLIRContext *context = module.getContext();
1012 Location loc = func.getLoc();
1013 ValueRange args = entryBlock->getArguments();
1014 Value c1 = constantIndex(builder, loc, 1);
1015 Value lo = args[loIdx];
1016 Value hi = args[hiIdx];
1017 Value lop1 = arith::AddIOp::create(builder, loc, lo, c1);
1018
1019 // Start the outer for-stmt with induction variable i.
1020 scf::ForOp forOpI = scf::ForOp::create(builder, loc, lop1, hi, c1);
1021 builder.setInsertionPointToStart(forOpI.getBody());
1022 Value i = forOpI.getInductionVar();
1023
1024 // Binary search to find the insertion point p.
1025 SmallVector<Value> operands{lo, i};
1026 operands.append(args.begin() + xStartIdx, args.end());
1028 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1029 xPerm, ny, operands, createBinarySearchFunc);
1030 Value p = func::CallOp::create(builder, loc, searchFunc,
1031 TypeRange{c1.getType()}, operands)
1032 .getResult(0);
1033
1034 // Move the value at data[i] to a temporary location.
1035 operands[0] = operands[1] = i;
1038 builder, loc, operands, xPerm, ny,
1039 [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1040 d.push_back(memref::LoadOp::create(builder, loc, buffer, i));
1041 });
1042
1043 // Start the inner for-stmt with induction variable j, for moving data[p..i)
1044 // to data[p+1..i+1).
1045 Value imp = arith::SubIOp::create(builder, loc, i, p);
1046 Value c0 = constantIndex(builder, loc, 0);
1047 scf::ForOp forOpJ = scf::ForOp::create(builder, loc, c0, imp, c1);
1048 builder.setInsertionPointToStart(forOpJ.getBody());
1049 Value j = forOpJ.getInductionVar();
1050 Value imj = arith::SubIOp::create(builder, loc, i, j);
1051 operands[1] = imj;
1052 operands[0] = arith::SubIOp::create(builder, loc, imj, c1);
1054 builder, loc, operands, xPerm, ny,
1055 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1056 Value t = memref::LoadOp::create(builder, loc, buffer, imjm1);
1057 memref::StoreOp::create(builder, loc, t, buffer, imj);
1058 });
1059
1060 // Store the value at data[i] to data[p].
1061 builder.setInsertionPointAfter(forOpJ);
1062 operands[0] = operands[1] = p;
1064 builder, loc, operands, xPerm, ny,
1065 [&](uint64_t k, Value p, Value usused, Value buffer) {
1066 memref::StoreOp::create(builder, loc, d[k], buffer, p);
1067 });
1068
1069 builder.setInsertionPointAfter(forOpI);
1070 func::ReturnOp::create(builder, loc);
1071}
1072
1073/// Creates a function to perform quick sort or a hybrid quick sort on the
1074/// values in the range of index [lo, hi).
1075//
1076//
1077// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1078// void quickSort(lo, hi, data) {
1079// while (lo + 1 < hi) {
1080// p = partition(low, high, data);
1081// if (len(lo, p) < len(p+1, hi)) {
1082// quickSort(lo, p, data);
1083// lo = p+1;
1084// } else {
1085// quickSort(p + 1, hi, data);
1086// hi = p;
1087// }
1088// }
1089// }
1090//
1091// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1092// void hybridQuickSort(lo, hi, data, depthLimit) {
1093// while (lo + 1 < hi) {
1094// len = hi - lo;
1095// if (len <= limit) {
1096// insertionSort(lo, hi, data);
1097// } else {
1098// depthLimit --;
1099// if (depthLimit <= 0) {
1100// heapSort(lo, hi, data);
1101// } else {
1102// p = partition(low, high, data);
1103// if (len(lo, p) < len(p+1, hi)) {
1104// quickSort(lo, p, data, depthLimit);
1105// lo = p+1;
1106// } else {
1107// quickSort(p + 1, hi, data, depthLimit);
1108// hi = p;
1109// }
1110// }
1111// }
1112// }
1113// }
1114//
1115static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1116 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1117 uint32_t nTrailingP) {
1118 assert(nTrailingP == 1 || nTrailingP == 0);
1119 bool isHybrid = (nTrailingP == 1);
1120 OpBuilder::InsertionGuard insertionGuard(builder);
1121 Block *entryBlock = func.addEntryBlock();
1122 builder.setInsertionPointToStart(entryBlock);
1123
1124 Location loc = func.getLoc();
1125 SmallVector<Value> args;
1126 args.append(entryBlock->getArguments().begin(),
1127 entryBlock->getArguments().end());
1128 Value lo = args[loIdx];
1129 Value hi = args[hiIdx];
1130 SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1131 scf::WhileOp whileOp =
1132 scf::WhileOp::create(builder, loc, types, SmallVector<Value, 2>{lo, hi});
1133
1134 // The before-region of the WhileOp.
1135 Block *before =
1136 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1137 builder.setInsertionPointToEnd(before);
1138 lo = before->getArgument(0);
1139 hi = before->getArgument(1);
1140 Value loP1 =
1141 arith::AddIOp::create(builder, loc, lo, constantIndex(builder, loc, 1));
1142 Value needSort =
1143 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult, loP1, hi);
1144 scf::ConditionOp::create(builder, loc, needSort, before->getArguments());
1145
1146 // The after-region of the WhileOp.
1147 Block *after =
1148 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1149 builder.setInsertionPointToEnd(after);
1150 lo = after->getArgument(0);
1151 hi = after->getArgument(1);
1152 args[0] = lo;
1153 args[1] = hi;
1154
1155 if (isHybrid) {
1156 Value len = arith::SubIOp::create(builder, loc, hi, lo);
1157 Value lenLimit = constantIndex(builder, loc, 30);
1158 Value lenCond = arith::CmpIOp::create(
1159 builder, loc, arith::CmpIPredicate::ule, len, lenLimit);
1160 scf::IfOp lenIf =
1161 scf::IfOp::create(builder, loc, types, lenCond, /*else=*/true);
1162
1163 // When len <= limit.
1164 builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1165 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1166 builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1167 ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1168 func::CallOp::create(builder, loc, insertionSortFunc, TypeRange(),
1169 ValueRange(args).drop_back(nTrailingP));
1170 scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
1171
1172 // When len > limit.
1173 builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1174 Value depthLimit = args.back();
1175 depthLimit = arith::SubIOp::create(builder, loc, depthLimit,
1176 constantI64(builder, loc, 1));
1177 Value depthCond =
1178 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ule,
1179 depthLimit, constantI64(builder, loc, 0));
1180 scf::IfOp depthIf =
1181 scf::IfOp::create(builder, loc, types, depthCond, /*else=*/true);
1182
1183 // When depth exceeds limit.
1184 builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1186 builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1187 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1188 func::CallOp::create(builder, loc, heapSortFunc, TypeRange(),
1189 ValueRange(args).drop_back(nTrailingP));
1190 scf::YieldOp::create(builder, loc, ValueRange{lo, lo});
1191
1192 // When depth doesn't exceed limit.
1193 builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1194 args.back() = depthLimit;
1195 std::tie(lo, hi) =
1196 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1197 scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
1198
1199 builder.setInsertionPointAfter(depthIf);
1200 lo = depthIf.getResult(0);
1201 hi = depthIf.getResult(1);
1202 scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
1203
1204 builder.setInsertionPointAfter(lenIf);
1205 lo = lenIf.getResult(0);
1206 hi = lenIf.getResult(1);
1207 } else {
1208 std::tie(lo, hi) =
1209 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1210 }
1211
1212 // New [lo, hi) for the next while-loop iteration.
1213 scf::YieldOp::create(builder, loc, ValueRange{lo, hi});
1214
1215 // After the while-loop.
1216 builder.setInsertionPointAfter(whileOp);
1217 func::ReturnOp::create(builder, loc);
1218}
1219
1220/// Implements the rewriting for operator sort and sort_coo.
1221template <typename OpTy>
1222static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys,
1223 AffineMap xPerm, uint64_t ny,
1224 PatternRewriter &rewriter) {
1225 Location loc = op.getLoc();
1226 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1227
1228 // Convert `values` to have dynamic shape and append them to `operands`.
1229 for (Value v : xys) {
1230 auto mtp = getMemRefType(v);
1231 if (!mtp.isDynamicDim(0)) {
1232 auto newMtp =
1233 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1234 v = memref::CastOp::create(rewriter, loc, newMtp, v);
1235 }
1236 operands.push_back(v);
1237 }
1238
1239 auto insertPoint = op->template getParentOfType<func::FuncOp>();
1240 if (!insertPoint)
1241 return failure();
1242
1243 SmallString<32> funcName;
1244 FuncGeneratorType funcGenerator;
1245 uint32_t nTrailingP = 0;
1246 switch (op.getAlgorithm()) {
1247 case SparseTensorSortKind::HybridQuickSort: {
1249 funcGenerator = createQuickSortFunc;
1250 nTrailingP = 1;
1251 // As a heuristics, set depthLimit = 2 * log2(n).
1252 Value lo = operands[loIdx];
1253 Value hi = operands[hiIdx];
1254 Value len = arith::IndexCastOp::create(
1255 rewriter, loc, rewriter.getI64Type(),
1256 arith::SubIOp::create(rewriter, loc, hi, lo));
1257 Value depthLimit = arith::SubIOp::create(
1258 rewriter, loc, constantI64(rewriter, loc, 64),
1259 math::CountLeadingZerosOp::create(rewriter, loc, len));
1260 operands.push_back(depthLimit);
1261 break;
1262 }
1263 case SparseTensorSortKind::QuickSort:
1264 funcName = kQuickSortFuncNamePrefix;
1265 funcGenerator = createQuickSortFunc;
1266 break;
1267 case SparseTensorSortKind::InsertionSortStable:
1268 funcName = kSortStableFuncNamePrefix;
1269 funcGenerator = createSortStableFunc;
1270 break;
1271 case SparseTensorSortKind::HeapSort:
1272 funcName = kHeapSortFuncNamePrefix;
1273 funcGenerator = createHeapSortFunc;
1274 break;
1275 }
1276
1278 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1279 xPerm, ny, operands, funcGenerator, nTrailingP);
1280 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1281 return success();
1282}
1283
1284//===---------------------------------------------------------------------===//
1285// The actual sparse buffer rewriting rules.
1286//===---------------------------------------------------------------------===//
1287
1288namespace {
1289/// Sparse rewriting rule for the push_back operator.
1290struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1291public:
1292 using OpRewritePattern<PushBackOp>::OpRewritePattern;
1293 PushBackRewriter(MLIRContext *context, bool enableInit)
1294 : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1295 LogicalResult matchAndRewrite(PushBackOp op,
1296 PatternRewriter &rewriter) const override {
1297 // Rewrite push_back(buffer, value, n) to:
1298 // new_size = size(buffer) + n
1299 // if (new_size > capacity(buffer))
1300 // while new_size > new_capacity
1301 // new_capacity = new_capacity*2
1302 // new_buffer = realloc(buffer, new_capacity)
1303 // buffer = new_buffer
1304 // subBuffer = subviewof(buffer)
1305 // linalg.fill subBuffer value
1306 //
1307 // size(buffer) += n
1308 //
1309 // The capacity check is skipped when the attribute inbounds is presented.
1310 Location loc = op->getLoc();
1311 Value c0 = constantIndex(rewriter, loc, 0);
1312 Value buffer = op.getInBuffer();
1313 Value capacity = memref::DimOp::create(rewriter, loc, buffer, c0);
1314 Value size = op.getCurSize();
1315 Value value = op.getValue();
1316
1317 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1318 Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
1319 auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
1320 bool nIsOne = (nValue && nValue.value() == 1);
1321
1322 if (!op.getInbounds()) {
1323 Value cond = arith::CmpIOp::create(
1324 rewriter, loc, arith::CmpIPredicate::ugt, newSize, capacity);
1325
1326 Value c2 = constantIndex(rewriter, loc, 2);
1327 auto bufferType =
1328 MemRefType::get({ShapedType::kDynamic}, value.getType());
1329 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, bufferType, cond,
1330 /*else=*/true);
1331 // True branch.
1332 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1333 if (nIsOne) {
1334 capacity = arith::MulIOp::create(rewriter, loc, capacity, c2);
1335 } else {
1336 // Use a do-while loop to calculate the new capacity as follows:
1337 // do { new_capacity *= 2 } while (size > new_capacity)
1338 scf::WhileOp whileOp =
1339 scf::WhileOp::create(rewriter, loc, capacity.getType(), capacity);
1340
1341 // The before-region of the WhileOp.
1342 Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1343 {capacity.getType()}, {loc});
1344 rewriter.setInsertionPointToEnd(before);
1345
1346 capacity =
1347 arith::MulIOp::create(rewriter, loc, before->getArgument(0), c2);
1348 cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ugt,
1349 newSize, capacity);
1350 scf::ConditionOp::create(rewriter, loc, cond, ValueRange{capacity});
1351 // The after-region of the WhileOp.
1352 Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1353 {capacity.getType()}, {loc});
1354 rewriter.setInsertionPointToEnd(after);
1355 scf::YieldOp::create(rewriter, loc, after->getArguments());
1356
1357 rewriter.setInsertionPointAfter(whileOp);
1358 capacity = whileOp.getResult(0);
1359 }
1360
1361 Value newBuffer = memref::ReallocOp::create(rewriter, loc, bufferType,
1362 buffer, capacity);
1363 if (enableBufferInitialization) {
1364 Value fillSize =
1365 arith::SubIOp::create(rewriter, loc, capacity, newSize);
1366 Value fillValue = constantZero(rewriter, loc, value.getType());
1367 Value subBuffer = memref::SubViewOp::create(
1368 rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize},
1369 /*sizes=*/ValueRange{fillSize},
1370 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1371 linalg::FillOp::create(rewriter, loc, fillValue, subBuffer);
1372 }
1373 scf::YieldOp::create(rewriter, loc, newBuffer);
1374
1375 // False branch.
1376 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1377 scf::YieldOp::create(rewriter, loc, buffer);
1378
1379 // Prepare for adding the value to the end of the buffer.
1380 rewriter.setInsertionPointAfter(ifOp);
1381 buffer = ifOp.getResult(0);
1382 }
1383
1384 // Add the value to the end of the buffer.
1385 if (nIsOne) {
1386 memref::StoreOp::create(rewriter, loc, value, buffer, size);
1387 } else {
1388 Value subBuffer = memref::SubViewOp::create(
1389 rewriter, loc, buffer, /*offsets=*/ValueRange{size},
1390 /*sizes=*/ValueRange{n},
1391 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1392 linalg::FillOp::create(rewriter, loc, value, subBuffer);
1393 }
1394
1395 // Update the buffer size.
1396 rewriter.replaceOp(op, {buffer, newSize});
1397 return success();
1398 }
1399
1400private:
1401 bool enableBufferInitialization;
1402};
1403
1404/// Sparse rewriting rule for the sort_coo operator.
1405struct SortRewriter : public OpRewritePattern<SortOp> {
1406public:
1407 using OpRewritePattern<SortOp>::OpRewritePattern;
1408
1409 LogicalResult matchAndRewrite(SortOp op,
1410 PatternRewriter &rewriter) const override {
1411 SmallVector<Value> xys;
1412 xys.push_back(op.getXy());
1413 xys.append(op.getYs().begin(), op.getYs().end());
1414
1415 auto xPerm = op.getPermMap();
1416 uint64_t ny = 0;
1417 if (auto nyAttr = op.getNyAttr())
1418 ny = nyAttr.getInt();
1419
1420 return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1421 }
1422};
1423
1424} // namespace
1425
1426//===---------------------------------------------------------------------===//
1427// Methods that add patterns described in this file to a pattern list.
1428//===---------------------------------------------------------------------===//
1429
1431 bool enableBufferInitialization) {
1432 patterns.add<PushBackRewriter>(patterns.getContext(),
1433 enableBufferInitialization);
1434 patterns.add<SortRewriter>(patterns.getContext());
1435}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumpt...
static void forEachIJPairInXs(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static Value createInlinedCompareImplementation(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
Creates code to compare all the (xs[i], xs[j]) pairs.
static constexpr const char kQuickSortFuncNamePrefix[]
static constexpr uint64_t hiIdx
static constexpr const char kHeapSortFuncNamePrefix[]
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static constexpr const char kHybridQuickSortFuncNamePrefix[]
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to insert the 3rd element to a list of two sorted elements.
static constexpr const char kSortStableFuncNamePrefix[]
static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
Looks up a function that is appropriate for the given operands being sorted, and creates such a funct...
static constexpr uint64_t loIdx
static void forEachIJPairInAllBuffers(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is less than xs[j].
static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
Implements the rewriting for operator sort and sort_coo.
static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to perform quick sort partition on the values in the range of index [lo,...
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
Creates code to sort 5 elements.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n)
Computes (n-2)/n, assuming n has index type.
static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates code to compare whether xs[i] is equal to xs[j].
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y...
static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo],...
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [...
static std::pair< Value, Value > createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
A helper for generating code to perform quick sort.
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
Creates code to sort 3 elements.
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sor...
static constexpr const char kBinarySearchFuncNamePrefix[]
static constexpr const char kPartitionFuncNamePrefix[]
static constexpr uint64_t xStartIdx
static constexpr const char kShiftDownFuncNamePrefix[]
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to heapify the subtree with root start within the full binary tree in the range of...
static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
Creates a function to perform insertion sort on the values in the range of index [lo,...
static std::pair< Value, Value > createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step ...
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b],...
function_ref< void(OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, uint32_t)> FuncGeneratorType
static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
Creates a code block for swapping the values in index i and j for all the buffers.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
BlockArgListType getArguments()
Definition Block.h:87
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Value constantI64(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of i64 type.
Include the generated interface declarations.
void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization)
const FrozenRewritePatternSet & patterns
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.