12 #include "llvm/ADT/ArrayRef.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/DebugLog.h"
16 #include "llvm/Support/InterleavedRange.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "linalg-transforms"
38 int64_t desiredBitAlignment,
40 bool favorPredication,
41 int64_t elementalBitwidth) {
42 assert(!copySizes.empty() && copySizes.size() <= 3 &&
43 "only 1,2,3-D copies are supported for now");
45 LDBG() <<
"START CopyMappingInfo, favorPredication: " << favorPredication;
46 LDBG() <<
"--copy shape: " << llvm::interleaved(copySizes);
51 int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
52 desiredBitAlignment, copySizes.back(), elementalBitwidth);
54 LDBG() <<
"--greedily determined vectorSize: " << desiredVectorSize
55 <<
" elements of " << elementalBitwidth <<
"b each -> "
56 << (desiredVectorSize * elementalBitwidth)
59 status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
64 LDBG() <<
"--copy: " << llvm::interleaved(copySizes) <<
"\n"
65 <<
"--numThreads: " << llvm::interleaved(this->
numThreads) <<
"\n"
67 assert(this->
numThreads.size() == copySizes.size() &&
68 "compute copy mapping expected same number of threads and copy sizes");
72 llvm::map_range(llvm::zip(copySizes, this->
numThreads), [](
auto &&pair) {
75 return llvm::divideCeilSigned(size,
numThreads);
82 llvm::to_vector(
ArrayRef(allThreadMappings)
87 int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
88 int64_t desiredBitAlignment, int64_t numContiguousElements,
89 int64_t elementalBitwidth) {
90 assert(kMaxVectorLoadBitWidth % elementalBitwidth == 0 &&
91 "elemental bitwidth does not divide kMaxVectorLoadBitWidth");
92 assert(desiredBitAlignment % elementalBitwidth == 0 &&
93 "elemental bitwidth does not divide desired bit alignment");
95 std::gcd(desiredBitAlignment / elementalBitwidth, numContiguousElements),
96 kMaxVectorLoadBitWidth / elementalBitwidth);
102 factors.reserve(val);
103 for (int64_t factor = 1; factor <= val; ++factor) {
104 if (val % factor != 0)
106 factors.push_back(factor);
108 factors.push_back(val);
114 for (
auto val : vals)
134 int64_t currentIndex,
135 int64_t maxNumThreads) {
136 assert(
static_cast<size_t>(currentIndex) < sizes.size() &&
137 "currentIndex out of bounds");
138 std::string indent(2 * currentIndex,
'-');
139 if (
static_cast<size_t>(currentIndex) == sizes.size() - 1) {
140 LDBG() << indent <<
"mandated globalBest: " << sizes[currentIndex];
145 int64_t s = sizes[currentIndex];
148 localThreadsPerDim.reserve(sizes.size());
149 LDBG() << indent <<
"maximizeNumThreads in " << s
150 <<
" with limit: " << maxNumThreads;
151 for (
auto factor : factors) {
152 auto nestedThreadsPerDim =
154 int64_t localBest = factor *
product(nestedThreadsPerDim);
155 if (localBest > best && localBest <= maxNumThreads) {
156 LDBG() << indent <<
"new localBest: " << localBest;
157 LDBG() << indent <<
"nestedThreadsPerDim: "
158 << llvm::interleaved(nestedThreadsPerDim);
159 localThreadsPerDim.clear();
160 localThreadsPerDim.push_back(factor);
161 llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
166 LDBG() << indent <<
"found globalBest: " << best;
167 LDBG() << indent <<
"numThreads: " << llvm::interleaved(localThreadsPerDim);
168 return localThreadsPerDim;
172 transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
174 int64_t desiredVectorSize,
175 bool favorPredication) {
177 if (!favorPredication) {
178 int64_t localVectorSize = desiredVectorSize;
179 for (; localVectorSize >= 1; localVectorSize /= 2) {
188 inferNumThreadsImpl(totalNumThreads, sizes, localVectorSize);
189 if (status == Status::Success || status == Status::Invalid)
192 LDBG() <<
"requires predication, try reducing vector size to "
193 << (localVectorSize / 2);
200 return inferNumThreadsImpl(totalNumThreads, sizes, desiredVectorSize);
204 transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
206 int64_t desiredVectorSize) {
207 assert(sizes.back() % desiredVectorSize == 0 &&
208 "most-minor size not divisible by actualVectorSize");
210 LDBG() <<
"inferNumThreadsImpl with totalNumThreads: " << totalNumThreads
211 <<
" and vectorSize: " << desiredVectorSize;
217 scaledSizes.back() /= desiredVectorSize;
218 if (scaledSizes.back() > totalNumThreads) {
219 LDBG() <<
"--Too few threads given the required vector size -> FAIL";
220 return Status::Invalid;
225 LDBG() <<
"inferred numThreads: " << llvm::interleaved(inferredNumThreads);
226 LDBG() <<
"computed actualVectorSize: " << desiredVectorSize;
231 int64_t totalNumThreadsUsed =
product(inferredNumThreads);
232 LDBG() <<
"--totalNumThreadsUsed: " << totalNumThreadsUsed;
233 if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
234 LDBG() <<
"--Too few threads given the required vector size -> FAIL";
235 return Status::Invalid;
238 this->vectorSize = desiredVectorSize;
239 this->numThreads = inferredNumThreads;
240 if (totalNumThreadsUsed == totalNumThreads)
241 return Status::Success;
243 return Status::RequiresPredication;
248 <<
"CopyMappingInfo: " <<
"valid: " << (status != Status::Invalid) <<
", "
249 <<
"vectorSize: " << vectorSize <<
", numThreads: {"
250 << llvm::interleaved(numThreads) <<
"}, smallestBoundingTileSizes: {"
251 << llvm::interleaved(smallestBoundingTileSizes) <<
"}, threadMapping: {"
252 << llvm::interleaved(threadMapping) <<
"}}";
static SmallVector< int64_t > maximizeNumThreads(ArrayRef< int64_t > sizes, int64_t currentIndex, int64_t maxNumThreads)
Extract result from sizes with the following constraints:
static Attribute linearId1(MLIRContext *ctx)
static int64_t product(ArrayRef< int64_t > vals)
static SmallVector< int64_t > getFactors(int64_t val)
Get the list of all factors that divide val, not just the prime factors.
static Attribute linearId0(MLIRContext *ctx)
static Attribute linearId2(MLIRContext *ctx)
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...