9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
15 #include "llvm/Support/Casting.h"
23 class SymbolTableCollection;
39 template <
typename AlgebraicOp>
42 auto getEndomorphismOpOperand = [](
Operation *op) {
43 auto allReduceOp = llvm::cast<AllReduceOp>(op);
44 return &allReduceOp.getInputMutable();
46 auto getEndomorphismOpResult = [](
Operation *op) {
47 auto allReduceOp = llvm::cast<AllReduceOp>(op);
48 return allReduceOp->getResult(0);
50 auto getAlgebraicOpOperands = [](
Operation *op,
52 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53 std::transform(algebraicOp->getOpOperands().begin(),
54 algebraicOp->getOpOperands().end(),
55 std::back_inserter(operands),
56 [](
OpOperand &operand) { return &operand; });
58 auto getAlgebraicOpResult = [](
Operation *op) {
59 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60 return algebraicOp->getResult(0);
62 auto isEndomorphismOp = [reduction](
Operation *op,
63 std::optional<Operation *> referenceOp) {
64 auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
66 allReduceOp.getInput().getType().getElementType() !=
67 allReduceOp.getResult().getType().getElementType() ||
68 allReduceOp.getReduction() != reduction) {
77 if (!allReduceOp->hasOneUse()) {
85 auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86 return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
87 allReduceOp.getInput().getType().getElementType() ==
88 refAllReduceOp.getInput().getType().getElementType();
91 return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
95 std::decay_t<decltype(getEndomorphismOpOperand)>,
96 std::decay_t<decltype(getEndomorphismOpResult)>,
97 std::decay_t<decltype(getAlgebraicOpOperands)>,
98 std::decay_t<decltype(getAlgebraicOpResult)>,
99 std::decay_t<decltype(isEndomorphismOp)>,
100 std::decay_t<decltype(isAlgebraicOp)>>;
101 patterns.
add(std::make_unique<ConcreteEndomorphismSimplification>(
102 std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103 std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104 std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105 AlgebraicOp::getOperationName(), 1, patterns.
getContext()));
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
mesh::ReductionKind ReductionKind
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.