29 const std::optional<ArrayAttr> &foreachMapping,
30 std::optional<TransformOpInterface> transformOp) {
31 if (!foreachMapping.has_value())
32 return transformOp->emitSilenceableError() <<
"mapping must be present";
35 for (
Attribute map : foreachMapping->getValue()) {
36 if (!llvm::is_contained(threadMappingAttributes, map)) {
37 return transformOp->emitDefiniteFailure()
38 <<
"mapping must be one of " << threadMappingAttributes;
40 if (llvm::is_contained(seen, map)) {
41 return transformOp->emitDefiniteFailure()
43 <<
" is duplicated, cannot map different "
44 "loops to the same processor";
56 TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
57 std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
58 std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
59 std::optional<int64_t> blockDimZ) {
61 static constexpr
int maxTotalBlockdim = 1024;
62 static constexpr
int maxBlockdimx = 1024;
63 static constexpr
int maxBlockdimy = 1024;
64 static constexpr
int maxBlockdimz = 64;
65 static constexpr
int maxTotalGriddim = 2147483647;
66 static constexpr
int maxGriddimx = 2147483647;
67 static constexpr
int maxGriddimy = 65535;
68 static constexpr
int maxGriddimz = 65535;
70 if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
72 (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
74 blockDimX.value_or(1) > maxBlockdimx ||
75 blockDimY.value_or(1) > maxBlockdimy ||
76 blockDimZ.value_or(1) > maxBlockdimz ||
77 gridDimY.value_or(1) > maxGriddimy ||
78 gridDimZ.value_or(1) > maxGriddimz ||
79 gridDimX.value_or(1) > maxGriddimx) {
80 return transformOp.emitSilenceableError()
81 <<
"Trying to launch a GPU kernel with gridDim = ("
82 << gridDimX.value_or(1) <<
", " << gridDimY.value_or(1) <<
", "
83 << gridDimZ.value_or(1) <<
") blockDim = (" << blockDimX.value_or(1)
84 <<
", " << blockDimY.value_or(1) <<
", " << blockDimZ.value_or(1)
85 <<
"). It is larger than the limits.";
94 TransformOpInterface transformOp, LaunchOp &launchOp,
95 std::optional<int64_t> gridDimX = std::nullopt,
96 std::optional<int64_t> gridDimY = std::nullopt,
97 std::optional<int64_t> gridDimZ = std::nullopt,
98 std::optional<int64_t> blockDimX = std::nullopt,
99 std::optional<int64_t> blockDimY = std::nullopt,
100 std::optional<int64_t> blockDimZ = std::nullopt) {
102 checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
103 blockDimY, blockDimZ);
104 if (!
diag.succeeded())
108 return rewriter.
create<arith::ConstantIndexOp>(loc, dim);
112 Value gridSizeX = gridDimX.has_value() ?
createConst(gridDimX.value()) : one;
113 Value gridSizeY = gridDimY.has_value() ?
createConst(gridDimY.value()) : one;
114 Value gridSizeZ = gridDimZ.has_value() ?
createConst(gridDimZ.value()) : one;
115 Value blkSizeX = blockDimX.has_value() ?
createConst(blockDimX.value()) : one;
116 Value blkSizeY = blockDimY.has_value() ?
createConst(blockDimY.value()) : one;
117 Value blkSizeZ = blockDimZ.has_value() ?
createConst(blockDimZ.value()) : one;
118 launchOp = rewriter.
create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
119 blkSizeX, blkSizeY, blkSizeZ);
121 rewriter.
create<TerminatorOp>(loc);
128 TransformOpInterface transformOp,
129 std::optional<int64_t> gridDimX = std::nullopt,
130 std::optional<int64_t> gridDimY = std::nullopt,
131 std::optional<int64_t> gridDimZ = std::nullopt,
132 std::optional<int64_t> blockDimX = std::nullopt,
133 std::optional<int64_t> blockDimY = std::nullopt,
134 std::optional<int64_t> blockDimZ = std::nullopt) {
136 checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
137 blockDimY, blockDimZ);
138 if (!
diag.succeeded())
141 KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
144 auto createConstValue = [&](
int dim) {
145 return rewriter.
create<arith::ConstantIndexOp>(currentBlockdim.
x.
getLoc(),
149 if (gridDimX.has_value())
150 gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
151 if (gridDimY.has_value())
152 gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
153 if (gridDimZ.has_value())
154 gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
155 if (blockDimX.has_value())
156 gpuLaunch.getBlockSizeXMutable().assign(
157 createConstValue(blockDimX.value()));
158 if (blockDimY.has_value())
159 gpuLaunch.getBlockSizeYMutable().assign(
160 createConstValue(blockDimY.value()));
161 if (blockDimZ.has_value())
162 gpuLaunch.getBlockSizeZMutable().assign(
163 createConstValue(blockDimZ.value()));
172 RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
181 Location loc = foreachThreadOp->getLoc();
183 if (foreachThreadOp.getNumResults() > 0)
184 return transformOp.emitSilenceableError()
185 <<
"only bufferized scf.foreach_thread lowers to "
187 if (foreachThreadOp.getNumThreads().size() > 3)
188 return transformOp.emitSilenceableError()
189 <<
"scf.foreach_thread with rank > 3 does not lower to "
191 if (llvm::any_of(foreachThreadOp.getNumThreads(), [](
Value v) {
192 return !v.getDefiningOp<arith::ConstantIndexOp>();
194 return transformOp.emitSilenceableError()
195 <<
"unsupported dynamic griddim size";
198 llvm::to_vector(foreachThreadOp.getMapping()->getValue());
202 llvm::to_vector(foreachThreadOp.getNumThreads());
205 for (
auto attr : mappingAttributes) {
206 if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
207 blockMapping.end()) {
208 blockMapping.push_back(attr);
210 numBlocks.push_back(one);
215 auto comparator = [&](DeviceMappingAttrInterface a,
216 DeviceMappingAttrInterface b) ->
bool {
217 return a.getMappingId() < b.getMappingId();
220 blockMapping, numBlocks, comparator);
221 for (
Value v : gridDimValues)
227 blockIdGenerator(rewriter, foreachThreadOp, blockOps);
229 for (
auto [blockIdx, blockDim] :
230 llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
232 blockOps[
static_cast<int64_t
>(
233 blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
238 rewriter.
eraseOp(foreachThreadOp.getTerminator());
239 Block *targetBlock = foreachThreadOp->getBlock();
241 Block &sourceBlock = foreachThreadOp.getRegion().
front();
246 for (
Value loopIndex : foreachThreadOp.getThreadIndices()) {
252 rewriter.
eraseOp(foreachThreadOp);
258 Operation *target, scf::ForeachThreadOp &topLevelForeachThreadOp,
259 TransformOpInterface transformOp) {
260 auto walkResult = target->
walk([&](scf::ForeachThreadOp foreachThreadOp) {
261 if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
263 if (topLevelForeachThreadOp)
266 topLevelForeachThreadOp = foreachThreadOp;
270 if (walkResult.wasInterrupted())
271 return transformOp.emitSilenceableError()
272 <<
"could not find a unique topLevel scf.foreach_thread";
280 scf::ForeachThreadOp foreachOp,
287 rewriter.
create<BlockIdOp>(loc, indexType, Dimension::x),
288 rewriter.
create<BlockIdOp>(loc, indexType, Dimension::y),
289 rewriter.
create<BlockIdOp>(loc, indexType, Dimension::z)};
293 transform::MapForeachToBlocks::applyToOne(
Operation *target,
296 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
298 auto transformOp = cast<TransformOpInterface>(getOperation());
300 if (!getGenerateGpuLaunch() && !gpuLaunch) {
302 emitSilenceableError()
303 <<
"Given target is not gpu.launch, set `generate_gpu_launch` "
305 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
309 scf::ForeachThreadOp topLevelForeachThreadOp;
312 target, topLevelForeachThreadOp, transformOp);
313 if (!
diag.succeeded()) {
314 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
322 if (getGenerateGpuLaunch()) {
325 if (!
diag.succeeded()) {
329 Operation *newForeachThreadOp = rewriter.
clone(*topLevelForeachThreadOp);
330 rewriter.
eraseOp(topLevelForeachThreadOp);
331 topLevelForeachThreadOp = cast<scf::ForeachThreadOp>(newForeachThreadOp);
336 GPUBlockMappingAttr::get(getContext(), Blocks::DimX),
337 GPUBlockMappingAttr::get(getContext(), Blocks::DimY),
338 GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)};
341 topLevelForeachThreadOp.getMapping(), transformOp);
342 if (
diag.succeeded())
345 transformOp, blockMappingAttributes);
346 if (
diag.succeeded()) {
348 cast<TransformOpInterface>(getOperation()),
349 gridDim[0], gridDim[1], gridDim[2]);
368 RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
370 std::optional<TransformOpInterface> transformOp,
377 if (transformOp.has_value()) {
378 return transformOp->emitSilenceableError() << message;
382 Location loc = foreachThreadOp->getLoc();
383 if (foreachThreadOp.getNumResults() > 0)
384 return failureHelper(
385 "only bufferized scf.foreach_thread lowers to gpu.thread_id");
386 if (foreachThreadOp.getNumThreads().size() > 3)
387 return failureHelper(
388 "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
389 if (llvm::any_of(foreachThreadOp.getNumThreads(), [](
Value v) {
390 return !v.getDefiningOp<arith::ConstantIndexOp>();
392 return failureHelper(
"unsupported dynamic blockdim size");
394 if (!foreachThreadOp.getMapping().has_value())
395 return failureHelper(
"mapping must be present");
397 llvm::to_vector(foreachThreadOp.getMapping()->getValue());
402 llvm::to_vector(foreachThreadOp.getNumThreads());
405 for (
auto attr : threadMappingAttributes) {
406 if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
407 threadMapping.end()) {
408 threadMapping.push_back(attr);
409 one = one ? one : rewriter.
create<arith::ConstantIndexOp>(loc, 1);
410 numThreads.push_back(one);
415 auto comparator = [&](DeviceMappingAttrInterface a,
416 DeviceMappingAttrInterface b) ->
bool {
417 return a.getMappingId() < b.getMappingId();
420 scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
423 llvm::to_vector(llvm::map_range(blockDimValues, [](
Value v) {
431 rewriter.
create<ThreadIdOp>(loc, indexType, Dimension::x),
432 rewriter.
create<ThreadIdOp>(loc, indexType, Dimension::y),
433 rewriter.
create<ThreadIdOp>(loc, indexType, Dimension::z)};
435 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
436 for (
size_t i : llvm::seq(
size_t(0), globalBlockDims.size())) {
437 if (globalBlockDims[i] == 1)
441 for (
auto [blockIdx, blockDim] :
442 llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
445 threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
450 for (
auto [threadId, blockDim, globalBlockDim] :
451 llvm::zip(threadOps, blockDims, globalBlockDims)) {
452 if (blockDim > globalBlockDim) {
453 return failureHelper(
454 "The requested GPU threads are fewer than the number of loop trip "
455 "counts. Try to tile scf.foreach_thread before mapping or set "
458 if (blockDim == globalBlockDim)
460 Value blockIdx = rewriter.
create<arith::ConstantIndexOp>(loc, blockDim);
461 Value tmpPredicate = rewriter.
create<arith::CmpIOp>(
462 loc, arith::CmpIPredicate::ult, threadId, blockIdx);
464 predicate ? rewriter.
create<arith::AndIOp>(loc, predicate, tmpPredicate)
470 rewriter.
eraseOp(foreachThreadOp.getTerminator());
476 rewriter.
create<scf::IfOp>(loc, predicate,
false);
477 targetBlock = ifOp.thenBlock();
478 insertionPoint = ifOp.thenBlock()->
begin();
481 targetBlock = foreachThreadOp->getBlock();
484 Block &sourceBlock = foreachThreadOp.getRegion().
front();
489 for (
Value loopIndex : foreachThreadOp.getThreadIndices()) {
496 if (syncAfterDistribute)
497 rewriter.
create<BarrierOp>(loc);
500 rewriter.
eraseOp(foreachThreadOp);
508 std::optional<TransformOpInterface> transformOp,
511 target->
walk([&](scf::ForeachThreadOp foreachThreadOp) {
513 foreachThreadOp.getMapping(), transformOp);
514 if (
diag.succeeded()) {
517 rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
518 threadMappingAttributes);
527 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
528 auto transformOp = cast<TransformOpInterface>(getOperation());
531 return emitSilenceableError() <<
"Given target is not gpu.launch";
535 blockDim.resize(3, 1);
538 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
539 blockDim[0], blockDim[1], blockDim[2]);
540 if (
diag.isSilenceableFailure()) {
541 diag.attachNote(getLoc()) << getBlockDimAttrName() <<
" is very large";
550 GPUThreadMappingAttr::get(ctx, Threads::DimX),
551 GPUThreadMappingAttr::get(ctx, Threads::DimY),
552 GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
555 rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
556 threadMappingAttributes);
558 if (
diag.succeeded()) {
560 std::nullopt, std::nullopt, blockDim[0], blockDim[1],
564 results.
push_back(gpuLaunch.getOperation());
575 class GPUTransformDialectExtension
577 GPUTransformDialectExtension> {
579 GPUTransformDialectExtension() {
580 declareDependentDialect<pdl::PDLDialect>();
581 declareGeneratedDialect<scf::SCFDialect>();
582 declareGeneratedDialect<arith::ArithDialect>();
583 declareGeneratedDialect<GPUDialect>();
584 registerTransformOps<
586 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
592 #define GET_OP_CLASSES
593 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Utility class for the GPU dialect to represent triples of Values accessible through ....