MLIR  21.0.0git
MPIOps.cpp
Go to the documentation of this file.
1 //===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/PatternMatch.h"
15 
16 using namespace mlir;
17 using namespace mlir::mpi;
18 
19 namespace {
20 
21 // If input memref has dynamic shape and is a cast and if the cast's input has
22 // static shape, fold the cast's static input into the given operation.
23 template <typename OpT>
24 struct FoldCast final : public mlir::OpRewritePattern<OpT> {
26 
27  LogicalResult matchAndRewrite(OpT op,
28  mlir::PatternRewriter &b) const override {
29  auto mRef = op.getRef();
30  if (mRef.getType().hasStaticShape()) {
31  return mlir::failure();
32  }
33  auto defOp = mRef.getDefiningOp();
34  if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
35  return mlir::failure();
36  }
37  auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
38  if (!src.getType().hasStaticShape()) {
39  return mlir::failure();
40  }
41  op.getRefMutable().assign(src);
42  return mlir::success();
43  }
44 };
45 
46 struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
48 
49  LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
50  mlir::PatternRewriter &b) const override {
51  auto comm = op.getComm();
52  if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
53  return mlir::failure();
54 
55  // Try to get DLTI attribute for MPI:comm_world_rank
56  // If found, set worldRank to the value of the attribute.
57  auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
58  if (failed(dltiAttr))
59  return mlir::failure();
60  if (!isa<IntegerAttr>(dltiAttr.value()))
61  return op->emitError()
62  << "Expected an integer attribute for MPI:comm_world_rank";
63  Value res = b.create<arith::ConstantIndexOp>(
64  op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
65  if (Value retVal = op.getRetval())
66  b.replaceOp(op, {retVal, res});
67  else
68  b.replaceOp(op, res);
69  return mlir::success();
70  }
71 };
72 
73 } // namespace
74 
75 void mlir::mpi::SendOp::getCanonicalizationPatterns(
76  mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
77  results.add<FoldCast<mlir::mpi::SendOp>>(context);
78 }
79 
80 void mlir::mpi::RecvOp::getCanonicalizationPatterns(
81  mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
82  results.add<FoldCast<mlir::mpi::RecvOp>>(context);
83 }
84 
85 void mlir::mpi::ISendOp::getCanonicalizationPatterns(
86  mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
87  results.add<FoldCast<mlir::mpi::ISendOp>>(context);
88 }
89 
90 void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
91  mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
92  results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
93 }
94 
95 void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
96  mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
97  results.add<FoldRank>(context);
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // TableGen'd op method definitions
102 //===----------------------------------------------------------------------===//
103 
104 #define GET_OP_CLASSES
105 #include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition: DLTI.cpp:539
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314