MLIR 22.0.0git
SubgroupIdRewriter.cpp
Go to the documentation of this file.
1//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10// where:
11// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12//
13//===----------------------------------------------------------------------===//
14
18#include "mlir/IR/Builders.h"
20
21using namespace mlir;
22
23namespace {
24struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
25 using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
28 PatternRewriter &rewriter) const override {
29 // Calculation of the thread's subgroup identifier.
30 //
31 // The process involves mapping the thread's 3D identifier within its
32 // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
33 // This linearization assumes a layout where the x-dimension (w_dim.x)
34 // varies most rapidly (i.e., it is the innermost dimension).
35 //
36 // The formula for the linearized thread index is:
37 // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
38 //
39 // Subsequently, the range of linearized indices [0, N_threads-1] is
40 // divided into consecutive, non-overlapping segments, each representing
41 // a subgroup of size 'subgroup_size'.
42 //
43 // Example Partitioning (N = subgroup_size):
44 // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
45 // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
46 //
47 // The subgroup identifier is obtained via integer division of the
48 // linearized thread index by the predefined 'subgroup_size'.
49 //
50 // subgroup_id = floor( L / subgroup_size )
51 // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
52 // subgroup_size
53
54 Location loc = op->getLoc();
55 Type indexType = rewriter.getIndexType();
56
57 Value dimX = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
58 Value dimY = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::y);
59 Value tidX = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
60 Value tidY = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::y);
61 Value tidZ = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::z);
62
63 Value dimYxIdZ =
64 arith::MulIOp::create(rewriter, loc, indexType, dimY, tidZ);
65 Value dimYxIdZPlusIdY =
66 arith::AddIOp::create(rewriter, loc, indexType, dimYxIdZ, tidY);
67 Value dimYxIdZPlusIdYTimesDimX =
68 arith::MulIOp::create(rewriter, loc, indexType, dimX, dimYxIdZPlusIdY);
69 Value IdXPlusDimYxIdZPlusIdYTimesDimX = arith::AddIOp::create(
70 rewriter, loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX);
71 Value subgroupSize = gpu::SubgroupSizeOp::create(
72 rewriter, loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
73 Value subgroupIdOp =
74 arith::DivUIOp::create(rewriter, loc, indexType,
75 IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
76 rewriter.replaceOp(op, {subgroupIdOp});
77 return success();
78 }
79};
80
81} // namespace
82
84 patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
85}
return success()
IndexType getIndexType()
Definition Builders.cpp:51
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite SubgroupIdOp op within the GPU dialect.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...