24 LogicalResult matchAndRewrite(math::SinOp sinOp,
25 PatternRewriter &rewriter)
const override {
26 Value operand = sinOp.getOperand();
27 mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
29 math::CosOp cosOp =
nullptr;
30 for (
auto op : sinOp->getBlock()->getOps<math::CosOp>())
31 if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
39 Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
40 : cosOp.getOperation();
43 Type elemType = sinOp.getType();
44 auto sincos = math::SincosOp::create(rewriter, firstOp->
getLoc(),
46 sinOp.getFastmathAttr());
48 rewriter.
replaceOp(sinOp, sincos.getSin());
49 rewriter.
replaceOp(cosOp, sincos.getCos());
57#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
58#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
63struct MathSincosFusionPass final
64 : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
65 using MathSincosFusionPassBase::MathSincosFusionPassBase;
67 void runOnOperation()
override {
71 GreedyRewriteConfig config;
74 return signalPassFailure();
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...