MLIR  20.0.0git
UpliftToFMA.cpp
Go to the documentation of this file.
1 //===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements uplifting from arith ops to math.fma.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir::math {
20 #define GEN_PASS_DEF_MATHUPLIFTTOFMA
21 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
22 } // namespace mlir::math
23 
24 using namespace mlir;
25 
26 template <typename Op>
27 static bool isValidForFMA(Op op) {
28  return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
29 }
30 
31 namespace {
32 
33 struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
35 
36  LogicalResult matchAndRewrite(arith::AddFOp op,
37  PatternRewriter &rewriter) const override {
38  if (!isValidForFMA(op))
39  return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
40 
41  Value c;
42  arith::MulFOp ab;
43  if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
44  c = op.getRhs();
45  } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
46  c = op.getLhs();
47  } else {
48  return rewriter.notifyMatchFailure(op, "no mulf op");
49  }
50 
51  if (!isValidForFMA(ab))
52  return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
53 
54  Value a = ab.getLhs();
55  Value b = ab.getRhs();
56  arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
57  rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
58  return success();
59  }
60 };
61 
62 struct MathUpliftToFMA final
63  : math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
64  using MathUpliftToFMABase::MathUpliftToFMABase;
65 
66  void runOnOperation() override {
69  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
70  return signalPassFailure();
71  }
72 };
73 
74 } // namespace
75 
77  patterns.insert<UpliftFma>(patterns.getContext());
78 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static bool isValidForFMA(Op op)
Definition: UpliftToFMA.cpp:27
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateUpliftToFMAPatterns(RewritePatternSet &patterns)
Definition: UpliftToFMA.cpp:76
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362