MLIR  19.0.0git
GlobalIdRewriter.cpp
Go to the documentation of this file.
1 //===- GlobalIdRewriter.cpp - Implementation of GlobalId rewriting -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect rewriting of the global_id op for archs
10 // where global_id.x = threadId.x + blockId.x * blockDim.x
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
26 
27  LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
28  PatternRewriter &rewriter) const override {
29  auto loc = op.getLoc();
30  auto dim = op.getDimension();
31  auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim);
32  auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim);
33  // Compute blockId.x * blockDim.x
34  auto tmp = rewriter.create<index::MulOp>(op.getLoc(), blockId, blockDim);
35  auto threadId = rewriter.create<gpu::ThreadIdOp>(loc, dim);
36  // Compute threadId.x + blockId.x * blockDim.x
37  rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
38  return success();
39  }
40 };
41 } // namespace
42 
44  patterns.add<GpuGlobalIdRewriter>(patterns.getContext());
45 }
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Include the generated interface declarations.
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358