11#include "llvm/ADT/TypeSwitch.h"
32#define GEN_PASS_DEF_ARITHINTRANGEOPTS
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
35#define GEN_PASS_DEF_ARITHINTRANGENARROWING
36#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
45 auto *maybeInferredRange =
47 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
50 maybeInferredRange->getValue().getValue();
70 if (!maybeConstValue.has_value())
81 maybeDefiningOp ? maybeDefiningOp->
getDialect()
85 if (
auto shaped = dyn_cast<ShapedType>(type)) {
116 void notifyOperationErased(Operation *op)
override {
130 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
131 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
135 LogicalResult matchAndRewrite(Operation *op,
136 PatternRewriter &rewriter)
const override {
145 auto needsReplacing = [&](Value v) {
149 bool hasConstantResults = llvm::any_of(op->
getResults(), needsReplacing);
151 if (!hasConstantResults)
153 bool hasConstantRegionArgs =
false;
155 for (
Block &block : region.getBlocks()) {
156 hasConstantRegionArgs |=
157 llvm::any_of(block.getArguments(), needsReplacing);
160 if (!hasConstantResults && !hasConstantRegionArgs)
173 PatternRewriter::InsertionGuard guard(rewriter);
175 for (
Block &block : region.getBlocks()) {
177 for (BlockArgument &arg : block.getArguments()) {
187 DataFlowSolver &solver;
190template <
typename RemOp>
192 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
193 : OpRewritePattern<RemOp>(context), solver(s) {}
195 LogicalResult matchAndRewrite(RemOp op,
196 PatternRewriter &rewriter)
const override {
197 Value
lhs = op.getOperand(0);
198 Value
rhs = op.getOperand(1);
200 if (!maybeModulus.has_value())
202 int64_t modulus = *maybeModulus;
205 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(
lhs);
206 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
208 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
209 const APInt &
min = isa<RemUIOp>(op) ? lhsRange.
umin() : lhsRange.
smin();
210 const APInt &
max = isa<RemUIOp>(op) ? lhsRange.
umax() : lhsRange.
smax();
213 if (
min.isNegative() ||
min.uge(modulus))
215 if (
max.isNegative() ||
max.uge(modulus))
228 DataFlowSolver &solver;
235 for (
Value val : values) {
236 auto *maybeInferredRange =
238 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
242 maybeInferredRange->getValue().getValue();
243 ranges.push_back(inferredRange);
250static Type getTargetType(
Type srcType,
unsigned targetBitwidth) {
251 auto dstType = IntegerType::get(srcType.
getContext(), targetBitwidth);
252 if (
auto shaped = dyn_cast<ShapedType>(srcType))
253 return shaped.clone(dstType);
272 unsigned targetWidth) {
273 unsigned srcWidth = range.
smin().getBitWidth();
274 if (srcWidth <= targetWidth)
275 return CastKind::None;
276 unsigned removedWidth = srcWidth - targetWidth;
280 bool canTruncateSigned =
281 range.
smin().getNumSignBits() >= (removedWidth + 1) &&
282 range.
smax().getNumSignBits() >= (removedWidth + 1);
283 bool canTruncateUnsigned = range.
umin().countLeadingZeros() >= removedWidth &&
284 range.
umax().countLeadingZeros() >= removedWidth;
285 if (canTruncateSigned && canTruncateUnsigned)
286 return CastKind::Both;
287 if (canTruncateSigned)
288 return CastKind::Signed;
289 if (canTruncateUnsigned)
290 return CastKind::Unsigned;
291 return CastKind::None;
294static CastKind mergeCastKinds(CastKind
lhs, CastKind
rhs) {
295 if (
lhs == CastKind::None ||
rhs == CastKind::None)
296 return CastKind::None;
297 if (
lhs == CastKind::Both)
299 if (
rhs == CastKind::Both)
303 return CastKind::None;
309 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
310 "Mixing vector and non-vector types");
311 assert(castKind != CastKind::None &&
"Can't cast when casting isn't allowed");
314 assert(srcElemType.
isIntOrIndex() &&
"Invalid src type");
315 assert(dstElemType.
isIntOrIndex() &&
"Invalid dst type");
316 if (srcType == dstType)
319 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
320 if (castKind == CastKind::Signed)
321 return arith::IndexCastOp::create(builder, loc, dstType, src);
322 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
325 auto srcInt = cast<IntegerType>(srcElemType);
326 auto dstInt = cast<IntegerType>(dstElemType);
327 if (dstInt.getWidth() < srcInt.getWidth())
328 return arith::TruncIOp::create(builder, loc, dstType, src);
330 if (castKind == CastKind::Signed)
331 return arith::ExtSIOp::create(builder, loc, dstType, src);
332 return arith::ExtUIOp::create(builder, loc, dstType, src);
349 if (failed(collectRanges(solver, op->
getOperands(), ranges)))
351 if (failed(collectRanges(solver, op->
getResults(), ranges)))
359 [=](
Type t) { return t == srcType; }))
361 op,
"no operands or operand types don't match result type");
363 for (
unsigned targetBitwidth : targetBitwidths) {
364 CastKind castKind = CastKind::Both;
366 castKind = mergeCastKinds(castKind,
367 checkTruncatability(range, targetBitwidth));
368 if (castKind == CastKind::None)
376 .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp,
377 arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp,
378 arith::ShRSIOp>([](
auto) {
return CastKind::Signed; })
379 .Default(CastKind::Both);
380 castKind = mergeCastKinds(castKind, castKindForOp);
381 if (castKind == CastKind::None)
383 Type targetType = getTargetType(srcType, targetBitwidth);
384 if (targetType == srcType)
389 for (
auto [arg, argRange] : llvm::zip_first(op->
getOperands(), ranges)) {
390 CastKind argCastKind = castKind;
393 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
394 argCastKind = CastKind::Both;
395 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
396 mapping.
map(arg, newArg);
402 res.setType(targetType);
406 for (
auto [newRes, oldRes] :
408 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
410 newResults.push_back(castBack);
420 DataFlowSolver &solver;
421 SmallVector<unsigned, 4> targetBitwidths;
428 LogicalResult matchAndRewrite(arith::CmpIOp op,
434 if (failed(collectRanges(solver, op.getOperands(), ranges)))
439 auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) ->
bool {
440 return pred == arith::CmpIPredicate::sge ||
441 pred == arith::CmpIPredicate::sgt ||
442 pred == arith::CmpIPredicate::sle ||
443 pred == arith::CmpIPredicate::slt;
447 CastKind predicateBasedCastRestriction =
448 isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed
452 for (
unsigned targetBitwidth : targetBitwidths) {
453 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
454 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
455 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
456 castKind = mergeCastKinds(castKind, predicateBasedCastRestriction);
459 if (castKind == CastKind::None)
462 Type targetType = getTargetType(srcType, targetBitwidth);
463 if (targetType == srcType)
468 Value lhsCast = doCast(rewriter, loc,
lhs, targetType, lhsCastKind);
469 Value rhsCast = doCast(rewriter, loc,
rhs, targetType, rhsCastKind);
471 mapping.
map(
rhs, rhsCast);
489template <
typename CastOp>
491 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned>
target)
492 : OpRewritePattern<CastOp>(context), targetBitwidths(
target) {}
494 LogicalResult matchAndRewrite(CastOp op,
495 PatternRewriter &rewriter)
const override {
496 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
500 Value src = srcOp.getIn();
501 if (src.
getType() != op.getType())
504 if (!srcOp.getType().isIndex())
507 auto intType = dyn_cast<IntegerType>(op.getType());
508 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
516 SmallVector<unsigned, 4> targetBitwidths;
520 NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
521 ArrayRef<unsigned>
target)
522 : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
524 boundsNarrowingFailedAttr(
525 StringAttr::
get(context,
"arith.bounds_narrowing_failed")) {}
527 LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
528 PatternRewriter &rewriter)
const override {
530 if (loopLike->hasAttr(boundsNarrowingFailedAttr))
532 "bounds narrowing previously failed");
534 std::optional<SmallVector<Value>> inductionVars =
535 loopLike.getLoopInductionVars();
536 if (!inductionVars.has_value() || inductionVars->empty())
539 std::optional<SmallVector<OpFoldResult>> lowerBounds =
540 loopLike.getLoopLowerBounds();
541 std::optional<SmallVector<OpFoldResult>> upperBounds =
542 loopLike.getLoopUpperBounds();
543 std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
545 if (!lowerBounds.has_value() || !upperBounds.has_value() ||
549 if (lowerBounds->size() != inductionVars->size() ||
550 upperBounds->size() != inductionVars->size() ||
551 steps->size() != inductionVars->size())
553 "mismatched bounds/steps count");
555 Location loc = loopLike->getLoc();
556 SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
557 SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
558 SmallVector<OpFoldResult> newSteps(*steps);
559 SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
562 for (
auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
563 llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
566 auto maybeLb = dyn_cast<Value>(lbOFR);
567 auto maybeUb = dyn_cast<Value>(ubOFR);
568 auto maybeStep = dyn_cast<Value>(stepOFR);
570 if (!maybeLb || !maybeUb || !maybeStep)
574 SmallVector<ConstantIntRanges> ranges;
576 solver,
ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
579 const ConstantIntRanges &stepRange = ranges[2];
580 const ConstantIntRanges &indVarRange = ranges[3];
582 Type srcType = maybeLb.getType();
585 for (
unsigned targetBitwidth : targetBitwidths) {
586 Type targetType = getTargetType(srcType, targetBitwidth);
587 if (targetType == srcType)
592 if (!loopLike.isValidInductionVarType(targetType))
596 CastKind castKind = CastKind::Both;
597 for (
const ConstantIntRanges &range : ranges) {
598 castKind = mergeCastKinds(castKind,
599 checkTruncatability(range, targetBitwidth));
600 if (castKind == CastKind::None)
604 if (castKind == CastKind::None)
614 ConstantIntRanges indVarPlusStepRange(
615 indVarRange.
smin().sadd_sat(stepRange.
smin()),
616 indVarRange.
smax().sadd_sat(stepRange.
smax()),
617 indVarRange.
umin().uadd_sat(stepRange.
umin()),
618 indVarRange.
umax().uadd_sat(stepRange.
umax()));
620 if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
625 Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
626 Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
627 Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
629 newLowerBounds[idx] = newLb;
630 newUpperBounds[idx] = newUb;
631 newSteps[idx] = newStep;
632 narrowings.push_back({idx, targetType, castKind});
637 if (narrowings.empty())
641 SmallVector<Type> origTypes;
642 for (
auto [idx, targetType, castKind] : narrowings) {
643 Value indVar = (*inductionVars)[idx];
644 origTypes.push_back(indVar.
getType());
649 bool updateFailed =
false;
652 if (
failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
653 failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
654 failed(loopLike.setLoopSteps(newSteps))) {
657 loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.
getUnitAttr());
663 for (
auto [idx, targetType, castKind] : narrowings) {
664 Value indVar = (*inductionVars)[idx];
665 auto blockArg = cast<BlockArgument>(indVar);
668 blockArg.setType(targetType);
676 for (
auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
677 auto [idx, targetType, castKind] = narrowingInfo;
678 Value indVar = (*inductionVars)[idx];
679 auto blockArg = cast<BlockArgument>(indVar);
680 Type origType = origTypes[narrowingIdx];
682 OpBuilder::InsertionGuard guard(rewriter);
684 Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
695 DataFlowSolver &solver;
696 SmallVector<unsigned, 4> targetBitwidths;
697 StringAttr boundsNarrowingFailedAttr;
700struct IntRangeOptimizationsPass final
703 void runOnOperation()
override {
704 Operation *op = getOperation();
706 DataFlowSolver solver;
708 solver.
load<IntegerRangeAnalysis>();
710 return signalPassFailure();
712 DataFlowListener listener(solver);
714 RewritePatternSet patterns(ctx);
725 GreedyRewriteConfig()
726 .enableFolding(
false)
727 .setRegionSimplificationLevel(
728 GreedySimplifyRegionLevel::Disabled)
729 .setListener(&listener))))
734struct IntRangeNarrowingPass final
736 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
738 void runOnOperation()
override {
739 Operation *op = getOperation();
741 DataFlowSolver solver;
743 solver.
load<IntegerRangeAnalysis>();
745 return signalPassFailure();
747 DataFlowListener listener(solver);
749 RewritePatternSet patterns(ctx);
757 op, std::move(patterns),
758 GreedyRewriteConfig().setUseTopDownTraversal(
false).setListener(
767 patterns.
add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
768 DeleteTrivialRem<RemUIOp>>(patterns.
getContext(), solver);
774 patterns.
add<NarrowElementwise, NarrowCmpI>(patterns.
getContext(), solver,
776 patterns.
add<FoldIndexCastChain<arith::IndexCastUIOp>,
777 FoldIndexCastChain<arith::IndexCastOp>>(patterns.
getContext(),
784 patterns.
add<NarrowLoopBounds>(patterns.
getContext(), solver,
789 return std::make_unique<IntRangeOptimizationsPass>();
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This is a value defined by a result of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in 'rhs' into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateControlFlowValuesNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for narrowing control flow values (loop bounds, steps, etc.) based on int range analysis...
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...