MLIR 22.0.0git
SparseBufferRewriting.cpp File Reference

Go to the source code of this file.

Typedefs

using FuncGeneratorType

Functions

static void getMangledSortHelperFuncName (llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands)
 Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort <namePrefix><xPerm>_<x type>coo<ny>_<y0 type>..._<yn type> for sort_coo.
static FlatSymbolRefAttr getMangledSortHelperFunc (OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0)
 Looks up a function that is appropriate for the given operands being sorted, and creates such a function if it doesn't exist yet.
static void forEachIJPairInXs (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
 Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
static void forEachIJPairInAllBuffers (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder)
 Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
static void createSwap (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny)
 Creates a code block for swapping the values in index i and j for all the buffers.
static Value createInlinedCompareImplementation (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder)
 Creates code to compare all the (xs[i], xs[j]) pairs.
static Value createEqCompare (OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
 Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
static Value createInlinedEqCompare (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
 Creates code to compare whether xs[i] is equal to xs[j].
static Value createLessThanCompare (OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim)
 Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
static Value createInlinedLessThan (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
 Creates code to compare whether xs[i] is less than xs[j].
static void createBinarySearchFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
 Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sorted values xs[lo..hi).
static std::pair< Value, ValuecreateScanLoop (OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step)
 Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step > 0) or while (xs[i] > xs[p]) i += step (step < 0) The routine returns i as well as a boolean value to indicate whether xs[i] == xs[p].
static scf::IfOp createCompareThenSwap (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b)
 Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b], data[a]) returns true.
static void createInsert3rd (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
 Creates code to insert the 3rd element to a list of two sorted elements.
static void createSort3 (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2)
 Creates code to sort 3 elements.
static void createSort5 (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4)
 Creates code to sort 5 elements.
static void createChoosePivot (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args)
 Creates a code block to swap the values in indices lo, mi, and hi so that data[lo], data[mi] and data[hi] are sorted in non-decreasing values.
static void createPartitionFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0)
 Creates a function to perform quick sort partition on the values in the range of index [lo, hi), assuming lo < hi.
static Value createSubTwoDividedByTwo (OpBuilder &builder, Location loc, Value n)
 Computes (n-2)/n, assuming n has index type.
static void createShiftDownFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
 Creates a function to heapify the subtree with root start within the full binary tree in the range of index [first, first + n).
static void createHeapSortFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
 Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumption hi - lo >= 2.
static std::pair< Value, ValuecreateQuickSort (OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
 A helper for generating code to perform quick sort.
static void createSortStableFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
 Creates a function to perform insertion sort on the values in the range of index [lo, hi).
static void createQuickSortFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP)
 Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [lo, hi).
template<typename OpTy>
static LogicalResult matchAndRewriteSortOp (OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter)
 Implements the rewriting for operator sort and sort_coo.

Variables

static constexpr uint64_t loIdx = 0
static constexpr uint64_t hiIdx = 1
static constexpr uint64_t xStartIdx = 2
static constexpr const char kPartitionFuncNamePrefix [] = "_sparse_partition_"
static constexpr const char kBinarySearchFuncNamePrefix []
static constexpr const char kHybridQuickSortFuncNamePrefix []
static constexpr const char kSortStableFuncNamePrefix []
static constexpr const char kShiftDownFuncNamePrefix [] = "_sparse_shift_down_"
static constexpr const char kHeapSortFuncNamePrefix [] = "_sparse_heap_sort_"
static constexpr const char kQuickSortFuncNamePrefix [] = "_sparse_qsort_"

Typedef Documentation

◆ FuncGeneratorType

Initial value:
function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
AffineMap, uint64_t, uint32_t)>
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
This class helps build Operations.
Definition Builders.h:207
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152

Definition at line 48 of file SparseBufferRewriting.cpp.

Function Documentation

◆ createBinarySearchFunc()

void createBinarySearchFunc ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP = 0 )
static

◆ createChoosePivot()

void createChoosePivot ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
Value lo,
Value hi,
Value mi,
ValueRange args )
static

Creates a code block to swap the values in indices lo, mi, and hi so that data[lo], data[mi] and data[hi] are sorted in non-decreasing values.

When the number of values in range [lo, hi) is more than a threshold, we also include the middle of [lo, mi) and [mi, hi) and sort a total of five values.

Definition at line 510 of file SparseBufferRewriting.cpp.

References b, mlir::sparse_tensor::constantIndex(), createSort3(), createSort5(), mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.

Referenced by createPartitionFunc().

◆ createCompareThenSwap()

scf::IfOp createCompareThenSwap ( OpBuilder & builder,
Location loc,
AffineMap xPerm,
uint64_t ny,
SmallVectorImpl< Value > & swapOperands,
SmallVectorImpl< Value > & compareOperands,
Value a,
Value b )
static

Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b], data[a]) returns true.

The new insertion point is right after the swap instructions.

Definition at line 435 of file SparseBufferRewriting.cpp.

References b, createInlinedLessThan(), createSwap(), and mlir::OpBuilder::setInsertionPointToStart().

Referenced by createInsert3rd(), createSort3(), and createSort5().

◆ createEqCompare()

Value createEqCompare ( OpBuilder & builder,
Location loc,
Value i,
Value j,
Value x,
bool isFirstDim,
bool isLastDim )
static

Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.

Definition at line 208 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::constantI1(), mlir::Builder::getIntegerType(), and mlir::OpBuilder::setInsertionPointToStart().

Referenced by createInlinedEqCompare().

◆ createHeapSortFunc()

void createHeapSortFunc ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP )
static

◆ createInlinedCompareImplementation()

Value createInlinedCompareImplementation ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder )
static

Creates code to compare all the (xs[i], xs[j]) pairs.

The method to compare each pair is create via compareBuilder.

Definition at line 179 of file SparseBufferRewriting.cpp.

References forEachIJPairInXs(), mlir::Value::getDefiningOp(), mlir::AffineMap::getNumResults(), result, mlir::OpBuilder::setInsertionPointAfter(), and mlir::OpBuilder::setInsertionPointAfterValue().

Referenced by createInlinedEqCompare(), and createInlinedLessThan().

◆ createInlinedEqCompare()

Value createInlinedEqCompare ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP = 0 )
static

Creates code to compare whether xs[i] is equal to xs[j].

Definition at line 249 of file SparseBufferRewriting.cpp.

References createEqCompare(), and createInlinedCompareImplementation().

Referenced by createScanLoop().

◆ createInlinedLessThan()

Value createInlinedLessThan ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP = 0 )
static

Creates code to compare whether xs[i] is less than xs[j].

Definition at line 304 of file SparseBufferRewriting.cpp.

References createInlinedCompareImplementation(), and createLessThanCompare().

Referenced by createBinarySearchFunc(), createCompareThenSwap(), createScanLoop(), and createShiftDownFunc().

◆ createInsert3rd()

void createInsert3rd ( OpBuilder & builder,
Location loc,
AffineMap xPerm,
uint64_t ny,
SmallVectorImpl< Value > & swapOperands,
SmallVectorImpl< Value > & compareOperands,
Value v0,
Value v1,
Value v2 )
static

Creates code to insert the 3rd element to a list of two sorted elements.

Definition at line 453 of file SparseBufferRewriting.cpp.

References createCompareThenSwap(), and mlir::OpBuilder::setInsertionPointAfter().

Referenced by createSort3(), and createSort5().

◆ createLessThanCompare()

Value createLessThanCompare ( OpBuilder & builder,
Location loc,
Value i,
Value j,
Value x,
bool isFirstDim,
bool isLastDim )
static

Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.

Definition at line 261 of file SparseBufferRewriting.cpp.

References mlir::Builder::getIntegerType(), and mlir::OpBuilder::setInsertionPointToStart().

Referenced by createInlinedLessThan().

