29#define GEN_PASS_DEF_ARITHINTRANGEOPTS
30#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32#define GEN_PASS_DEF_ARITHINTRANGENARROWING
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
42 auto *maybeInferredRange =
44 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
47 maybeInferredRange->getValue().getValue();
54 "Can't copy integer ranges between different types");
69 if (!maybeConstValue.has_value())
76 maybeDefiningOp ? maybeDefiningOp->
getDialect()
80 if (
auto shaped = dyn_cast<ShapedType>(type)) {
111 void notifyOperationErased(Operation *op)
override {
125 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
126 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
130 LogicalResult matchAndRewrite(Operation *op,
131 PatternRewriter &rewriter)
const override {
135 auto needsReplacing = [&](Value v) {
138 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
140 if (!hasConstantResults)
142 bool hasConstantRegionArgs =
false;
144 for (
Block &block : region.getBlocks()) {
145 hasConstantRegionArgs |=
146 llvm::any_of(block.getArguments(), needsReplacing);
149 if (!hasConstantResults && !hasConstantRegionArgs)
162 PatternRewriter::InsertionGuard guard(rewriter);
164 for (
Block &block : region.getBlocks()) {
166 for (BlockArgument &arg : block.getArguments()) {
176 DataFlowSolver &solver;
179template <
typename RemOp>
181 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
182 : OpRewritePattern<RemOp>(context), solver(s) {}
184 LogicalResult matchAndRewrite(RemOp op,
185 PatternRewriter &rewriter)
const override {
186 Value
lhs = op.getOperand(0);
187 Value
rhs = op.getOperand(1);
189 if (!maybeModulus.has_value())
191 int64_t modulus = *maybeModulus;
194 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(
lhs);
195 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
197 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
198 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
199 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
202 if (
min.isNegative() ||
min.uge(modulus))
204 if (
max.isNegative() ||
max.uge(modulus))
217 DataFlowSolver &solver;
224 for (
Value val : values) {
225 auto *maybeInferredRange =
227 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
231 maybeInferredRange->getValue().getValue();
232 ranges.push_back(inferredRange);
239static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
240 auto dstType = IntegerType::get(srcType.
getContext(), targetBitwidth);
241 if (
auto shaped = dyn_cast<ShapedType>(srcType))
242 return shaped.clone(dstType);
261 unsigned targetWidth) {
262 unsigned srcWidth = range.
smin().getBitWidth();
263 if (srcWidth <= targetWidth)
264 return CastKind::None;
265 unsigned removedWidth = srcWidth - targetWidth;
269 bool canTruncateSigned =
270 range.
smin().getNumSignBits() >= (removedWidth + 1) &&
271 range.
smax().getNumSignBits() >= (removedWidth + 1);
272 bool canTruncateUnsigned = range.
umin().countLeadingZeros() >= removedWidth &&
273 range.
umax().countLeadingZeros() >= removedWidth;
274 if (canTruncateSigned && canTruncateUnsigned)
275 return CastKind::Both;
276 if (canTruncateSigned)
277 return CastKind::Signed;
278 if (canTruncateUnsigned)
279 return CastKind::Unsigned;
280 return CastKind::None;
283static CastKind mergeCastKinds(CastKind
lhs, CastKind
rhs) {
284 if (
lhs == CastKind::None ||
rhs == CastKind::None)
285 return CastKind::None;
286 if (
lhs == CastKind::Both)
288 if (
rhs == CastKind::Both)
292 return CastKind::None;
298 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
299 "Mixing vector and non-vector types");
300 assert(castKind != CastKind::None &&
"Can't cast when casting isn't allowed");
303 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
304 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
305 if (srcType == dstType)
308 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
309 if (castKind == CastKind::Signed)
310 return arith::IndexCastOp::create(builder, loc, dstType, src);
311 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
314 auto srcInt = cast<IntegerType>(srcElemType);
315 auto dstInt = cast<IntegerType>(dstElemType);
316 if (dstInt.getWidth() < srcInt.getWidth())
317 return arith::TruncIOp::create(builder, loc, dstType, src);
319 if (castKind == CastKind::Signed)
320 return arith::ExtSIOp::create(builder, loc, dstType, src);
321 return arith::ExtUIOp::create(builder, loc, dstType, src);
325 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
326 ArrayRef<unsigned>
target)
327 : OpTraitRewritePattern(context), solver(s), targetBitwidths(
target) {}
330 LogicalResult matchAndRewrite(Operation *op,
331 PatternRewriter &rewriter)
const override {
335 SmallVector<ConstantIntRanges> ranges;
346 [=](Type t) { return t == srcType; }))
348 op,
"no operands or operand types don't match result type");
350 for (
unsigned targetBitwidth : targetBitwidths) {
351 CastKind castKind = CastKind::Both;
352 for (
const ConstantIntRanges &range : ranges) {
353 castKind = mergeCastKinds(castKind,
354 checkTruncatability(range, targetBitwidth));
355 if (castKind == CastKind::None)
358 if (castKind == CastKind::None)
360 Type targetType = getTargetType(srcType, targetBitwidth);
361 if (targetType == srcType)
364 Location loc = op->
getLoc();
366 for (
auto [arg, argRange] : llvm::zip_first(op->
getOperands(), ranges)) {
367 CastKind argCastKind = castKind;
370 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
371 argCastKind = CastKind::Both;
372 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
373 mapping.
map(arg, newArg);
376 Operation *newOp = rewriter.
clone(*op, mapping);
379 res.setType(targetType);
382 SmallVector<Value> newResults;
383 for (
auto [newRes, oldRes] :
385 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
387 newResults.push_back(castBack);
397 DataFlowSolver &solver;
398 SmallVector<unsigned, 4> targetBitwidths;
402 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned>
target)
403 : OpRewritePattern(context), solver(s), targetBitwidths(
target) {}
405 LogicalResult matchAndRewrite(arith::CmpIOp op,
406 PatternRewriter &rewriter)
const override {
407 Value
lhs = op.getLhs();
408 Value
rhs = op.getRhs();
410 SmallVector<ConstantIntRanges> ranges;
411 if (
failed(collectRanges(solver, op.getOperands(), ranges)))
413 const ConstantIntRanges &lhsRange = ranges[0];
414 const ConstantIntRanges &rhsRange = ranges[1];
416 Type srcType =
lhs.getType();
417 for (
unsigned targetBitwidth : targetBitwidths) {
418 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
419 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
420 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
422 if (castKind == CastKind::None)
425 Type targetType = getTargetType(srcType, targetBitwidth);
426 if (targetType == srcType)
429 Location loc = op->getLoc();
431 Value lhsCast = doCast(rewriter, loc,
lhs, targetType, lhsCastKind);
432 Value rhsCast = doCast(rewriter, loc,
rhs, targetType, rhsCastKind);
433 mapping.
map(
lhs, lhsCast);
434 mapping.
map(
rhs, rhsCast);
436 Operation *newOp = rewriter.
clone(*op, mapping);
445 DataFlowSolver &solver;
446 SmallVector<unsigned, 4> targetBitwidths;
452template <
typename CastOp>
454 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned>
target)
455 : OpRewritePattern<CastOp>(context), targetBitwidths(
target) {}
457 LogicalResult matchAndRewrite(CastOp op,
458 PatternRewriter &rewriter)
const override {
459 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
463 Value src = srcOp.getIn();
464 if (src.
getType() != op.getType())
467 if (!srcOp.getType().isIndex())
470 auto intType = dyn_cast<IntegerType>(op.getType());
471 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
479 SmallVector<unsigned, 4> targetBitwidths;
482struct IntRangeOptimizationsPass final
483 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
485 void runOnOperation()
override {
486 Operation *op = getOperation();
488 DataFlowSolver solver;
490 solver.
load<IntegerRangeAnalysis>();
492 return signalPassFailure();
494 DataFlowListener listener(solver);
501 GreedyRewriteConfig().setListener(&listener))))
506struct IntRangeNarrowingPass final
507 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
508 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
510 void runOnOperation()
override {
511 Operation *op = getOperation();
513 DataFlowSolver solver;
515 solver.
load<IntegerRangeAnalysis>();
517 return signalPassFailure();
519 DataFlowListener listener(solver);
528 GreedyRewriteConfig().setUseTopDownTraversal(
false).setListener(
537 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538 DeleteTrivialRem<RemUIOp>>(
patterns.getContext(), solver);
546 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547 FoldIndexCastChain<arith::IndexCastOp>>(
patterns.getContext(),
552 return std::make_unique<IntRangeOptimizationsPass>();
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...