MLIR 23.0.0git
MPIOps.cpp
Go to the documentation of this file.
1//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
14
15using namespace mlir;
16using namespace mlir::mpi;
17
18//===----------------------------------------------------------------------===//
19// Verifiers
20//===----------------------------------------------------------------------===//
21
22LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
23 if (getSendbuf().getType().getElementType() !=
24 getRecvbuf().getType().getElementType())
25 return emitOpError("sendbuf and recvbuf must have the same element type");
26 return success();
27}
28
29namespace {
30
31//===----------------------------------------------------------------------===//
32// Canonicalization patterns
33//===----------------------------------------------------------------------===//
34
35// If input memref has dynamic shape and is a cast and if the cast's input has
36// static shape, fold the cast's static input into the given operation.
37template <typename OpT>
38struct FoldCast final : public mlir::OpRewritePattern<OpT> {
39 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
40
41 LogicalResult matchAndRewrite(OpT op,
42 mlir::PatternRewriter &b) const override {
43 auto mRef = op.getRef();
44 if (mRef.getType().hasStaticShape()) {
45 return mlir::failure();
46 }
47 auto defOp = mRef.getDefiningOp();
48 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
49 return mlir::failure();
50 }
51 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
52 if (!src.getType().hasStaticShape()) {
53 return mlir::failure();
54 }
55 op.getRefMutable().assign(src);
56 return mlir::success();
57 }
58};
59
60struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
61 using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
62 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
63 mlir::PatternRewriter &b) const override {
64 return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
65 }
66};
67
68struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
69 using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
70
71 LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
72 mlir::PatternRewriter &b) const override {
73 return FoldToDLTIConst(op, "MPI:comm_world_size", b);
74 }
75};
76} // namespace
77
78void mlir::mpi::SendOp::getCanonicalizationPatterns(
79 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
80 results.add<FoldCast<mlir::mpi::SendOp>>(context);
81}
82
83void mlir::mpi::RecvOp::getCanonicalizationPatterns(
84 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
85 results.add<FoldCast<mlir::mpi::RecvOp>>(context);
86}
87
88void mlir::mpi::ISendOp::getCanonicalizationPatterns(
89 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
90 results.add<FoldCast<mlir::mpi::ISendOp>>(context);
91}
92
93void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
94 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
95 results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
96}
97
98void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
99 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
100 results.add<FoldRank>(context);
101}
102
103void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
104 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
105 results.add<FoldSize>(context);
106}
107
108//===----------------------------------------------------------------------===//
109// TableGen'd op method definitions
110//===----------------------------------------------------------------------===//
111
112#define GET_OP_CLASSES
113#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
Definition Utils.h:19
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...