27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
40 auto *maybeInferredRange =
42 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
45 maybeInferredRange->getValue().getValue();
52 "Can't copy integer ranges between different types");
67 if (!maybeConstValue.has_value())
74 maybeDefiningOp ? maybeDefiningOp->
getDialect()
78 if (
auto shaped = dyn_cast<ShapedType>(type)) {
105 void notifyOperationErased(
Operation *op)
override {
123 LogicalResult match(
Operation *op)
const override {
127 auto needsReplacing = [&](
Value v) {
130 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
132 return success(hasConstantResults);
133 bool hasConstantRegionArgs =
false;
135 for (
Block &block : region.getBlocks()) {
136 hasConstantRegionArgs |=
137 llvm::any_of(block.getArguments(), needsReplacing);
140 return success(hasConstantResults || hasConstantRegionArgs);
154 PatternRewriter::InsertionGuard guard(rewriter);
156 for (
Block &block : region.getBlocks()) {
169 template <
typename RemOp>
174 LogicalResult matchAndRewrite(RemOp op,
176 Value lhs = op.getOperand(0);
177 Value rhs = op.getOperand(1);
179 if (!maybeModulus.has_value())
181 int64_t modulus = *maybeModulus;
185 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
188 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
189 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
192 if (
min.isNegative() ||
min.uge(modulus))
194 if (
max.isNegative() ||
max.uge(modulus))
214 for (
Value val : values) {
215 auto *maybeInferredRange =
217 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
221 maybeInferredRange->getValue().getValue();
222 ranges.push_back(inferredRange);
229 static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
231 if (
auto shaped = dyn_cast<ShapedType>(srcType))
232 return shaped.clone(dstType);
251 unsigned targetWidth) {
252 unsigned srcWidth = range.
smin().getBitWidth();
253 if (srcWidth <= targetWidth)
255 unsigned removedWidth = srcWidth - targetWidth;
259 bool canTruncateSigned =
260 range.
smin().getNumSignBits() >= (removedWidth + 1) &&
261 range.
smax().getNumSignBits() >= (removedWidth + 1);
262 bool canTruncateUnsigned = range.
umin().countLeadingZeros() >= removedWidth &&
263 range.
umax().countLeadingZeros() >= removedWidth;
264 if (canTruncateSigned && canTruncateUnsigned)
265 return CastKind::Both;
266 if (canTruncateSigned)
268 if (canTruncateUnsigned)
269 return CastKind::Unsigned;
276 if (lhs == CastKind::Both)
278 if (rhs == CastKind::Both)
288 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
289 "Mixing vector and non-vector types");
290 assert(castKind !=
CastKind::None &&
"Can't cast when casting isn't allowed");
293 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
294 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
295 if (srcType == dstType)
298 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
300 return builder.
create<arith::IndexCastOp>(loc, dstType, src);
301 return builder.
create<arith::IndexCastUIOp>(loc, dstType, src);
304 auto srcInt = cast<IntegerType>(srcElemType);
305 auto dstInt = cast<IntegerType>(dstElemType);
306 if (dstInt.getWidth() < srcInt.getWidth())
307 return builder.
create<arith::TruncIOp>(loc, dstType, src);
310 return builder.
create<arith::ExtSIOp>(loc, dstType, src);
311 return builder.
create<arith::ExtUIOp>(loc, dstType, src);
320 LogicalResult matchAndRewrite(
Operation *op,
326 if (failed(collectRanges(solver, op->
getOperands(), ranges)))
328 if (failed(collectRanges(solver, op->
getResults(), ranges)))
336 [=](
Type t) { return t == srcType; }))
338 op,
"no operands or operand types don't match result type");
340 for (
unsigned targetBitwidth : targetBitwidths) {
343 castKind = mergeCastKinds(castKind,
344 checkTruncatability(range, targetBitwidth));
350 Type targetType = getTargetType(srcType, targetBitwidth);
351 if (targetType == srcType)
356 for (
auto [arg, argRange] : llvm::zip_first(op->
getOperands(), ranges)) {
361 argCastKind = CastKind::Both;
362 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
363 mapping.
map(arg, newArg);
369 res.setType(targetType);
373 for (
auto [newRes, oldRes] :
375 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
377 newResults.push_back(castBack);
395 LogicalResult matchAndRewrite(arith::CmpIOp op,
397 Value lhs = op.getLhs();
398 Value rhs = op.getRhs();
401 if (failed(collectRanges(solver, op.getOperands(), ranges)))
407 for (
unsigned targetBitwidth : targetBitwidths) {
408 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
409 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
410 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
415 Type targetType = getTargetType(srcType, targetBitwidth);
416 if (targetType == srcType)
421 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
422 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
423 mapping.
map(lhs, lhsCast);
424 mapping.
map(rhs, rhsCast);
442 template <
typename CastOp>
447 LogicalResult matchAndRewrite(CastOp op,
449 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
453 Value src = srcOp.getIn();
454 if (src.
getType() != op.getType())
457 if (!srcOp.getType().isIndex())
460 auto intType = dyn_cast<IntegerType>(op.getType());
461 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
472 struct IntRangeOptimizationsPass final
473 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
475 void runOnOperation()
override {
482 return signalPassFailure();
484 DataFlowListener listener(solver);
490 config.listener = &listener;
497 struct IntRangeNarrowingPass final
498 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
499 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
501 void runOnOperation()
override {
508 return signalPassFailure();
510 DataFlowListener listener(solver);
518 config.useTopDownTraversal =
false;
519 config.listener = &listener;
529 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
530 DeleteTrivialRem<RemUIOp>>(
patterns.getContext(), solver);
538 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
539 FoldIndexCastChain<arith::IndexCastOp>>(
patterns.getContext(),
544 return std::make_unique<IntRangeOptimizationsPass>();
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value)
Patterned after SCCP.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...