MLIR 22.0.0git
MPIOps.cpp
Go to the documentation of this file.
1//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
15
16using namespace mlir;
17using namespace mlir::mpi;
18
19namespace {
20
21// If input memref has dynamic shape and is a cast and if the cast's input has
22// static shape, fold the cast's static input into the given operation.
23template <typename OpT>
24struct FoldCast final : public mlir::OpRewritePattern<OpT> {
25 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(OpT op,
28 mlir::PatternRewriter &b) const override {
29 auto mRef = op.getRef();
30 if (mRef.getType().hasStaticShape()) {
31 return mlir::failure();
32 }
33 auto defOp = mRef.getDefiningOp();
34 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
35 return mlir::failure();
36 }
37 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
38 if (!src.getType().hasStaticShape()) {
39 return mlir::failure();
40 }
41 op.getRefMutable().assign(src);
42 return mlir::success();
43 }
44};
45
46struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
47 using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
48
49 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
50 mlir::PatternRewriter &b) const override {
51 auto comm = op.getComm();
52 if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
53 return mlir::failure();
54
55 // Try to get DLTI attribute for MPI:comm_world_rank
56 // If found, set worldRank to the value of the attribute.
57 auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
58 if (failed(dltiAttr))
59 return mlir::failure();
60 if (!isa<IntegerAttr>(dltiAttr.value()))
61 return op->emitError()
62 << "Expected an integer attribute for MPI:comm_world_rank";
64 b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
65 if (Value retVal = op.getRetval())
66 b.replaceOp(op, {retVal, res});
67 else
68 b.replaceOp(op, res);
69 return mlir::success();
70 }
71};
72
73} // namespace
74
75void mlir::mpi::SendOp::getCanonicalizationPatterns(
76 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
77 results.add<FoldCast<mlir::mpi::SendOp>>(context);
78}
79
80void mlir::mpi::RecvOp::getCanonicalizationPatterns(
81 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
82 results.add<FoldCast<mlir::mpi::RecvOp>>(context);
83}
84
85void mlir::mpi::ISendOp::getCanonicalizationPatterns(
86 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
87 results.add<FoldCast<mlir::mpi::ISendOp>>(context);
88}
89
90void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
91 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
92 results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
93}
94
95void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
96 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
97 results.add<FoldRank>(context);
98}
99
100//===----------------------------------------------------------------------===//
101// TableGen'd op method definitions
102//===----------------------------------------------------------------------===//
103
104#define GET_OP_CLASSES
105#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition DLTI.cpp:537
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...