26 using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
28 LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
29 PatternRewriter &rewriter)
const override {
55 Location loc = op->getLoc();
58 auto asMaybeIndexAttr = [&](std::optional<uint32_t> bound) -> IntegerAttr {
61 return IntegerAttr::get(
62 indexType,
static_cast<int64_t
>(
static_cast<uint64_t
>(*bound)));
65 IntegerAttr maybeKnownDimX =
67 op, gpu::DimensionKind::Block, gpu::Dimension::x));
68 IntegerAttr maybeKnownDimY =
70 op, gpu::DimensionKind::Block, gpu::Dimension::y));
71 IntegerAttr maybeKnownDimZ =
73 op, gpu::DimensionKind::Block, gpu::Dimension::z));
77 dimX = arith::ConstantOp::create(rewriter, loc, maybeKnownDimX);
79 dimX = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
81 dimY = arith::ConstantOp::create(rewriter, loc, maybeKnownDimY);
83 dimY = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::y);
85 Value tidX = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x,
87 Value tidY = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::y,
89 Value tidZ = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::z,
96 arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw;
98 arith::MulIOp::create(rewriter, loc, indexType, dimY, tidZ, flags);
99 Value dimYxIdZPlusIdY =
100 arith::AddIOp::create(rewriter, loc, indexType, dimYxIdZ, tidY, flags);
101 Value dimYxIdZPlusIdYTimesDimX = arith::MulIOp::create(
102 rewriter, loc, indexType, dimX, dimYxIdZPlusIdY, flags);
103 Value idXPlusDimYxIdZPlusIdYTimesDimX = arith::AddIOp::create(
104 rewriter, loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX, flags);
105 Value subgroupSize = gpu::SubgroupSizeOp::create(
108 arith::DivUIOp::create(rewriter, loc, indexType,
109 idXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::optional< uint32_t > getKnownDimensionSizeAround(Operation *op, DimensionKind kind, Dimension dim)
Retrieve the constant bounds for a given dimension and dimension kind from the context surrounding op...
Include the generated interface declarations.
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite SubgroupIdOp op within the GPU dialect.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...