42 auto getEndomorphismOpOperand = [](
Operation *op) {
43 auto allReduceOp = llvm::cast<AllReduceOp>(op);
44 return &allReduceOp.getInputMutable();
46 auto getEndomorphismOpResult = [](
Operation *op) {
47 auto allReduceOp = llvm::cast<AllReduceOp>(op);
48 return allReduceOp->getResult(0);
50 auto getAlgebraicOpOperands = [](
Operation *op,
52 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
53 llvm::append_range(operands,
54 llvm::make_pointer_range(algebraicOp->getOpOperands()));
56 auto getAlgebraicOpResult = [](
Operation *op) {
57 auto algebraicOp = llvm::cast<AlgebraicOp>(op);
58 return algebraicOp->getResult(0);
60 auto isEndomorphismOp = [reduction](
Operation *op,
61 std::optional<Operation *> referenceOp) {
62 auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65 auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
66 auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
67 if (inType.getElementType() != outType.getElementType() ||
68 allReduceOp.getReduction() != reduction) {
77 if (!allReduceOp->hasOneUse()) {
85 auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86 auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
87 return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
88 inType.getElementType() == refType.getElementType();
90 auto isAlgebraicOp = [](
Operation *op) {
return isa<AlgebraicOp>(op); };
93 std::decay_t<
decltype(getEndomorphismOpOperand)>,
94 std::decay_t<
decltype(getEndomorphismOpResult)>,
95 std::decay_t<
decltype(getAlgebraicOpOperands)>,
96 std::decay_t<
decltype(getAlgebraicOpResult)>,
97 std::decay_t<
decltype(isEndomorphismOp)>,
98 std::decay_t<
decltype(isAlgebraicOp)>>;
99 patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
100 std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
101 std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
102 std::move(isEndomorphismOp), std::move(isAlgebraicOp),
103 AlgebraicOp::getOperationName(), 1,
patterns.getContext()));