◆ createPartitionFunc()

void createPartitionFunc ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP = 0 )
static

◆ createQuickSort()

std::pair< Value, Value > createQuickSort ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP )
static

A helper for generating code to perform quick sort.

It partitions [lo, hi), recursively calls quick sort to process the smaller partition and returns the bigger partition to be processed by the enclosed while-loop.

Definition at line 922 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::constantIndex(), createPartitionFunc(), getMangledSortHelperFunc(), mlir::Value::getType(), hiIdx, kPartitionFuncNamePrefix, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.

Referenced by createQuickSortFunc().

◆ createQuickSortFunc()

◆ createScanLoop()

std::pair< Value, Value > createScanLoop ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
ValueRange xs,
Value i,
Value p,
AffineMap xPerm,
uint64_t ny,
int step )
static

Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step > 0) or while (xs[i] > xs[p]) i += step (step < 0) The routine returns i as well as a boolean value to indicate whether xs[i] == xs[p].

Definition at line 390 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::createBlock(), createInlinedEqCompare(), createInlinedLessThan(), mlir::Block::getArgument(), mlir::Value::getType(), mlir::OpBuilder::setInsertionPointAfter(), and mlir::OpBuilder::setInsertionPointToEnd().

Referenced by createPartitionFunc().

◆ createShiftDownFunc()

void createShiftDownFunc ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP )
static

◆ createSort3()

void createSort3 ( OpBuilder & builder,
Location loc,
AffineMap xPerm,
uint64_t ny,
SmallVectorImpl< Value > & swapOperands,
SmallVectorImpl< Value > & compareOperands,
Value v0,
Value v1,
Value v2 )
static

Creates code to sort 3 elements.

Definition at line 465 of file SparseBufferRewriting.cpp.

References createCompareThenSwap(), createInsert3rd(), and mlir::OpBuilder::setInsertionPointAfter().

Referenced by createChoosePivot(), and createSort5().

◆ createSort5()

void createSort5 ( OpBuilder & builder,
Location loc,
AffineMap xPerm,
uint64_t ny,
SmallVectorImpl< Value > & swapOperands,
SmallVectorImpl< Value > & compareOperands,
Value v0,
Value v1,
Value v2,
Value v3,
Value v4 )
static

Creates code to sort 5 elements.

Definition at line 480 of file SparseBufferRewriting.cpp.

References createCompareThenSwap(), createInsert3rd(), createSort3(), and mlir::OpBuilder::setInsertionPointAfter().

Referenced by createChoosePivot().

◆ createSortStableFunc()

void createSortStableFunc ( OpBuilder & builder,
ModuleOp module,
func::FuncOp func,
AffineMap xPerm,
uint64_t ny,
uint32_t nTrailingP )
static

◆ createSubTwoDividedByTwo()

Value createSubTwoDividedByTwo ( OpBuilder & builder,
Location loc,
Value n )
static

Computes (n-2)/n, assuming n has index type.

Definition at line 693 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::constantIndex().

Referenced by createHeapSortFunc(), and createShiftDownFunc().

◆ createSwap()

void createSwap ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny )
static

Creates a code block for swapping the values in index i and j for all the buffers.

Definition at line 165 of file SparseBufferRewriting.cpp.

References forEachIJPairInAllBuffers().

Referenced by createCompareThenSwap(), createHeapSortFunc(), createPartitionFunc(), and createShiftDownFunc().

◆ forEachIJPairInAllBuffers()

void forEachIJPairInAllBuffers ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder )
static

Creates a code block to process each pair of (xys[i], xys[j]) for sorting.

The code to process the value pairs is generated by bodyBuilder.

Definition at line 129 of file SparseBufferRewriting.cpp.

References forEachIJPairInXs(), mlir::AffineMap::get(), mlir::Builder::getAffineDimExpr(), mlir::Builder::getContext(), mlir::AffineMap::getNumResults(), mlir::AffineMap::getResults(), mlir::AffineMap::isPermutation(), and xStartIdx.

