12 #include "llvm/ADT/ArrayRef.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/InterleavedRange.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/raw_ostream.h"
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
40 int64_t desiredBitAlignment,
42 bool favorPredication,
43 int64_t elementalBitwidth) {
44 assert(!copySizes.empty() && copySizes.size() <= 3 &&
45 "only 1,2,3-D copies are supported for now");
47 LDBG(
"START CopyMappingInfo, favorPredication: " << favorPredication);
48 LLVM_DEBUG(
DBGS() <<
"--copy shape: " << llvm::interleaved(copySizes)
54 int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
55 desiredBitAlignment, copySizes.back(), elementalBitwidth);
57 LDBG(
"--greedily determined vectorSize: "
58 << desiredVectorSize <<
" elements of " << elementalBitwidth
59 <<
"b each -> " << (desiredVectorSize * elementalBitwidth)
62 status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
67 LLVM_DEBUG(
DBGS() <<
"--copy: " << llvm::interleaved(copySizes) <<
"\n"
68 <<
"--numThreads: " << llvm::interleaved(this->
numThreads)
70 <<
"--vectorSize: " << this->
vectorSize <<
"\n");
71 assert(this->
numThreads.size() == copySizes.size() &&
72 "compute copy mapping expected same number of threads and copy sizes");
76 llvm::map_range(llvm::zip(copySizes, this->
numThreads), [](
auto &&pair) {
79 return llvm::divideCeilSigned(size,
numThreads);
86 llvm::to_vector(
ArrayRef(allThreadMappings)
88 LLVM_DEBUG(this->
print(
DBGS()); llvm::dbgs() <<
"\n");
91 int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
92 int64_t desiredBitAlignment, int64_t numContiguousElements,
93 int64_t elementalBitwidth) {
94 assert(kMaxVectorLoadBitWidth % elementalBitwidth == 0 &&
95 "elemental bitwidth does not divide kMaxVectorLoadBitWidth");
96 assert(desiredBitAlignment % elementalBitwidth == 0 &&
97 "elemental bitwidth does not divide desired bit alignment");
99 std::gcd(desiredBitAlignment / elementalBitwidth, numContiguousElements),
100 kMaxVectorLoadBitWidth / elementalBitwidth);
106 factors.reserve(val);
107 for (int64_t factor = 1; factor <= val; ++factor) {
108 if (val % factor != 0)
110 factors.push_back(factor);
112 factors.push_back(val);
118 for (
auto val : vals)
138 int64_t currentIndex,
139 int64_t maxNumThreads) {
140 assert(
static_cast<size_t>(currentIndex) < sizes.size() &&
141 "currentIndex out of bounds");
142 std::string indent(2 * currentIndex,
'-');
143 if (
static_cast<size_t>(currentIndex) == sizes.size() - 1) {
144 LDBG(indent <<
"mandated globalBest: " << sizes[currentIndex]);
149 int64_t s = sizes[currentIndex];
152 localThreadsPerDim.reserve(sizes.size());
153 LDBG(indent <<
"maximizeNumThreads in " << s
154 <<
" with limit: " << maxNumThreads);
155 for (
auto factor : factors) {
156 auto nestedThreadsPerDim =
158 int64_t localBest = factor *
product(nestedThreadsPerDim);
159 if (localBest > best && localBest <= maxNumThreads) {
160 LDBG(indent <<
"new localBest: " << localBest);
161 LDBG(indent <<
"nestedThreadsPerDim: "
162 << llvm::interleaved(nestedThreadsPerDim));
163 localThreadsPerDim.clear();
164 localThreadsPerDim.push_back(factor);
165 llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
170 LDBG(indent <<
"found globalBest: " << best);
171 LDBG(indent <<
"numThreads: " << llvm::interleaved(localThreadsPerDim));
172 return localThreadsPerDim;
176 transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
178 int64_t desiredVectorSize,
179 bool favorPredication) {
181 if (!favorPredication) {
182 int64_t localVectorSize = desiredVectorSize;
183 for (; localVectorSize >= 1; localVectorSize /= 2) {
192 inferNumThreadsImpl(totalNumThreads, sizes, localVectorSize);
193 if (status == Status::Success || status == Status::Invalid)
196 LDBG(
"requires predication, try reducing vector size to "
197 << (localVectorSize / 2));
204 return inferNumThreadsImpl(totalNumThreads, sizes, desiredVectorSize);
208 transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
210 int64_t desiredVectorSize) {
211 assert(sizes.back() % desiredVectorSize == 0 &&
212 "most-minor size not divisible by actualVectorSize");
214 LDBG(
"inferNumThreadsImpl with totalNumThreads: "
215 << totalNumThreads <<
" and vectorSize: " << desiredVectorSize);
221 scaledSizes.back() /= desiredVectorSize;
222 if (scaledSizes.back() > totalNumThreads) {
223 LDBG(
"--Too few threads given the required vector size -> FAIL");
224 return Status::Invalid;
229 LDBG(
"inferred numThreads: " << llvm::interleaved(inferredNumThreads));
230 LDBG(
"computed actualVectorSize: " << desiredVectorSize);
235 int64_t totalNumThreadsUsed =
product(inferredNumThreads);
236 LDBG(
"--totalNumThreadsUsed: " << totalNumThreadsUsed);
237 if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
238 LDBG(
"--Too few threads given the required vector size -> FAIL");
239 return Status::Invalid;
242 this->vectorSize = desiredVectorSize;
243 this->numThreads = inferredNumThreads;
244 if (totalNumThreadsUsed == totalNumThreads)
245 return Status::Success;
247 return Status::RequiresPredication;
252 <<
"CopyMappingInfo: " <<
"valid: " << (status != Status::Invalid) <<
", "
253 <<
"vectorSize: " << vectorSize <<
", numThreads: {"
254 << llvm::interleaved(numThreads) <<
"}, smallestBoundingTileSizes: {"
255 << llvm::interleaved(smallestBoundingTileSizes) <<
"}, threadMapping: {"
256 << llvm::interleaved(threadMapping) <<
"}}";
static SmallVector< int64_t > maximizeNumThreads(ArrayRef< int64_t > sizes, int64_t currentIndex, int64_t maxNumThreads)
Extract result from sizes with the following constraints:
static Attribute linearId1(MLIRContext *ctx)
static int64_t product(ArrayRef< int64_t > vals)
static SmallVector< int64_t > getFactors(int64_t val)
Get the list of all factors that divide val, not just the prime factors.
static Attribute linearId0(MLIRContext *ctx)
static Attribute linearId2(MLIRContext *ctx)
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...