27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
40 auto *maybeInferredRange =
42 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
45 maybeInferredRange->getValue().getValue();
52 "Can't copy integer ranges between different types");
67 if (!maybeConstValue.has_value())
74 maybeDefiningOp ? maybeDefiningOp->
getDialect()
78 if (
auto shaped = dyn_cast<ShapedType>(type)) {
109 void notifyOperationErased(
Operation *op)
override {
128 LogicalResult matchAndRewrite(
Operation *op,
133 auto needsReplacing = [&](
Value v) {
136 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
138 if (!hasConstantResults)
140 bool hasConstantRegionArgs =
false;
142 for (
Block &block : region.getBlocks()) {
143 hasConstantRegionArgs |=
144 llvm::any_of(block.getArguments(), needsReplacing);
147 if (!hasConstantResults && !hasConstantRegionArgs)
160 PatternRewriter::InsertionGuard guard(rewriter);
162 for (
Block &block : region.getBlocks()) {
177 template <
typename RemOp>
182 LogicalResult matchAndRewrite(RemOp op,
184 Value lhs = op.getOperand(0);
185 Value rhs = op.getOperand(1);
187 if (!maybeModulus.has_value())
189 int64_t modulus = *maybeModulus;
193 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
196 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
197 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
200 if (
min.isNegative() ||
min.uge(modulus))
202 if (
max.isNegative() ||
max.uge(modulus))
222 for (
Value val : values) {
223 auto *maybeInferredRange =
225 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
229 maybeInferredRange->getValue().getValue();
230 ranges.push_back(inferredRange);
237 static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
239 if (
auto shaped = dyn_cast<ShapedType>(srcType))
240 return shaped.clone(dstType);
259 unsigned targetWidth) {
260 unsigned srcWidth = range.
smin().getBitWidth();
261 if (srcWidth <= targetWidth)
263 unsigned removedWidth = srcWidth - targetWidth;
267 bool canTruncateSigned =
268 range.
smin().getNumSignBits() >= (removedWidth + 1) &&
269 range.
smax().getNumSignBits() >= (removedWidth + 1);
270 bool canTruncateUnsigned = range.
umin().countLeadingZeros() >= removedWidth &&
271 range.
umax().countLeadingZeros() >= removedWidth;
272 if (canTruncateSigned && canTruncateUnsigned)
273 return CastKind::Both;
274 if (canTruncateSigned)
276 if (canTruncateUnsigned)
277 return CastKind::Unsigned;
284 if (lhs == CastKind::Both)
286 if (rhs == CastKind::Both)
296 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
297 "Mixing vector and non-vector types");
298 assert(castKind !=
CastKind::None &&
"Can't cast when casting isn't allowed");
301 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
302 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
303 if (srcType == dstType)
306 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
308 return builder.
create<arith::IndexCastOp>(loc, dstType, src);
309 return builder.
create<arith::IndexCastUIOp>(loc, dstType, src);
312 auto srcInt = cast<IntegerType>(srcElemType);
313 auto dstInt = cast<IntegerType>(dstElemType);
314 if (dstInt.getWidth() < srcInt.getWidth())
315 return builder.
create<arith::TruncIOp>(loc, dstType, src);
318 return builder.
create<arith::ExtSIOp>(loc, dstType, src);
319 return builder.
create<arith::ExtUIOp>(loc, dstType, src);
328 LogicalResult matchAndRewrite(
Operation *op,
334 if (failed(collectRanges(solver, op->
getOperands(), ranges)))
336 if (failed(collectRanges(solver, op->
getResults(), ranges)))
344 [=](
Type t) { return t == srcType; }))
346 op,
"no operands or operand types don't match result type");
348 for (
unsigned targetBitwidth : targetBitwidths) {
351 castKind = mergeCastKinds(castKind,
352 checkTruncatability(range, targetBitwidth));
358 Type targetType = getTargetType(srcType, targetBitwidth);
359 if (targetType == srcType)
364 for (
auto [arg, argRange] : llvm::zip_first(op->
getOperands(), ranges)) {
369 argCastKind = CastKind::Both;
370 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
371 mapping.
map(arg, newArg);
377 res.setType(targetType);
381 for (
auto [newRes, oldRes] :
383 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
385 newResults.push_back(castBack);
403 LogicalResult matchAndRewrite(arith::CmpIOp op,
405 Value lhs = op.getLhs();
406 Value rhs = op.getRhs();
409 if (failed(collectRanges(solver, op.getOperands(), ranges)))
415 for (
unsigned targetBitwidth : targetBitwidths) {
416 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
417 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
418 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
423 Type targetType = getTargetType(srcType, targetBitwidth);
424 if (targetType == srcType)
429 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
430 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
431 mapping.
map(lhs, lhsCast);
432 mapping.
map(rhs, rhsCast);
450 template <
typename CastOp>
455 LogicalResult matchAndRewrite(CastOp op,
457 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
461 Value src = srcOp.getIn();
462 if (src.
getType() != op.getType())
465 if (!srcOp.getType().isIndex())
468 auto intType = dyn_cast<IntegerType>(op.getType());
469 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
480 struct IntRangeOptimizationsPass final
481 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
483 void runOnOperation()
override {
490 return signalPassFailure();
492 DataFlowListener listener(solver);
498 config.listener = &listener;
505 struct IntRangeNarrowingPass final
506 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
507 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
509 void runOnOperation()
override {
516 return signalPassFailure();
518 DataFlowListener listener(solver);
526 config.useTopDownTraversal =
false;
527 config.listener = &listener;
537 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538 DeleteTrivialRem<RemUIOp>>(
patterns.getContext(), solver);
546 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547 FoldIndexCastChain<arith::IndexCastOp>>(
patterns.getContext(),
552 return std::make_unique<IntRangeOptimizationsPass>();
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...