23 template <
typename OpT>
27 LogicalResult matchAndRewrite(OpT op,
29 auto mRef = op.getRef();
30 if (mRef.getType().hasStaticShape()) {
31 return mlir::failure();
33 auto defOp = mRef.getDefiningOp();
34 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
35 return mlir::failure();
37 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
38 if (!src.getType().hasStaticShape()) {
39 return mlir::failure();
41 op.getRefMutable().assign(src);
42 return mlir::success();
49 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
51 auto comm = op.getComm();
52 if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
53 return mlir::failure();
57 auto dltiAttr =
dlti::query(op, {
"MPI:comm_world_rank"},
false);
59 return mlir::failure();
60 if (!isa<IntegerAttr>(dltiAttr.value()))
61 return op->emitError()
62 <<
"Expected an integer attribute for MPI:comm_world_rank";
64 op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
65 if (
Value retVal = op.getRetval())
69 return mlir::success();
75 void mlir::mpi::SendOp::getCanonicalizationPatterns(
77 results.
add<FoldCast<mlir::mpi::SendOp>>(context);
80 void mlir::mpi::RecvOp::getCanonicalizationPatterns(
82 results.
add<FoldCast<mlir::mpi::RecvOp>>(context);
85 void mlir::mpi::ISendOp::getCanonicalizationPatterns(
87 results.
add<FoldCast<mlir::mpi::ISendOp>>(context);
90 void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
92 results.
add<FoldCast<mlir::mpi::IRecvOp>>(context);
95 void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
97 results.
add<FoldRank>(context);
104 #define GET_OP_CLASSES
105 #include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...