36 LogicalResult matchAndRewrite(math::PowFOp op,
37 PatternRewriter &rewriter)
const final;
42PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
44 Location loc = op.getLoc();
45 Value x = op.getLhs();
47 FloatAttr scalarExponent;
48 DenseFPElementsAttr vectorExponent;
54 auto isExponentValue = [&](
double value) ->
bool {
56 return scalarExponent.getValue().isExactlyValue(value);
58 if (isVector && vectorExponent.isSplat())
59 return vectorExponent.getSplatValue<FloatAttr>()
61 .isExactlyValue(value);
67 auto bcast = [&](Value value) -> Value {
68 if (
auto vec = dyn_cast<VectorType>(op.getType()))
69 return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
74 if (isExponentValue(1.0)) {
80 if (isExponentValue(2.0)) {
86 if (isExponentValue(3.0)) {
88 arith::MulFOp::create(rewriter, op.getLoc(),
ValueRange({x, x}));
94 if (isExponentValue(-1.0)) {
95 Value one = arith::ConstantOp::create(
103 if (isExponentValue(0.5)) {
109 if (isExponentValue(-0.5)) {
115 if (isExponentValue(0.75)) {
116 Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
117 Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
131template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
132struct PowIStrengthReduction :
public OpRewritePattern<PowIOpTy> {
134 unsigned exponentThreshold;
137 PowIStrengthReduction(MLIRContext *context,
unsigned exponentThreshold = 3,
138 PatternBenefit benefit = 1,
139 ArrayRef<StringRef> generatedNames = {})
140 : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
141 exponentThreshold(exponentThreshold) {}
143 LogicalResult matchAndRewrite(PowIOpTy op,
144 PatternRewriter &rewriter)
const final;
148template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
150PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
151 PowIOpTy op, PatternRewriter &rewriter)
const {
152 Location loc = op.getLoc();
153 Value base = op.getLhs();
155 IntegerAttr scalarExponent;
156 DenseIntElementsAttr vectorExponent;
162 int64_t exponentValue = 0;
164 exponentValue = scalarExponent.getInt();
165 else if (isVector && vectorExponent.isSplat())
166 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
171 auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
172 if (
auto vec = dyn_cast<VectorType>(op.getType()))
173 return vector::BroadcastOp::create(rewriter, loc, vec, value);
179 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
180 one = arith::ConstantOp::create(rewriter, loc,
182 }
else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
183 auto complexTy = cast<ComplexType>(opType);
184 Type elementType = complexTy.getElementType();
185 auto realPart = rewriter.
getFloatAttr(elementType, 1.0);
186 auto imagPart = rewriter.
getFloatAttr(elementType, 0.0);
187 one = complex::ConstantOp::create(
188 rewriter, loc, complexTy, rewriter.
getArrayAttr({realPart, imagPart}));
190 one = arith::ConstantOp::create(rewriter, loc,
195 if (exponentValue == 0) {
200 bool exponentIsNegative =
false;
201 if (exponentValue < 0) {
202 exponentIsNegative =
true;
207 if (exponentValue > exponentThreshold)
220 auto buildMul = [&](Value
lhs, Value
rhs) {
221 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
222 return MulOpTy::create(rewriter, loc, op.getType(),
lhs,
rhs,
223 op.getFastmathAttr());
225 return MulOpTy::create(rewriter, loc,
lhs,
rhs);
227 for (
unsigned i = 1; i < exponentValue; ++i)
232 if (exponentIsNegative) {
233 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
234 result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one),
result,
235 op.getFastmathAttr());
237 result = DivOpTy::create(rewriter, loc, bcast(one),
result);
249 PowFStrengthReduction,
250 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
251 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
252 PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...