23template <
typename OpT>
25 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
27 LogicalResult matchAndRewrite(OpT op,
28 mlir::PatternRewriter &
b)
const override {
29 auto mRef = op.getRef();
30 if (mRef.getType().hasStaticShape()) {
31 return mlir::failure();
33 auto defOp = mRef.getDefiningOp();
34 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
35 return mlir::failure();
37 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
38 if (!src.getType().hasStaticShape()) {
39 return mlir::failure();
41 op.getRefMutable().assign(src);
42 return mlir::success();
47 using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
49 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
50 mlir::PatternRewriter &
b)
const override {
51 auto comm = op.getComm();
52 if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
53 return mlir::failure();
57 auto dltiAttr =
dlti::query(op, {
"MPI:comm_world_rank"},
false);
59 return mlir::failure();
60 if (!isa<IntegerAttr>(dltiAttr.value()))
61 return op->emitError()
62 <<
"Expected an integer attribute for MPI:comm_world_rank";
64 b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
65 if (Value retVal = op.getRetval())
66 b.replaceOp(op, {retVal, res});
69 return mlir::success();
75void mlir::mpi::SendOp::getCanonicalizationPatterns(
77 results.
add<FoldCast<mlir::mpi::SendOp>>(context);
80void mlir::mpi::RecvOp::getCanonicalizationPatterns(
82 results.
add<FoldCast<mlir::mpi::RecvOp>>(context);
85void mlir::mpi::ISendOp::getCanonicalizationPatterns(
87 results.
add<FoldCast<mlir::mpi::ISendOp>>(context);
90void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
92 results.
add<FoldCast<mlir::mpi::IRecvOp>>(context);
95void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
97 results.
add<FoldRank>(context);
104#define GET_OP_CLASSES
105#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
MLIRContext is the top-level object for a collection of MLIR operations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...