MLIR  22.0.0git
SincosFusion.cpp
Go to the documentation of this file.
1 //===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 using namespace mlir::math;
16 
17 namespace {
18 
19 /// Fuse a math.sin and math.cos in the same block that use the same operand and
20 /// have identical fastmath flags into a single math.sincos.
21 struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
22  using Base::Base;
23 
24  LogicalResult matchAndRewrite(math::SinOp sinOp,
25  PatternRewriter &rewriter) const override {
26  Value operand = sinOp.getOperand();
27  mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
28 
29  math::CosOp cosOp = nullptr;
30  sinOp->getBlock()->walk([&](math::CosOp op) {
31  if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
32  cosOp = op;
33  return WalkResult::interrupt();
34  }
35  return WalkResult::advance();
36  });
37 
38  if (!cosOp)
39  return failure();
40 
41  Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
42  : cosOp.getOperation();
43  rewriter.setInsertionPoint(firstOp);
44 
45  Type elemType = sinOp.getType();
46  auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
47  TypeRange{elemType, elemType}, operand,
48  sinOp.getFastmathAttr());
49 
50  rewriter.replaceOp(sinOp, sincos.getSin());
51  rewriter.replaceOp(cosOp, sincos.getCos());
52  return success();
53  }
54 };
55 
56 } // namespace
57 
58 namespace mlir::math {
59 #define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
60 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
61 } // namespace mlir::math
62 
63 namespace {
64 
65 struct MathSincosFusionPass final
66  : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
67  using MathSincosFusionPassBase::MathSincosFusionPassBase;
68 
69  void runOnOperation() override {
71  patterns.add<SincosFusionPattern>(&getContext());
72 
74  if (failed(
75  applyPatternsGreedily(getOperation(), std::move(patterns), config)))
76  return signalPassFailure();
77  }
78 };
79 
80 } // namespace
static MLIRContext * getContext(OpFoldResult val)
This class allows control over how the GreedyPatternRewriteDriver works.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static WalkResult advance()
Definition: WalkResult.h:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314