MLIR  19.0.0git
UpliftToFMA.cpp
Go to the documentation of this file.
1 //===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements uplifting from arith ops to math.fma.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir::math {
20 #define GEN_PASS_DEF_MATHUPLIFTTOFMA
21 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
22 } // namespace mlir::math
23 
24 using namespace mlir;
25 
26 template <typename Op>
27 static bool isValidForFMA(Op op) {
28  return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
29 }
30 
31 namespace {
32 
33 struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
35 
36  LogicalResult matchAndRewrite(arith::AddFOp op,
37  PatternRewriter &rewriter) const override {
38  if (!isValidForFMA(op))
39  return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
40 
41  Value c;
42  arith::MulFOp ab;
43  if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
44  c = op.getRhs();
45  } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
46  c = op.getLhs();
47  } else {
48  return rewriter.notifyMatchFailure(op, "no mulf op");
49  }
50 
51  if (!isValidForFMA(ab))
52  return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
53 
54  Value a = ab.getLhs();
55  Value b = ab.getRhs();
56  arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
57  rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
58  return success();
59  }
60 };
61 
62 struct MathUpliftToFMA final
63  : math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
64  using MathUpliftToFMABase::MathUpliftToFMABase;
65 
66  void runOnOperation() override {
67  RewritePatternSet patterns(&getContext());
69  if (failed(
70  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
71  return signalPassFailure();
72  }
73 };
74 
75 } // namespace
76 
78  patterns.insert<UpliftFma>(patterns.getContext());
79 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static bool isValidForFMA(Op op)
Definition: UpliftToFMA.cpp:27
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
void populateUpliftToFMAPatterns(RewritePatternSet &patterns)
Definition: UpliftToFMA.cpp:77
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362