MLIR  15.0.0git
AlgebraicSimplification.cpp
Go to the documentation of this file.
1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include <climits>
23 
24 using namespace mlir;
25 
26 //----------------------------------------------------------------------------//
27 // PowFOp strength reduction.
28 //----------------------------------------------------------------------------//
29 
30 namespace {
31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32 public:
34 
35  LogicalResult matchAndRewrite(math::PowFOp op,
36  PatternRewriter &rewriter) const final;
37 };
38 } // namespace
39 
41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42  PatternRewriter &rewriter) const {
43  Location loc = op.getLoc();
44  Value x = op.getLhs();
45 
46  FloatAttr scalarExponent;
47  DenseFPElementsAttr vectorExponent;
48 
49  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51 
52  // Returns true if exponent is a constant equal to `value`.
53  auto isExponentValue = [&](double value) -> bool {
54  if (isScalar)
55  return scalarExponent.getValue().isExactlyValue(value);
56 
57  if (isVector && vectorExponent.isSplat())
58  return vectorExponent.getSplatValue<FloatAttr>()
59  .getValue()
60  .isExactlyValue(value);
61 
62  return false;
63  };
64 
65  // Maybe broadcasts scalar value into vector type compatible with `op`.
66  auto bcast = [&](Value value) -> Value {
67  if (auto vec = op.getType().dyn_cast<VectorType>())
68  return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69  return value;
70  };
71 
72  // Replace `pow(x, 1.0)` with `x`.
73  if (isExponentValue(1.0)) {
74  rewriter.replaceOp(op, x);
75  return success();
76  }
77 
78  // Replace `pow(x, 2.0)` with `x * x`.
79  if (isExponentValue(2.0)) {
80  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81  return success();
82  }
83 
84  // Replace `pow(x, 3.0)` with `x * x * x`.
85  if (isExponentValue(3.0)) {
86  Value square =
87  rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89  return success();
90  }
91 
92  // Replace `pow(x, -1.0)` with `1.0 / x`.
93  if (isExponentValue(-1.0)) {
94  Value one = rewriter.create<arith::ConstantOp>(
95  loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97  return success();
98  }
99 
100  // Replace `pow(x, 0.5)` with `sqrt(x)`.
101  if (isExponentValue(0.5)) {
102  rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103  return success();
104  }
105 
106  // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107  if (isExponentValue(-0.5)) {
108  rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109  return success();
110  }
111 
112  return failure();
113 }
114 
115 //----------------------------------------------------------------------------//
116 
118  RewritePatternSet &patterns) {
119  patterns.add<PowFStrengthReduction>(patterns.getContext());
120 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
An attribute that represents a reference to a dense float vector or tensor object.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:259
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:333
This class provides an abstraction over the different types of ranges over Values.
MLIRContext * getContext() const