Referenced by createSortStableFunc(), and createSwap().

◆ forEachIJPairInXs()

void forEachIJPairInXs ( OpBuilder & builder,
Location loc,
ValueRange args,
AffineMap xPerm,
uint64_t ny,
function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder )
static

Creates a code block to process each pair of (xs[i], xs[j]) for sorting.

The code to process the value pairs is generated by bodyBuilder.

Definition at line 109 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::constantIndex(), mlir::AffineMap::getNumResults(), mlir::AffineMap::getResult(), and xStartIdx.

Referenced by createInlinedCompareImplementation(), and forEachIJPairInAllBuffers().

◆ getMangledSortHelperFunc()

FlatSymbolRefAttr getMangledSortHelperFunc ( OpBuilder & builder,
func::FuncOp insertPoint,
TypeRange resultTypes,
StringRef namePrefix,
AffineMap xPerm,
uint64_t ny,
ValueRange operands,
FuncGeneratorType createFunc,
uint32_t nTrailingP = 0 )
static

Looks up a function that is appropriate for the given operands being sorted, and creates such a function if it doesn't exist yet.

The parameters xPerm and ny tell the number of x and y values provided by the buffer in xStartIdx.

Definition at line 78 of file SparseBufferRewriting.cpp.

References getMangledSortHelperFuncName(), mlir::ValueRange::getTypes(), result, and mlir::OpBuilder::setInsertionPoint().

Referenced by createHeapSortFunc(), createQuickSort(), createQuickSortFunc(), createSortStableFunc(), and matchAndRewriteSortOp().

◆ getMangledSortHelperFuncName()

void getMangledSortHelperFuncName ( llvm::raw_svector_ostream & nameOstream,
StringRef namePrefix,
AffineMap xPerm,
uint64_t ny,
ValueRange operands )
static

Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort <namePrefix><xPerm>_<x type>coo<ny>_<y0 type>..._<yn type> for sort_coo.

Definition at line 54 of file SparseBufferRewriting.cpp.

References mlir::sparse_tensor::getMemRefType(), mlir::AffineMap::getResults(), and xStartIdx.

Referenced by getMangledSortHelperFunc().

◆ matchAndRewriteSortOp()

Variable Documentation

◆ hiIdx

◆ kBinarySearchFuncNamePrefix

const char kBinarySearchFuncNamePrefix[]
staticconstexpr
Initial value:
=
"_sparse_binary_search_"

Definition at line 38 of file SparseBufferRewriting.cpp.

Referenced by createSortStableFunc().

◆ kHeapSortFuncNamePrefix

const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"
staticconstexpr

Definition at line 45 of file SparseBufferRewriting.cpp.

Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().

◆ kHybridQuickSortFuncNamePrefix

const char kHybridQuickSortFuncNamePrefix[]
staticconstexpr
Initial value:
=
"_sparse_hybrid_qsort_"

Definition at line 40 of file SparseBufferRewriting.cpp.

Referenced by matchAndRewriteSortOp().

◆ kPartitionFuncNamePrefix

const char kPartitionFuncNamePrefix[] = "_sparse_partition_"
staticconstexpr

Definition at line 37 of file SparseBufferRewriting.cpp.

Referenced by createQuickSort().

◆ kQuickSortFuncNamePrefix

const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"
staticconstexpr

Definition at line 46 of file SparseBufferRewriting.cpp.

Referenced by matchAndRewriteSortOp().

◆ kShiftDownFuncNamePrefix

const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"
staticconstexpr

Definition at line 44 of file SparseBufferRewriting.cpp.

Referenced by createHeapSortFunc().

◆ kSortStableFuncNamePrefix

const char kSortStableFuncNamePrefix[]
staticconstexpr
Initial value:
=
"_sparse_sort_stable_"

Definition at line 42 of file SparseBufferRewriting.cpp.

Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().

◆ loIdx

◆ xStartIdx