36 LogicalResult matchAndRewrite(math::PowFOp op,
42 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
45 Value x = op.getLhs();
47 FloatAttr scalarExponent;
54 auto isExponentValue = [&](
double value) ->
bool {
56 return scalarExponent.getValue().isExactlyValue(value);
58 if (isVector && vectorExponent.isSplat())
59 return vectorExponent.getSplatValue<FloatAttr>()
61 .isExactlyValue(value);
68 if (
auto vec = dyn_cast<VectorType>(op.getType()))
69 return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
74 if (isExponentValue(1.0)) {
80 if (isExponentValue(2.0)) {
86 if (isExponentValue(3.0)) {
88 arith::MulFOp::create(rewriter, op.getLoc(),
ValueRange({x, x}));
94 if (isExponentValue(-1.0)) {
95 Value one = arith::ConstantOp::create(
103 if (isExponentValue(0.5)) {
109 if (isExponentValue(-0.5)) {
115 if (isExponentValue(0.75)) {
116 Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
117 Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
131 template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
134 unsigned exponentThreshold;
137 PowIStrengthReduction(
MLIRContext *context,
unsigned exponentThreshold = 3,
141 exponentThreshold(exponentThreshold) {}
143 LogicalResult matchAndRewrite(PowIOpTy op,
148 template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
150 PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
153 Value base = op.getLhs();
155 IntegerAttr scalarExponent;
162 int64_t exponentValue = 0;
164 exponentValue = scalarExponent.getInt();
165 else if (isVector && vectorExponent.isSplat())
166 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
171 auto bcast = [&loc, &op, &rewriter](
Value value) ->
Value {
172 if (
auto vec = dyn_cast<VectorType>(op.getType()))
173 return vector::BroadcastOp::create(rewriter, loc, vec, value);
179 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
180 one = arith::ConstantOp::create(rewriter, loc,
182 }
else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
183 auto complexTy = cast<ComplexType>(opType);
184 Type elementType = complexTy.getElementType();
185 auto realPart = rewriter.
getFloatAttr(elementType, 1.0);
186 auto imagPart = rewriter.
getFloatAttr(elementType, 0.0);
187 one = complex::ConstantOp::create(
188 rewriter, loc, complexTy, rewriter.
getArrayAttr({realPart, imagPart}));
190 one = arith::ConstantOp::create(rewriter, loc,
195 if (exponentValue == 0) {
200 bool exponentIsNegative =
false;
201 if (exponentValue < 0) {
202 exponentIsNegative =
true;
207 if (exponentValue > exponentThreshold)
221 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
222 return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
223 op.getFastmathAttr());
225 return MulOpTy::create(rewriter, loc, lhs, rhs);
227 for (
unsigned i = 1; i < exponentValue; ++i)
228 result = buildMul(result, base);
232 if (exponentIsNegative) {
233 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
234 result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
235 op.getFastmathAttr());
237 result = DivOpTy::create(rewriter, loc, bcast(one), result);
249 PowFStrengthReduction,
250 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
251 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
252 PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...