36 LogicalResult matchAndRewrite(math::PowFOp op,
37 PatternRewriter &rewriter)
const final;
42PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
44 Location loc = op.getLoc();
45 Value x = op.getLhs();
46 arith::FastMathFlags fmf = op.getFastmathAttr().getValue();
48 FloatAttr scalarExponent;
49 DenseFPElementsAttr vectorExponent;
55 auto isExponentValue = [&](
double value) ->
bool {
57 return scalarExponent.getValue().isExactlyValue(value);
59 if (isVector && vectorExponent.isSplat())
60 return vectorExponent.getSplatValue<FloatAttr>()
62 .isExactlyValue(value);
68 auto bcast = [&](Value value) -> Value {
69 if (
auto vec = dyn_cast<VectorType>(op.getType()))
70 return vector::BroadcastOp::create(rewriter, loc, vec, value);
75 if (isExponentValue(1.0)) {
81 if (isExponentValue(2.0)) {
87 if (isExponentValue(3.0)) {
88 Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
94 if (isExponentValue(-1.0)) {
95 Value one = arith::ConstantOp::create(
103 if (isExponentValue(0.5)) {
109 if (isExponentValue(-0.5)) {
115 if (isExponentValue(0.75)) {
116 Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
117 Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
130template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
131struct PowIStrengthReduction :
public OpRewritePattern<PowIOpTy> {
133 unsigned exponentThreshold;
136 PowIStrengthReduction(MLIRContext *context,
unsigned exponentThreshold = 3,
137 PatternBenefit benefit = 1,
138 ArrayRef<StringRef> generatedNames = {})
139 : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
140 exponentThreshold(exponentThreshold) {}
142 LogicalResult matchAndRewrite(PowIOpTy op,
143 PatternRewriter &rewriter)
const final;
147template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
149PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
150 PowIOpTy op, PatternRewriter &rewriter)
const {
151 Location loc = op.getLoc();
152 Value base = op.getLhs();
154 IntegerAttr scalarExponent;
155 DenseIntElementsAttr vectorExponent;
161 int64_t exponentValue = 0;
163 exponentValue = scalarExponent.getInt();
164 else if (isVector && vectorExponent.isSplat())
165 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
170 auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
171 if (
auto vec = dyn_cast<VectorType>(op.getType()))
172 return vector::BroadcastOp::create(rewriter, loc, vec, value);
178 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
179 one = arith::ConstantOp::create(rewriter, loc,
181 }
else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
182 auto complexTy = cast<ComplexType>(opType);
183 Type elementType = complexTy.getElementType();
184 auto realPart = rewriter.
getFloatAttr(elementType, 1.0);
185 auto imagPart = rewriter.
getFloatAttr(elementType, 0.0);
186 one = complex::ConstantOp::create(
187 rewriter, loc, complexTy, rewriter.
getArrayAttr({realPart, imagPart}));
189 one = arith::ConstantOp::create(rewriter, loc,
194 if (exponentValue == 0) {
199 bool exponentIsNegative =
false;
200 if (exponentValue < 0) {
201 exponentIsNegative =
true;
206 if (exponentValue > exponentThreshold)
219 auto buildMul = [&](Value
lhs, Value
rhs) {
220 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
221 return MulOpTy::create(rewriter, loc, op.getType(),
lhs,
rhs,
222 op.getFastmathAttr());
224 return MulOpTy::create(rewriter, loc,
lhs,
rhs);
226 for (
unsigned i = 1; i < exponentValue; ++i)
231 if (exponentIsNegative) {
232 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
233 result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one),
result,
234 op.getFastmathAttr());
236 result = DivOpTy::create(rewriter, loc, bcast(one),
result);
248 PowFStrengthReduction,
249 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
250 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
251 PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...