35 LogicalResult matchAndRewrite(math::PowFOp op,
41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
44 Value x = op.getLhs();
46 FloatAttr scalarExponent;
53 auto isExponentValue = [&](
double value) ->
bool {
55 return scalarExponent.getValue().isExactlyValue(value);
57 if (isVector && vectorExponent.isSplat())
58 return vectorExponent.getSplatValue<FloatAttr>()
60 .isExactlyValue(value);
67 if (
auto vec = dyn_cast<VectorType>(op.getType()))
68 return rewriter.
create<vector::BroadcastOp>(op.getLoc(), vec, value);
73 if (isExponentValue(1.0)) {
79 if (isExponentValue(2.0)) {
85 if (isExponentValue(3.0)) {
93 if (isExponentValue(-1.0)) {
101 if (isExponentValue(0.5)) {
107 if (isExponentValue(-0.5)) {
113 if (isExponentValue(0.75)) {
114 Value powHalf = rewriter.
create<math::SqrtOp>(op.getLoc(), x);
115 Value powQuarter = rewriter.
create<math::SqrtOp>(op.getLoc(), powHalf);
129 template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
132 unsigned exponentThreshold;
135 PowIStrengthReduction(
MLIRContext *context,
unsigned exponentThreshold = 3,
139 exponentThreshold(exponentThreshold) {}
141 LogicalResult matchAndRewrite(PowIOpTy op,
146 template <
typename PowIOpTy,
typename DivOpTy,
typename MulOpTy>
148 PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
151 Value base = op.getLhs();
153 IntegerAttr scalarExponent;
160 int64_t exponentValue = 0;
162 exponentValue = scalarExponent.getInt();
163 else if (isVector && vectorExponent.isSplat())
164 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
169 auto bcast = [&loc, &op, &rewriter](
Value value) ->
Value {
170 if (
auto vec = dyn_cast<VectorType>(op.getType()))
171 return rewriter.
create<vector::BroadcastOp>(loc, vec, value);
177 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178 one = rewriter.
create<arith::ConstantOp>(
181 one = rewriter.
create<arith::ConstantOp>(
185 if (exponentValue == 0) {
190 bool exponentIsNegative =
false;
191 if (exponentValue < 0) {
192 exponentIsNegative =
true;
197 if (exponentValue > exponentThreshold)
202 if (exponentIsNegative)
203 base = rewriter.
create<DivOpTy>(loc, bcast(one), base);
215 for (
unsigned i = 1; i < exponentValue; ++i)
216 result = rewriter.
create<MulOpTy>(loc, result, base);
227 .
add<PowFStrengthReduction,
228 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
229 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...