27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
40 auto *maybeInferredRange =
42 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
45 maybeInferredRange->getValue().getValue();
56 if (!maybeConstValue.has_value())
63 maybeDefiningOp ? maybeDefiningOp->
getDialect()
67 if (
auto shaped = dyn_cast<ShapedType>(type)) {
93 void notifyOperationErased(
Operation *op)
override {
111 LogicalResult match(
Operation *op)
const override {
115 auto needsReplacing = [&](
Value v) {
118 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
120 return success(hasConstantResults);
121 bool hasConstantRegionArgs =
false;
123 for (
Block &block : region.getBlocks()) {
124 hasConstantRegionArgs |=
125 llvm::any_of(block.getArguments(), needsReplacing);
128 return success(hasConstantResults || hasConstantRegionArgs);
142 PatternRewriter::InsertionGuard guard(rewriter);
144 for (
Block &block : region.getBlocks()) {
157 template <
typename RemOp>
162 LogicalResult matchAndRewrite(RemOp op,
164 Value lhs = op.getOperand(0);
165 Value rhs = op.getOperand(1);
167 if (!maybeModulus.has_value())
169 int64_t modulus = *maybeModulus;
173 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
176 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
177 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
180 if (
min.isNegative() ||
min.uge(modulus))
182 if (
max.isNegative() ||
max.uge(modulus))
199 static LogicalResult checkIntType(
Type type,
unsigned targetBitwidth) {
201 if (isa<IndexType>(elemType))
204 if (
auto intType = dyn_cast<IntegerType>(elemType))
205 if (intType.getWidth() > targetBitwidth)
213 static LogicalResult checkElementwiseOpType(
Operation *op,
214 unsigned targetBitwidth) {
221 type = val.getType();
225 if (type != val.getType())
229 return checkIntType(type, targetBitwidth);
233 static std::optional<ConstantIntRanges> getOperandsRange(
DataFlowSolver &solver,
235 std::optional<ConstantIntRanges> ret;
236 for (
Value value : operands) {
237 auto *maybeInferredRange =
239 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
243 maybeInferredRange->getValue().getValue();
245 ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
252 static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
254 if (
auto shaped = dyn_cast<ShapedType>(srcType))
255 return shaped.clone(dstType);
263 APInt smax, APInt umin, APInt umax) {
264 auto sge = [](APInt val1, APInt val2) ->
bool {
265 unsigned width =
std::max(val1.getBitWidth(), val2.getBitWidth());
266 val1 = val1.sext(width);
267 val2 = val2.sext(width);
268 return val1.sge(val2);
270 auto sle = [](APInt val1, APInt val2) ->
bool {
271 unsigned width =
std::max(val1.getBitWidth(), val2.getBitWidth());
272 val1 = val1.sext(width);
273 val2 = val2.sext(width);
274 return val1.sle(val2);
276 auto uge = [](APInt val1, APInt val2) ->
bool {
277 unsigned width =
std::max(val1.getBitWidth(), val2.getBitWidth());
278 val1 = val1.zext(width);
279 val2 = val2.zext(width);
280 return val1.uge(val2);
282 auto ule = [](APInt val1, APInt val2) ->
bool {
283 unsigned width =
std::max(val1.getBitWidth(), val2.getBitWidth());
284 val1 = val1.zext(width);
285 val2 = val2.zext(width);
286 return val1.ule(val2);
288 return success(
sge(range.
smin(), smin) &&
sle(range.
smax(), smax) &&
294 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
295 "Mixing vector and non-vector types");
298 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
299 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
300 if (srcType == dstType)
303 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
304 return builder.
create<arith::IndexCastUIOp>(loc, dstType, src);
306 auto srcInt = cast<IntegerType>(srcElemType);
307 auto dstInt = cast<IntegerType>(dstElemType);
308 if (dstInt.getWidth() < srcInt.getWidth())
309 return builder.
create<arith::TruncIOp>(loc, dstType, src);
311 return builder.
create<arith::ExtUIOp>(loc, dstType, src);
320 LogicalResult matchAndRewrite(
Operation *op,
322 std::optional<ConstantIntRanges> range =
327 for (
unsigned targetBitwidth : targetBitwidths) {
328 if (failed(checkElementwiseOpType(op, targetBitwidth)))
338 auto smax = APInt::getSignedMaxValue(targetBitwidth);
339 auto umin = APInt::getMinValue(targetBitwidth);
340 auto umax = APInt::getMaxValue(targetBitwidth);
341 if (failed(checkRange(*range, smin, smax, umin, umax)))
344 Type targetType = getTargetType(srcType, targetBitwidth);
345 if (targetType == srcType)
351 Value newArg = doCast(rewriter, loc, arg, targetType);
352 mapping.
map(arg, newArg);
358 res.setType(targetType);
363 newResults.emplace_back(doCast(rewriter, loc, res, srcType));
380 LogicalResult matchAndRewrite(arith::CmpIOp op,
382 Value lhs = op.getLhs();
383 Value rhs = op.getRhs();
385 std::optional<ConstantIntRanges> range =
386 getOperandsRange(solver, {lhs, rhs});
390 for (
unsigned targetBitwidth : targetBitwidths) {
392 if (failed(checkIntType(srcType, targetBitwidth)))
395 auto smin = APInt::getSignedMinValue(targetBitwidth);
396 auto smax = APInt::getSignedMaxValue(targetBitwidth);
397 auto umin = APInt::getMinValue(targetBitwidth);
398 auto umax = APInt::getMaxValue(targetBitwidth);
399 if (failed(checkRange(*range, smin, smax, umin, umax)))
402 Type targetType = getTargetType(srcType, targetBitwidth);
403 if (targetType == srcType)
408 for (
Value arg : op->getOperands()) {
409 Value newArg = doCast(rewriter, loc, arg, targetType);
410 mapping.
map(arg, newArg);
432 LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
434 auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
438 Value src = srcOp.getIn();
439 if (src.
getType() != op.getType())
442 auto intType = dyn_cast<IntegerType>(op.getType());
443 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
454 struct IntRangeOptimizationsPass final
455 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
457 void runOnOperation()
override {
464 return signalPassFailure();
466 DataFlowListener listener(solver);
479 struct IntRangeNarrowingPass final
480 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
481 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
483 void runOnOperation()
override {
490 return signalPassFailure();
492 DataFlowListener listener(solver);
511 patterns.
add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
512 DeleteTrivialRem<RemUIOp>>(patterns.
getContext(), solver);
518 patterns.
add<NarrowElementwise, NarrowCmpI>(patterns.
getContext(), solver,
520 patterns.
add<FoldIndexCastChain>(patterns.
getContext(), bitwidthsSupported);
524 return std::make_unique<IntRangeOptimizationsPass>();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value)
Patterned after SCCP.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...