MLIR 23.0.0git
MPIOps.cpp
Go to the documentation of this file.
1//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
14
15using namespace mlir;
16using namespace mlir::mpi;
17
18namespace {
19
20// If input memref has dynamic shape and is a cast and if the cast's input has
21// static shape, fold the cast's static input into the given operation.
22template <typename OpT>
23struct FoldCast final : public mlir::OpRewritePattern<OpT> {
24 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
25
26 LogicalResult matchAndRewrite(OpT op,
27 mlir::PatternRewriter &b) const override {
28 auto mRef = op.getRef();
29 if (mRef.getType().hasStaticShape()) {
30 return mlir::failure();
31 }
32 auto defOp = mRef.getDefiningOp();
33 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
34 return mlir::failure();
35 }
36 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
37 if (!src.getType().hasStaticShape()) {
38 return mlir::failure();
39 }
40 op.getRefMutable().assign(src);
41 return mlir::success();
42 }
43};
44
45struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
46 using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
47 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
48 mlir::PatternRewriter &b) const override {
49 return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
50 }
51};
52
53struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
54 using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
55
56 LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
57 mlir::PatternRewriter &b) const override {
58 return FoldToDLTIConst(op, "MPI:comm_world_size", b);
59 }
60};
61} // namespace
62
63void mlir::mpi::SendOp::getCanonicalizationPatterns(
64 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
65 results.add<FoldCast<mlir::mpi::SendOp>>(context);
66}
67
68void mlir::mpi::RecvOp::getCanonicalizationPatterns(
69 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
70 results.add<FoldCast<mlir::mpi::RecvOp>>(context);
71}
72
73void mlir::mpi::ISendOp::getCanonicalizationPatterns(
74 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
75 results.add<FoldCast<mlir::mpi::ISendOp>>(context);
76}
77
78void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
79 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
80 results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
81}
82
83void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
84 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
85 results.add<FoldRank>(context);
86}
87
88void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
89 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
90 results.add<FoldSize>(context);
91}
92
93//===----------------------------------------------------------------------===//
94// TableGen'd op method definitions
95//===----------------------------------------------------------------------===//
96
97#define GET_OP_CLASSES
98#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
Definition Utils.h:19
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...