9 #ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
10 #define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
15 #include "llvm/Support/Casting.h"
23 class SymbolTableCollection;
39 template <
typename AlgebraicOp>
42 auto getEndomorphismOpOperand = [](
Operation *op) {
43 auto allReduceOp = llvm::cast<AllReduceOp>(op);
44 return &allReduceOp.getInputMutable();
46 auto getEndomorphismOpResult = [](
Operation *op) {
47 auto allReduceOp = llvm::cast<AllReduceOp>(op);
48 return allReduceOp->getResult(0);
50 auto getAlgebraicOpOperands = [](
Operation *op,
52 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53 std::transform(algebraicOp->getOpOperands().begin(),
54 algebraicOp->getOpOperands().end(),
55 std::back_inserter(operands),
56 [](
OpOperand &operand) { return &operand; });
58 auto getAlgebraicOpResult = [](
Operation *op) {
59 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
60 return algebraicOp->getResult(0);
62 auto isEndomorphismOp = [reduction](
Operation *op,
63 std::optional<Operation *> referenceOp) {
64 auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
67 auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
68 auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
69 if (inType.getElementType() != outType.getElementType() ||
70 allReduceOp.getReduction() != reduction) {
79 if (!allReduceOp->hasOneUse()) {
87 auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
88 auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
89 return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
90 inType.getElementType() == refType.getElementType();
92 auto isAlgebraicOp = [](
Operation *op) {
return isa<AlgebraicOp>(op); };
95 std::decay_t<decltype(getEndomorphismOpOperand)>,
96 std::decay_t<decltype(getEndomorphismOpResult)>,
97 std::decay_t<decltype(getAlgebraicOpOperands)>,
98 std::decay_t<decltype(getAlgebraicOpResult)>,
99 std::decay_t<decltype(isEndomorphismOp)>,
100 std::decay_t<decltype(isAlgebraicOp)>>;
101 patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
102 std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
103 std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
104 std::move(isEndomorphismOp), std::move(isAlgebraicOp),
105 AlgebraicOp::getOperationName(), 1,
patterns.getContext()));
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
shard::ReductionKind ReductionKind
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllReduceEndomorphismSimplificationPatterns(RewritePatternSet &patterns, ReductionKind reduction)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns