MLIR
20.0.0git
|
#include "Utils/CodegenUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Support/LLVM.h"
Go to the source code of this file.
Typedefs | |
using | FuncGeneratorType = function_ref< void(OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, uint32_t)> |
Functions | |
static void | getMangledSortHelperFuncName (llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands) |
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort <namePrefix><xPerm>_<x type>coo<ny>_<y0 type>..._<yn type> for sort_coo. More... | |
static FlatSymbolRefAttr | getMangledSortHelperFunc (OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP=0) |
Looks up a function that is appropriate for the given operands being sorted, and creates such a function if it doesn't exist yet. More... | |
static void | forEachIJPairInXs (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder) |
Creates a code block to process each pair of (xs[i], xs[j]) for sorting. More... | |
static void | forEachIJPairInAllBuffers (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< void(uint64_t, Value, Value, Value)> bodyBuilder) |
Creates a code block to process each pair of (xys[i], xys[j]) for sorting. More... | |
static void | createSwap (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny) |
Creates a code block for swapping the values in index i and j for all the buffers. More... | |
static Value | createInlinedCompareImplementation (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref< Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder) |
Creates code to compare all the (xs[i], xs[j]) pairs. More... | |
static Value | createEqCompare (OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) |
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison. More... | |
static Value | createInlinedEqCompare (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0) |
Creates code to compare whether xs[i] is equal to xs[j]. More... | |
static Value | createLessThanCompare (OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) |
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison. More... | |
static Value | createInlinedLessThan (OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0) |
Creates code to compare whether xs[i] is less than xs[j]. More... | |
static void | createBinarySearchFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0) |
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sorted values xs[lo..hi). More... | |
static std::pair< Value, Value > | createScanLoop (OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step) |
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step > 0) or while (xs[i] > xs[p]) i += step (step < 0) The routine returns i as well as a boolean value to indicate whether xs[i] == xs[p]. More... | |
static scf::IfOp | createCompareThenSwap (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value a, Value b) |
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b], data[a]) returns true. More... | |
static void | createInsert3rd (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2) |
Creates code to insert the 3rd element to a list of two sorted elements. More... | |
static void | createSort3 (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2) |
Creates code to sort 3 elements. More... | |
static void | createSort5 (OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl< Value > &swapOperands, SmallVectorImpl< Value > &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4) |
Creates code to sort 5 elements. More... | |
static void | createChoosePivot (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args) |
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo], data[mi] and data[hi] are sorted in non-decreasing values. More... | |
static void | createPartitionFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP=0) |
Creates a function to perform quick sort partition on the values in the range of index [lo, hi), assuming lo < hi. More... | |
static Value | createSubTwoDividedByTwo (OpBuilder &builder, Location loc, Value n) |
Computes (n-2)/n, assuming n has index type. More... | |
static void | createShiftDownFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) |
Creates a function to heapify the subtree with root start within the full binary tree in the range of index [first, first + n). More... | |
static void | createHeapSortFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) |
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumption hi - lo >= 2. More... | |
static std::pair< Value, Value > | createQuickSort (OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) |
A helper for generating code to perform quick sort. More... | |
static void | createSortStableFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) |
Creates a function to perform insertion sort on the values in the range of index [lo, hi). More... | |
static void | createQuickSortFunc (OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) |
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [lo, hi). More... | |
template<typename OpTy > | |
LogicalResult | matchAndRewriteSortOp (OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter) |
Implements the rewriting for operator sort and sort_coo. More... | |
Variables | |
static constexpr uint64_t | loIdx = 0 |
static constexpr uint64_t | hiIdx = 1 |
static constexpr uint64_t | xStartIdx = 2 |
static constexpr const char | kPartitionFuncNamePrefix [] = "_sparse_partition_" |
static constexpr const char | kBinarySearchFuncNamePrefix [] |
static constexpr const char | kHybridQuickSortFuncNamePrefix [] |
static constexpr const char | kSortStableFuncNamePrefix [] |
static constexpr const char | kShiftDownFuncNamePrefix [] = "_sparse_shift_down_" |
static constexpr const char | kHeapSortFuncNamePrefix [] = "_sparse_heap_sort_" |
static constexpr const char | kQuickSortFuncNamePrefix [] = "_sparse_qsort_" |
using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, uint32_t)> |
Definition at line 48 of file SparseBufferRewriting.cpp.
|
static |
Creates a function to use a binary search to find the insertion point for inserting xs[hi] to the sorted values xs[lo..hi).
Definition at line 326 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), mlir::OpBuilder::createBlock(), createInlinedLessThan(), mlir::Block::getArguments(), mlir::Value::getType(), hiIdx, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToEnd(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createSortStableFunc().
|
static |
Creates a code block to swap the values in indices lo, mi, and hi so that data[lo], data[mi] and data[hi] are sorted in non-decreasing values.
When the number of values in range [lo, hi) is more than a threshold, we also include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
Definition at line 509 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createSort3(), createSort5(), mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createPartitionFunc().
|
static |
Creates and returns an IfOp to compare two elements and swap the elements if compareFunc(data[b], data[a]) returns true.
The new insertion point is right after the swap instructions.
Definition at line 434 of file SparseBufferRewriting.cpp.
References mlir::OpBuilder::create(), createInlinedLessThan(), createSwap(), and mlir::OpBuilder::setInsertionPointToStart().
Referenced by createInsert3rd(), createSort3(), and createSort5().
|
static |
Generates code to compare whether x[i] is equal to x[j] and returns the result of the comparison.
Definition at line 208 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantI1(), mlir::OpBuilder::create(), mlir::Builder::getIntegerType(), and mlir::OpBuilder::setInsertionPointToStart().
Referenced by createInlinedEqCompare().
|
static |
Creates a function to perform heap sort on the values in the range of index [lo, hi) with the assumption hi - lo >= 2.
Definition at line 859 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createShiftDownFunc(), createSubTwoDividedByTwo(), createSwap(), mlir::Block::getArguments(), getMangledSortHelperFunc(), hiIdx, kShiftDownFuncNamePrefix, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().
|
static |
Creates code to compare all the (xs[i], xs[j]) pairs.
The method to compare each pair is create via compareBuilder
.
Definition at line 179 of file SparseBufferRewriting.cpp.
References mlir::OpBuilder::create(), forEachIJPairInXs(), mlir::Value::getDefiningOp(), mlir::AffineMap::getNumResults(), mlir::OpBuilder::setInsertionPointAfter(), and mlir::OpBuilder::setInsertionPointAfterValue().
Referenced by createInlinedEqCompare(), and createInlinedLessThan().
|
static |
Creates code to compare whether xs[i] is equal to xs[j].
Definition at line 249 of file SparseBufferRewriting.cpp.
References createEqCompare(), and createInlinedCompareImplementation().
Referenced by createScanLoop().
|
static |
Creates code to compare whether xs[i] is less than xs[j].
Definition at line 303 of file SparseBufferRewriting.cpp.
References createInlinedCompareImplementation(), and createLessThanCompare().
Referenced by createBinarySearchFunc(), createCompareThenSwap(), and createScanLoop().
|
static |
Creates code to insert the 3rd element to a list of two sorted elements.
Definition at line 452 of file SparseBufferRewriting.cpp.
References createCompareThenSwap(), and mlir::OpBuilder::setInsertionPointAfter().
Referenced by createSort3(), and createSort5().
|
static |
Generates code to compare whether x[i] is less than x[j] and returns the result of the comparison.
Definition at line 261 of file SparseBufferRewriting.cpp.
References mlir::OpBuilder::create(), mlir::Builder::getIntegerType(), and mlir::OpBuilder::setInsertionPointToStart().
Referenced by createInlinedLessThan().
|
static |
Creates a function to perform quick sort partition on the values in the range of index [lo, hi), assuming lo < hi.
Definition at line 577 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantI1(), mlir::sparse_tensor::constantIndex(), mlir::sparse_tensor::constantOne(), mlir::OpBuilder::create(), mlir::OpBuilder::createBlock(), createChoosePivot(), createScanLoop(), createSwap(), mlir::Block::getArguments(), mlir::Value::getType(), hiIdx, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToEnd(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createQuickSort().
|
static |
A helper for generating code to perform quick sort.
It partitions [lo, hi), recursively calls quick sort to process the smaller partition and returns the bigger partition to be processed by the enclosed while-loop.
Definition at line 918 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createPartitionFunc(), mlir::get(), getMangledSortHelperFunc(), mlir::Value::getType(), hiIdx, kPartitionFuncNamePrefix, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createQuickSortFunc().
|
static |
Creates a function to perform quick sort or a hybrid quick sort on the values in the range of index [lo, hi).
Definition at line 1113 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantI64(), mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), mlir::OpBuilder::createBlock(), createHeapSortFunc(), createQuickSort(), createSortStableFunc(), mlir::Block::getArguments(), getMangledSortHelperFunc(), mlir::Value::getType(), hiIdx, kHeapSortFuncNamePrefix, kSortStableFuncNamePrefix, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToEnd(), and mlir::OpBuilder::setInsertionPointToStart().
Referenced by matchAndRewriteSortOp().
|
static |
Creates code to advance i in a loop based on xs[p] as follows: while (xs[i] < xs[p]) i += step (step > 0) or while (xs[i] > xs[p]) i += step (step < 0) The routine returns i as well as a boolean value to indicate whether xs[i] == xs[p].
Definition at line 389 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), mlir::OpBuilder::createBlock(), createInlinedEqCompare(), createInlinedLessThan(), mlir::Value::getType(), mlir::OpBuilder::setInsertionPointAfter(), and mlir::OpBuilder::setInsertionPointToEnd().
Referenced by createPartitionFunc().
|
static |
Creates a function to heapify the subtree with root start
within the full binary tree in the range of index [first, first + n).
Definition at line 728 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createSubTwoDividedByTwo(), mlir::Block::getArguments(), hiIdx, loIdx, mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createHeapSortFunc().
|
static |
Creates code to sort 3 elements.
Definition at line 464 of file SparseBufferRewriting.cpp.
References createCompareThenSwap(), createInsert3rd(), and mlir::OpBuilder::setInsertionPointAfter().
Referenced by createChoosePivot(), and createSort5().
|
static |
Creates code to sort 5 elements.
Definition at line 479 of file SparseBufferRewriting.cpp.
References createCompareThenSwap(), createInsert3rd(), createSort3(), and mlir::OpBuilder::setInsertionPointAfter().
Referenced by createChoosePivot().
|
static |
Creates a function to perform insertion sort on the values in the range of index [lo, hi).
Definition at line 998 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createBinarySearchFunc(), forEachIJPairInAllBuffers(), mlir::get(), mlir::Block::getArguments(), getMangledSortHelperFunc(), mlir::Value::getType(), hiIdx, kBinarySearchFuncNamePrefix, loIdx, mlir::OpBuilder::setInsertionPointAfter(), mlir::OpBuilder::setInsertionPointToStart(), and xStartIdx.
Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().
Computes (n-2)/n, assuming n has index type.
Definition at line 690 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), and mlir::OpBuilder::create().
Referenced by createHeapSortFunc(), and createShiftDownFunc().
|
static |
Creates a code block for swapping the values in index i and j for all the buffers.
Definition at line 165 of file SparseBufferRewriting.cpp.
References mlir::OpBuilder::create(), and forEachIJPairInAllBuffers().
Referenced by createCompareThenSwap(), createHeapSortFunc(), and createPartitionFunc().
|
static |
Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
The code to process the value pairs is generated by bodyBuilder
.
Definition at line 129 of file SparseBufferRewriting.cpp.
References mlir::detail::enumerate(), forEachIJPairInXs(), mlir::AffineMap::get(), mlir::Builder::getAffineDimExpr(), mlir::Builder::getContext(), mlir::AffineMap::getNumResults(), mlir::AffineMap::getResults(), mlir::AffineMap::isPermutation(), and xStartIdx.
Referenced by createSortStableFunc(), and createSwap().
|
static |
Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
The code to process the value pairs is generated by bodyBuilder
.
Definition at line 109 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), mlir::AffineMap::getNumResults(), mlir::AffineMap::getResult(), and xStartIdx.
Referenced by createInlinedCompareImplementation(), and forEachIJPairInAllBuffers().
|
static |
Looks up a function that is appropriate for the given operands being sorted, and creates such a function if it doesn't exist yet.
The parameters xPerm
and ny
tell the number of x and y values provided by the buffer in xStartIdx.
Definition at line 78 of file SparseBufferRewriting.cpp.
References mlir::OpBuilder::create(), mlir::get(), getMangledSortHelperFuncName(), mlir::ValueRange::getTypes(), and mlir::OpBuilder::setInsertionPoint().
Referenced by createHeapSortFunc(), createQuickSort(), createQuickSortFunc(), createSortStableFunc(), and matchAndRewriteSortOp().
|
static |
Constructs a function name with this format to facilitate quick sort: <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort <namePrefix><xPerm>_<x type>coo<ny>_<y0 type>..._<yn type> for sort_coo.
Definition at line 54 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::getMemRefType(), mlir::AffineMap::getResults(), and xStartIdx.
Referenced by getMangledSortHelperFunc().
LogicalResult matchAndRewriteSortOp | ( | OpTy | op, |
ValueRange | xys, | ||
AffineMap | xPerm, | ||
uint64_t | ny, | ||
PatternRewriter & | rewriter | ||
) |
Implements the rewriting for operator sort and sort_coo.
Definition at line 1220 of file SparseBufferRewriting.cpp.
References mlir::sparse_tensor::constantI64(), mlir::sparse_tensor::constantIndex(), mlir::OpBuilder::create(), createHeapSortFunc(), createQuickSortFunc(), createSortStableFunc(), mlir::get(), mlir::Builder::getI64Type(), getMangledSortHelperFunc(), mlir::sparse_tensor::getMemRefType(), hiIdx, kHeapSortFuncNamePrefix, kHybridQuickSortFuncNamePrefix, kQuickSortFuncNamePrefix, kSortStableFuncNamePrefix, loIdx, and mlir::RewriterBase::replaceOpWithNewOp().
|
staticconstexpr |
Definition at line 34 of file SparseBufferRewriting.cpp.
Referenced by createBinarySearchFunc(), createHeapSortFunc(), createPartitionFunc(), createQuickSort(), createQuickSortFunc(), createShiftDownFunc(), createSortStableFunc(), and matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 38 of file SparseBufferRewriting.cpp.
Referenced by createSortStableFunc().
|
staticconstexpr |
Definition at line 45 of file SparseBufferRewriting.cpp.
Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 40 of file SparseBufferRewriting.cpp.
Referenced by matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 37 of file SparseBufferRewriting.cpp.
Referenced by createQuickSort().
|
staticconstexpr |
Definition at line 46 of file SparseBufferRewriting.cpp.
Referenced by matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 44 of file SparseBufferRewriting.cpp.
Referenced by createHeapSortFunc().
|
staticconstexpr |
Definition at line 42 of file SparseBufferRewriting.cpp.
Referenced by createQuickSortFunc(), and matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 33 of file SparseBufferRewriting.cpp.
Referenced by createBinarySearchFunc(), createHeapSortFunc(), createPartitionFunc(), createQuickSort(), createQuickSortFunc(), createShiftDownFunc(), createSortStableFunc(), and matchAndRewriteSortOp().
|
staticconstexpr |
Definition at line 35 of file SparseBufferRewriting.cpp.
Referenced by createBinarySearchFunc(), createChoosePivot(), createHeapSortFunc(), createPartitionFunc(), createQuickSort(), createShiftDownFunc(), createSortStableFunc(), forEachIJPairInAllBuffers(), forEachIJPairInXs(), and getMangledSortHelperFuncName().