28 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
29 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
31 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
32 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
41 auto *maybeInferredRange =
43 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
46 maybeInferredRange->getValue().getValue();
53 "Can't copy integer ranges between different types");
68 if (!maybeConstValue.has_value())
75 maybeDefiningOp ? maybeDefiningOp->
getDialect()
79 if (
auto shaped = dyn_cast<ShapedType>(type)) {
110 void notifyOperationErased(
Operation *op)
override {
129 LogicalResult matchAndRewrite(
Operation *op,
134 auto needsReplacing = [&](
Value v) {
137 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
139 if (!hasConstantResults)
141 bool hasConstantRegionArgs =
false;
143 for (
Block &block : region.getBlocks()) {
144 hasConstantRegionArgs |=
145 llvm::any_of(block.getArguments(), needsReplacing);
148 if (!hasConstantResults && !hasConstantRegionArgs)
161 PatternRewriter::InsertionGuard guard(rewriter);
163 for (
Block &block : region.getBlocks()) {
178 template <
typename RemOp>
183 LogicalResult matchAndRewrite(RemOp op,
185 Value lhs = op.getOperand(0);
186 Value rhs = op.getOperand(1);
188 if (!maybeModulus.has_value())
190 int64_t modulus = *maybeModulus;
194 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
197 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
198 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
201 if (
min.isNegative() ||
min.uge(modulus))
203 if (
max.isNegative() ||
max.uge(modulus))
223 for (
Value val : values) {
224 auto *maybeInferredRange =
226 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
230 maybeInferredRange->getValue().getValue();
231 ranges.push_back(inferredRange);
238 static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
240 if (
auto shaped = dyn_cast<ShapedType>(srcType))
241 return shaped.clone(dstType);
260 unsigned targetWidth) {
261 unsigned srcWidth = range.
smin().getBitWidth();
262 if (srcWidth <= targetWidth)
264 unsigned removedWidth = srcWidth - targetWidth;
268 bool canTruncateSigned =
269 range.
smin().getNumSignBits() >= (removedWidth + 1) &&
270 range.
smax().getNumSignBits() >= (removedWidth + 1);
271 bool canTruncateUnsigned = range.
umin().countLeadingZeros() >= removedWidth &&
272 range.
umax().countLeadingZeros() >= removedWidth;
273 if (canTruncateSigned && canTruncateUnsigned)
274 return CastKind::Both;
275 if (canTruncateSigned)
277 if (canTruncateUnsigned)
278 return CastKind::Unsigned;
285 if (lhs == CastKind::Both)
287 if (rhs == CastKind::Both)
297 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
298 "Mixing vector and non-vector types");
299 assert(castKind !=
CastKind::None &&
"Can't cast when casting isn't allowed");
302 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
303 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
304 if (srcType == dstType)
307 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
309 return arith::IndexCastOp::create(builder, loc, dstType, src);
310 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
313 auto srcInt = cast<IntegerType>(srcElemType);
314 auto dstInt = cast<IntegerType>(dstElemType);
315 if (dstInt.getWidth() < srcInt.getWidth())
316 return arith::TruncIOp::create(builder, loc, dstType, src);
319 return arith::ExtSIOp::create(builder, loc, dstType, src);
320 return arith::ExtUIOp::create(builder, loc, dstType, src);
329 LogicalResult matchAndRewrite(
Operation *op,
345 [=](
Type t) { return t == srcType; }))
347 op,
"no operands or operand types don't match result type");
349 for (
unsigned targetBitwidth : targetBitwidths) {
352 castKind = mergeCastKinds(castKind,
353 checkTruncatability(range, targetBitwidth));
359 Type targetType = getTargetType(srcType, targetBitwidth);
360 if (targetType == srcType)
365 for (
auto [arg, argRange] : llvm::zip_first(op->
getOperands(), ranges)) {
370 argCastKind = CastKind::Both;
371 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
372 mapping.
map(arg, newArg);
378 res.setType(targetType);
382 for (
auto [newRes, oldRes] :
384 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
386 newResults.push_back(castBack);
404 LogicalResult matchAndRewrite(arith::CmpIOp op,
406 Value lhs = op.getLhs();
407 Value rhs = op.getRhs();
410 if (
failed(collectRanges(solver, op.getOperands(), ranges)))
416 for (
unsigned targetBitwidth : targetBitwidths) {
417 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
418 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
419 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
424 Type targetType = getTargetType(srcType, targetBitwidth);
425 if (targetType == srcType)
430 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
431 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
432 mapping.
map(lhs, lhsCast);
433 mapping.
map(rhs, rhsCast);
451 template <
typename CastOp>
456 LogicalResult matchAndRewrite(CastOp op,
458 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
462 Value src = srcOp.getIn();
463 if (src.
getType() != op.getType())
466 if (!srcOp.getType().isIndex())
469 auto intType = dyn_cast<IntegerType>(op.getType());
470 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
481 struct IntRangeOptimizationsPass final
482 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
484 void runOnOperation()
override {
492 return signalPassFailure();
494 DataFlowListener listener(solver);
506 struct IntRangeNarrowingPass final
507 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
508 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
510 void runOnOperation()
override {
517 return signalPassFailure();
519 DataFlowListener listener(solver);
537 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538 DeleteTrivialRem<RemUIOp>>(
patterns.getContext(), solver);
546 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547 FoldIndexCastChain<arith::IndexCastOp>>(
patterns.getContext(),
552 return std::make_unique<IntRangeOptimizationsPass>();
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...