12 #include "llvm/ADT/ArrayRef.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/MathExtras.h"
17 #include "llvm/Support/raw_ostream.h"
23 #define DEBUG_TYPE "linalg-transforms"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
39 int64_t desiredBitAlignment,
41 bool favorPredication,
42 int64_t elementalBitwidth) {
43 assert(!copySizes.empty() && copySizes.size() <= 3 &&
44 "only 1,2,3-D copies are supported for now");
46 LDBG(
"START CopyMappingInfo, favorPredication: " << favorPredication);
47 LLVM_DEBUG(llvm::interleaveComma(copySizes,
DBGS() <<
"--copy shape: ");
48 llvm::dbgs() <<
"\n";);
53 int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
54 desiredBitAlignment, copySizes.back(), elementalBitwidth);
56 LDBG(
"--greedily determined vectorSize: "
57 << desiredVectorSize <<
" elements of " << elementalBitwidth
58 <<
"b each -> " << (desiredVectorSize * elementalBitwidth)
61 status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
66 LLVM_DEBUG(llvm::interleaveComma(copySizes,
DBGS() <<
"--copy: ");
67 llvm::dbgs() <<
"\n"; llvm::interleaveComma(
69 llvm::dbgs() <<
"\n";);
71 assert(this->
numThreads.size() == copySizes.size() &&
72 "compute copy mapping expected same number of threads and copy sizes");
76 llvm::map_range(llvm::zip(copySizes, this->
numThreads), [](
auto &&pair) {
79 return llvm::divideCeilSigned(size,
numThreads);
86 llvm::to_vector(
ArrayRef(allThreadMappings)
88 LLVM_DEBUG(this->
print(
DBGS()); llvm::dbgs() <<
"\n");
91 int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
92 int64_t desiredBitAlignment, int64_t numContiguousElements,
93 int64_t elementalBitwidth) {
94 assert(kMaxVectorLoadBitWidth % elementalBitwidth == 0 &&
95 "elemental bitwidth does not divide kMaxVectorLoadBitWidth");
96 assert(desiredBitAlignment % elementalBitwidth == 0 &&
97 "elemental bitwidth does not divide desired bit alignment");
99 std::gcd(desiredBitAlignment / elementalBitwidth, numContiguousElements),
100 kMaxVectorLoadBitWidth / elementalBitwidth);
106 factors.reserve(val);
107 for (int64_t factor = 1; factor <= val; ++factor) {
108 if (val % factor != 0)
110 factors.push_back(factor);
112 factors.push_back(val);
118 for (
auto val : vals)
138 int64_t currentIndex,
139 int64_t maxNumThreads) {
140 assert(
static_cast<size_t>(currentIndex) < sizes.size() &&
141 "currentIndex out of bounds");
142 std::string indent(2 * currentIndex,
'-');
143 if (
static_cast<size_t>(currentIndex) == sizes.size() - 1) {
144 LDBG(indent <<
"mandated globalBest: " << sizes[currentIndex]);
149 int64_t s = sizes[currentIndex];
152 localThreadsPerDim.reserve(sizes.size());
153 LDBG(indent <<
"maximizeNumThreads in " << s
154 <<
" with limit: " << maxNumThreads);
155 for (
auto factor : factors) {
156 auto nestedThreadsPerDim =
158 int64_t localBest = factor *
product(nestedThreadsPerDim);
159 if (localBest > best && localBest <= maxNumThreads) {
160 LDBG(indent <<
"new localBest: " << localBest);
162 llvm::interleaveComma(nestedThreadsPerDim,
163 DBGS() << indent <<
"nestedThreadsPerDim: ");
164 llvm::dbgs() <<
"\n";);
165 localThreadsPerDim.clear();
166 localThreadsPerDim.push_back(factor);
167 llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
172 LDBG(indent <<
"found globalBest: " << best);
173 LLVM_DEBUG(llvm::interleaveComma(localThreadsPerDim,
174 DBGS() << indent <<
"numThreads: ");
175 llvm::dbgs() <<
"\n";);
177 return localThreadsPerDim;
181 transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
183 int64_t desiredVectorSize,
184 bool favorPredication) {
186 if (!favorPredication) {
187 int64_t localVectorSize = desiredVectorSize;
188 for (; localVectorSize >= 1; localVectorSize /= 2) {
197 inferNumThreadsImpl(totalNumThreads, sizes, localVectorSize);
198 if (status == Status::Success || status == Status::Invalid)
201 LDBG(
"requires predication, try reducing vector size to "
202 << (localVectorSize / 2));
209 return inferNumThreadsImpl(totalNumThreads, sizes, desiredVectorSize);
213 transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
215 int64_t desiredVectorSize) {
216 assert(sizes.back() % desiredVectorSize == 0 &&
217 "most-minor size not divisible by actualVectorSize");
219 LDBG(
"inferNumThreadsImpl with totalNumThreads: "
220 << totalNumThreads <<
" and vectorSize: " << desiredVectorSize);
226 scaledSizes.back() /= desiredVectorSize;
227 if (scaledSizes.back() > totalNumThreads) {
228 LDBG(
"--Too few threads given the required vector size -> FAIL");
229 return Status::Invalid;
234 LLVM_DEBUG(llvm::interleaveComma(inferredNumThreads,
235 DBGS() <<
"inferred numThreads: ");
236 llvm::dbgs() <<
"\n";
237 LDBG(
"computed actualVectorSize: " << desiredVectorSize););
242 int64_t totalNumThreadsUsed =
product(inferredNumThreads);
243 LDBG(
"--totalNumThreadsUsed: " << totalNumThreadsUsed);
244 if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
245 LDBG(
"--Too few threads given the required vector size -> FAIL");
246 return Status::Invalid;
249 this->vectorSize = desiredVectorSize;
250 this->numThreads = inferredNumThreads;
251 if (totalNumThreadsUsed == totalNumThreads)
252 return Status::Success;
254 return Status::RequiresPredication;
258 os <<
"MappingInfo{";
259 os <<
"CopyMappingInfo: ";
260 os <<
"valid: " << (status != Status::Invalid) <<
", ";
261 os <<
"vectorSize: " << vectorSize <<
", ";
262 llvm::interleaveComma(numThreads, os <<
", numThreads: {");
263 llvm::interleaveComma(smallestBoundingTileSizes,
264 os <<
"}, smallestBoundingTileSizes: {");
265 llvm::interleaveComma(threadMapping, os <<
"}, threadMapping: {");
static SmallVector< int64_t > maximizeNumThreads(ArrayRef< int64_t > sizes, int64_t currentIndex, int64_t maxNumThreads)
Extract result from sizes with the following constraints:
static Attribute linearId1(MLIRContext *ctx)
static int64_t product(ArrayRef< int64_t > vals)
static SmallVector< int64_t > getFactors(int64_t val)
Get the list of all factors that divide val, not just the prime factors.
static Attribute linearId0(MLIRContext *ctx)
static Attribute linearId2(MLIRContext *ctx)
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...