MLIR  21.0.0git
SubgroupIdRewriter.cpp
Go to the documentation of this file.
1 //===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10 // where:
11 // subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12 //
13 //===----------------------------------------------------------------------===//
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
27 
28  LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
29  PatternRewriter &rewriter) const override {
30  // Calculation of the thread's subgroup identifier.
31  //
32  // The process involves mapping the thread's 3D identifier within its
33  // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
34  // This linearization assumes a layout where the x-dimension (w_dim.x)
35  // varies most rapidly (i.e., it is the innermost dimension).
36  //
37  // The formula for the linearized thread index is:
38  // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
39  //
40  // Subsequently, the range of linearized indices [0, N_threads-1] is
41  // divided into consecutive, non-overlapping segments, each representing
42  // a subgroup of size 'subgroup_size'.
43  //
44  // Example Partitioning (N = subgroup_size):
45  // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
46  // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
47  //
48  // The subgroup identifier is obtained via integer division of the
49  // linearized thread index by the predefined 'subgroup_size'.
50  //
51  // subgroup_id = floor( L / subgroup_size )
52  // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
53  // subgroup_size
54 
55  Location loc = op->getLoc();
56  Type indexType = rewriter.getIndexType();
57 
58  Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
59  Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
60  Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
61  Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
62  Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
63 
64  Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ);
65  Value dimYxIdZPlusIdY =
66  rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY);
67  Value dimYxIdZPlusIdYTimesDimX =
68  rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY);
69  Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>(
70  loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX);
71  Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
72  loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
73  Value subgroupIdOp = rewriter.create<arith::DivUIOp>(
74  loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
75  rewriter.replaceOp(op, {subgroupIdOp});
76  return success();
77  }
78 };
79 
80 } // namespace
81 
83  patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
84 }
constexpr unsigned subgroupSize
HW dependent constants.
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite SubgroupIdOp op within the GPU dialect.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314