MLIR  22.0.0git
GlobalIdRewriter.cpp
Go to the documentation of this file.
1 //===- GlobalIdRewriter.cpp - Implementation of GlobalId rewriting -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the global_id op for archs
10 // where global_id.x = threadId.x + blockId.x * blockDim.x
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
24 
25  LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
26  PatternRewriter &rewriter) const override {
27  Location loc = op.getLoc();
28  auto dim = op.getDimension();
29  auto blockId = gpu::BlockIdOp::create(rewriter, loc, dim);
30  auto blockDim = gpu::BlockDimOp::create(rewriter, loc, dim);
31  // Compute blockId.x * blockDim.x
32  auto tmp = index::MulOp::create(rewriter, op.getLoc(), blockId, blockDim);
33  auto threadId = gpu::ThreadIdOp::create(rewriter, loc, dim);
34  // Compute threadId.x + blockId.x * blockDim.x
35  rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
36  return success();
37  }
38 };
39 } // namespace
40 
42  patterns.add<GpuGlobalIdRewriter>(patterns.getContext());
43 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Include the generated interface declarations.
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314