MLIR 23.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "llvm/ADT/TypeSwitch.h"
12
17
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
30
31namespace mlir::arith {
32#define GEN_PASS_DEF_ARITHINTRANGEOPTS
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34
35#define GEN_PASS_DEF_ARITHINTRANGENARROWING
36#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
37} // namespace mlir::arith
38
39using namespace mlir;
40using namespace mlir::arith;
41using namespace mlir::dataflow;
42
43static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
44 Value value) {
45 auto *maybeInferredRange =
47 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
48 return std::nullopt;
49 const ConstantIntRanges &inferredRange =
50 maybeInferredRange->getValue().getValue();
51 return inferredRange.getConstantValue();
52}
53
54static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
55 Value newVal) {
56 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
57 if (!oldState)
58 return;
60 *oldState);
61}
62
63namespace mlir::dataflow {
64/// Patterned after SCCP
66 RewriterBase &rewriter, Value value) {
67 if (value.use_empty())
68 return failure();
69 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
70 if (!maybeConstValue.has_value())
71 return failure();
72
73 Type type = value.getType();
74 // If the type or element type is non-integral, the attribute constructor
75 // will crash, so eagerly check for an integer type to avoid this.
76 if (!getElementTypeOrSelf(type).isIntOrIndex())
77 return failure();
78 Location loc = value.getLoc();
79 Operation *maybeDefiningOp = value.getDefiningOp();
80 Dialect *valueDialect =
81 maybeDefiningOp ? maybeDefiningOp->getDialect()
83
84 Attribute constAttr;
85 if (auto shaped = dyn_cast<ShapedType>(type)) {
86 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
87 } else {
88 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
89 }
90 Operation *constOp =
91 valueDialect->materializeConstant(rewriter, constAttr, type, loc);
92 // Fall back to arith.constant if the dialect materializer doesn't know what
93 // to do with an integer constant.
94 if (!constOp)
95 constOp = rewriter.getContext()
96 ->getLoadedDialect<ArithDialect>()
97 ->materializeConstant(rewriter, constAttr, type, loc);
98 if (!constOp)
99 return failure();
100
101 OpResult res = constOp->getResult(0);
103 solver.eraseState(res);
104 copyIntegerRange(solver, value, res);
105 rewriter.replaceAllUsesWith(value, res);
106 return success();
107}
108} // namespace mlir::dataflow
109
110namespace {
111class DataFlowListener : public RewriterBase::Listener {
112public:
113 DataFlowListener(DataFlowSolver &s) : s(s) {}
114
115protected:
116 void notifyOperationErased(Operation *op) override {
117 s.eraseState(s.getProgramPointAfter(op));
118 for (Value res : op->getResults())
119 s.eraseState(res);
120 }
121
122 DataFlowSolver &s;
123};
124
125/// Rewrite any results of `op` that were inferred to be constant integers to
126/// and replace their uses with that constant. Return success() if all results
127/// where thus replaced and the operation is erased. Also replace any block
128/// arguments with their constant values.
129struct MaterializeKnownConstantValues : public RewritePattern {
130 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
131 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
132 /*benefit=*/1, context),
133 solver(s) {}
134
135 LogicalResult matchAndRewrite(Operation *op,
136 PatternRewriter &rewriter) const override {
137 if (matchPattern(op, m_Constant()))
138 return failure();
139
140 // We need to check isIntOrIndex() here as well to avoid infinite loops in
141 // the greedy pattern rewriter. If we only check it in
142 // maybeReplaceWithConstant, this lambda might still return true for
143 // non-integral types, causing the pattern to match and claim success
144 // without making any changes, leading to non-convergence.
145 auto needsReplacing = [&](Value v) {
146 return getElementTypeOrSelf(v.getType()).isIntOrIndex() &&
147 getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
148 };
149 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
150 if (op->getNumRegions() == 0)
151 if (!hasConstantResults)
152 return failure();
153 bool hasConstantRegionArgs = false;
154 for (Region &region : op->getRegions()) {
155 for (Block &block : region.getBlocks()) {
156 hasConstantRegionArgs |=
157 llvm::any_of(block.getArguments(), needsReplacing);
158 }
159 }
160 if (!hasConstantResults && !hasConstantRegionArgs)
161 return failure();
162
163 bool replacedAll = (op->getNumResults() != 0);
164 for (Value v : op->getResults())
165 replacedAll &=
166 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
167 v.use_empty());
168 if (replacedAll && isOpTriviallyDead(op)) {
169 rewriter.eraseOp(op);
170 return success();
171 }
172
173 PatternRewriter::InsertionGuard guard(rewriter);
174 for (Region &region : op->getRegions()) {
175 for (Block &block : region.getBlocks()) {
176 rewriter.setInsertionPointToStart(&block);
177 for (BlockArgument &arg : block.getArguments()) {
178 (void)maybeReplaceWithConstant(solver, rewriter, arg);
179 }
180 }
181 }
182
183 return success();
184 }
185
186private:
187 DataFlowSolver &solver;
188};
189
190template <typename RemOp>
191struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
192 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
193 : OpRewritePattern<RemOp>(context), solver(s) {}
194
195 LogicalResult matchAndRewrite(RemOp op,
196 PatternRewriter &rewriter) const override {
197 Value lhs = op.getOperand(0);
198 Value rhs = op.getOperand(1);
199 auto maybeModulus = getConstantIntValue(rhs);
200 if (!maybeModulus.has_value())
201 return failure();
202 int64_t modulus = *maybeModulus;
203 if (modulus <= 0)
204 return failure();
205 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
206 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
207 return failure();
208 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
209 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
210 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
211 // The minima and maxima here are given as closed ranges, we must be
212 // strictly less than the modulus.
213 if (min.isNegative() || min.uge(modulus))
214 return failure();
215 if (max.isNegative() || max.uge(modulus))
216 return failure();
217 if (!min.ule(max))
218 return failure();
219
220 // With all those conditions out of the way, we know thas this invocation of
221 // a remainder is a noop because the input is strictly within the range
222 // [0, modulus), so get rid of it.
223 rewriter.replaceOp(op, ValueRange{lhs});
224 return success();
225 }
226
227private:
228 DataFlowSolver &solver;
229};
230
231/// Gather ranges for all the values in `values`. Appends to the existing
232/// vector.
233static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
235 for (Value val : values) {
236 auto *maybeInferredRange =
238 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
239 return failure();
240
241 const ConstantIntRanges &inferredRange =
242 maybeInferredRange->getValue().getValue();
243 ranges.push_back(inferredRange);
244 }
245 return success();
246}
247
248/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
249/// return shaped type as well.
250static Type getTargetType(Type srcType, unsigned targetBitwidth) {
251 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
252 if (auto shaped = dyn_cast<ShapedType>(srcType))
253 return shaped.clone(dstType);
254
255 assert(srcType.isIntOrIndex() && "Invalid src type");
256 return dstType;
257}
258
259namespace {
260// Enum for tracking which type of truncation should be performed
261// to narrow an operation, if any.
262enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
263} // namespace
264
265/// If the values within `range` can be represented using only `width` bits,
266/// return the kind of truncation needed to preserve that property.
267///
268/// This check relies on the fact that the signed and unsigned ranges are both
269/// always correct, but that one might be an approximation of the other,
270/// so we want to use the correct truncation operation.
271static CastKind checkTruncatability(const ConstantIntRanges &range,
272 unsigned targetWidth) {
273 unsigned srcWidth = range.smin().getBitWidth();
274 if (srcWidth <= targetWidth)
275 return CastKind::None;
276 unsigned removedWidth = srcWidth - targetWidth;
277 // The sign bits need to extend into the sign bit of the target width. For
278 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
279 // bits.
280 bool canTruncateSigned =
281 range.smin().getNumSignBits() >= (removedWidth + 1) &&
282 range.smax().getNumSignBits() >= (removedWidth + 1);
283 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
284 range.umax().countLeadingZeros() >= removedWidth;
285 if (canTruncateSigned && canTruncateUnsigned)
286 return CastKind::Both;
287 if (canTruncateSigned)
288 return CastKind::Signed;
289 if (canTruncateUnsigned)
290 return CastKind::Unsigned;
291 return CastKind::None;
292}
293
294static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
295 if (lhs == CastKind::None || rhs == CastKind::None)
296 return CastKind::None;
297 if (lhs == CastKind::Both)
298 return rhs;
299 if (rhs == CastKind::Both)
300 return lhs;
301 if (lhs == rhs)
302 return lhs;
303 return CastKind::None;
304}
305
306static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
307 CastKind castKind) {
308 Type srcType = src.getType();
309 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
310 "Mixing vector and non-vector types");
311 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
312 Type srcElemType = getElementTypeOrSelf(srcType);
313 Type dstElemType = getElementTypeOrSelf(dstType);
314 assert(srcElemType.isIntOrIndex() && "Invalid src type");
315 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
316 if (srcType == dstType)
317 return src;
318
319 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
320 if (castKind == CastKind::Signed)
321 return arith::IndexCastOp::create(builder, loc, dstType, src);
322 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
323 }
324
325 auto srcInt = cast<IntegerType>(srcElemType);
326 auto dstInt = cast<IntegerType>(dstElemType);
327 if (dstInt.getWidth() < srcInt.getWidth())
328 return arith::TruncIOp::create(builder, loc, dstType, src);
329
330 if (castKind == CastKind::Signed)
331 return arith::ExtSIOp::create(builder, loc, dstType, src);
332 return arith::ExtUIOp::create(builder, loc, dstType, src);
333}
334
335struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
336 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
338 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
341 LogicalResult matchAndRewrite(Operation *op,
342 PatternRewriter &rewriter) const override {
343 if (op->getNumResults() == 0)
344 return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
345
346 // Inline size chosen empirically based on compilation profiling.
347 // Profiled: 2.6M calls, avg=1.7+-1.3. N=4 covers >95% of cases inline.
349 if (failed(collectRanges(solver, op->getOperands(), ranges)))
350 return rewriter.notifyMatchFailure(op, "input without specified range");
351 if (failed(collectRanges(solver, op->getResults(), ranges)))
352 return rewriter.notifyMatchFailure(op, "output without specified range");
353
354 Type srcType = op->getResult(0).getType();
355 if (!llvm::all_equal(op->getResultTypes()))
356 return rewriter.notifyMatchFailure(op, "mismatched result types");
357 if (op->getNumOperands() == 0 ||
358 !llvm::all_of(op->getOperandTypes(),
359 [=](Type t) { return t == srcType; }))
360 return rewriter.notifyMatchFailure(
361 op, "no operands or operand types don't match result type");
362
363 for (unsigned targetBitwidth : targetBitwidths) {
364 CastKind castKind = CastKind::Both;
365 for (const ConstantIntRanges &range : ranges) {
366 castKind = mergeCastKinds(castKind,
367 checkTruncatability(range, targetBitwidth));
368 if (castKind == CastKind::None)
369 break;
370 }
371 // For operations that explicitly treat the values as signed, we should
372 // only do signed casts, if those are deemed possible as such based on the
373 // value range.
374 auto castKindForOp =
376 .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp,
377 arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp,
378 arith::ShRSIOp>([](auto) { return CastKind::Signed; })
379 .Default(CastKind::Both);
380 castKind = mergeCastKinds(castKind, castKindForOp);
381 if (castKind == CastKind::None)
382 continue;
383 Type targetType = getTargetType(srcType, targetBitwidth);
384 if (targetType == srcType)
385 continue;
386
387 Location loc = op->getLoc();
388 IRMapping mapping;
389 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
390 CastKind argCastKind = castKind;
391 // When dealing with `index` values, preserve non-negativity in the
392 // index_casts since we can't recover this in unsigned when equivalent.
393 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
394 argCastKind = CastKind::Both;
395 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
396 mapping.map(arg, newArg);
398
399 Operation *newOp = rewriter.clone(*op, mapping);
400 rewriter.modifyOpInPlace(newOp, [&]() {
401 for (OpResult res : newOp->getResults()) {
402 res.setType(targetType);
403 }
404 });
405 SmallVector<Value> newResults;
406 for (auto [newRes, oldRes] :
407 llvm::zip_equal(newOp->getResults(), op->getResults())) {
408 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
409 copyIntegerRange(solver, oldRes, castBack);
410 newResults.push_back(castBack);
411 }
412
413 rewriter.replaceOp(op, newResults);
414 return success();
415 }
416 return failure();
417 }
418
419private:
420 DataFlowSolver &solver;
421 SmallVector<unsigned, 4> targetBitwidths;
423
424struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
425 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
426 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
428 LogicalResult matchAndRewrite(arith::CmpIOp op,
429 PatternRewriter &rewriter) const override {
430 Value lhs = op.getLhs();
431 Value rhs = op.getRhs();
432
434 if (failed(collectRanges(solver, op.getOperands(), ranges)))
435 return failure();
436 const ConstantIntRanges &lhsRange = ranges[0];
437 const ConstantIntRanges &rhsRange = ranges[1];
438
439 auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool {
440 return pred == arith::CmpIPredicate::sge ||
441 pred == arith::CmpIPredicate::sgt ||
442 pred == arith::CmpIPredicate::sle ||
443 pred == arith::CmpIPredicate::slt;
444 };
445 // If we're to narrow the input values via a cast, we should preserve the
446 // sign.
447 CastKind predicateBasedCastRestriction =
448 isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed
449 : CastKind::Both;
450
451 Type srcType = lhs.getType();
452 for (unsigned targetBitwidth : targetBitwidths) {
453 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
454 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
455 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
456 castKind = mergeCastKinds(castKind, predicateBasedCastRestriction);
457 // Note: this includes target width > src width, as well as the unsigned
458 // truncatability & signed predicate scenario.
459 if (castKind == CastKind::None)
460 continue;
461
462 Type targetType = getTargetType(srcType, targetBitwidth);
463 if (targetType == srcType)
464 continue;
465
466 Location loc = op->getLoc();
467 IRMapping mapping;
468 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
469 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
470 mapping.map(lhs, lhsCast);
471 mapping.map(rhs, rhsCast);
472
473 Operation *newOp = rewriter.clone(*op, mapping);
474 copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
475 rewriter.replaceOp(op, newOp->getResults());
476 return success();
477 }
478 return failure();
479 }
480
481private:
482 DataFlowSolver &solver;
483 SmallVector<unsigned, 4> targetBitwidths;
484};
485
486/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
487/// This pattern assumes all passed `targetBitwidths` are not wider than index
488/// type.
489template <typename CastOp>
490struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
491 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
492 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
493
494 LogicalResult matchAndRewrite(CastOp op,
495 PatternRewriter &rewriter) const override {
496 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
497 if (!srcOp)
498 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
499
500 Value src = srcOp.getIn();
501 if (src.getType() != op.getType())
502 return rewriter.notifyMatchFailure(op, "outer types don't match");
503
504 if (!srcOp.getType().isIndex())
505 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
506
507 auto intType = dyn_cast<IntegerType>(op.getType());
508 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
509 return failure();
510
511 rewriter.replaceOp(op, src);
512 return success();
513 }
514
515private:
516 SmallVector<unsigned, 4> targetBitwidths;
517};
518
519struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
520 NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
521 ArrayRef<unsigned> target)
522 : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
523 targetBitwidths(target),
524 boundsNarrowingFailedAttr(
525 StringAttr::get(context, "arith.bounds_narrowing_failed")) {}
526
527 LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
528 PatternRewriter &rewriter) const override {
529 // Skip ops where bounds narrowing previously failed.
530 if (loopLike->hasAttr(boundsNarrowingFailedAttr))
531 return rewriter.notifyMatchFailure(loopLike,
532 "bounds narrowing previously failed");
533
534 std::optional<SmallVector<Value>> inductionVars =
535 loopLike.getLoopInductionVars();
536 if (!inductionVars.has_value() || inductionVars->empty())
537 return rewriter.notifyMatchFailure(loopLike, "no induction variables");
538
539 std::optional<SmallVector<OpFoldResult>> lowerBounds =
540 loopLike.getLoopLowerBounds();
541 std::optional<SmallVector<OpFoldResult>> upperBounds =
542 loopLike.getLoopUpperBounds();
543 std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
544
545 if (!lowerBounds.has_value() || !upperBounds.has_value() ||
546 !steps.has_value())
547 return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
548
549 if (lowerBounds->size() != inductionVars->size() ||
550 upperBounds->size() != inductionVars->size() ||
551 steps->size() != inductionVars->size())
552 return rewriter.notifyMatchFailure(loopLike,
553 "mismatched bounds/steps count");
554
555 Location loc = loopLike->getLoc();
556 SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
557 SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
558 SmallVector<OpFoldResult> newSteps(*steps);
559 SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
560
561 // Check each (indVar, lb, ub, step) tuple.
562 for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
563 llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
564
565 // Only process value operands, skip attributes.
566 auto maybeLb = dyn_cast<Value>(lbOFR);
567 auto maybeUb = dyn_cast<Value>(ubOFR);
568 auto maybeStep = dyn_cast<Value>(stepOFR);
569
570 if (!maybeLb || !maybeUb || !maybeStep)
571 continue;
572
573 // Collect ranges for (lb, ub, step, indVar).
574 SmallVector<ConstantIntRanges> ranges;
575 if (failed(collectRanges(
576 solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
577 continue;
578
579 const ConstantIntRanges &stepRange = ranges[2];
580 const ConstantIntRanges &indVarRange = ranges[3];
581
582 Type srcType = maybeLb.getType();
583
584 // Try each target bitwidth.
585 for (unsigned targetBitwidth : targetBitwidths) {
586 Type targetType = getTargetType(srcType, targetBitwidth);
587 if (targetType == srcType)
588 continue;
589
590 // Check if the target type is valid for this loop's induction
591 // variables.
592 if (!loopLike.isValidInductionVarType(targetType))
593 continue;
594
595 // Check if all values in this tuple can be truncated.
596 CastKind castKind = CastKind::Both;
597 for (const ConstantIntRanges &range : ranges) {
598 castKind = mergeCastKinds(castKind,
599 checkTruncatability(range, targetBitwidth));
600 if (castKind == CastKind::None)
601 break;
602 }
603
604 if (castKind == CastKind::None)
605 continue;
606
607 // Check if indVar + step fits in the narrowed type.
608 // This is critical for loop correctness: the loop computes
609 // iv_next = iv_current + step in the narrowed type, then compares
610 // iv_next < ub. If iv_current + step overflows, the comparison may
611 // produce incorrect results and break loop termination.
612 // Both signed and unsigned interpretations must fit because loop
613 // semantics are unknown (integer types are signless).
614 ConstantIntRanges indVarPlusStepRange(
615 indVarRange.smin().sadd_sat(stepRange.smin()),
616 indVarRange.smax().sadd_sat(stepRange.smax()),
617 indVarRange.umin().uadd_sat(stepRange.umin()),
618 indVarRange.umax().uadd_sat(stepRange.umax()));
619
620 if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
621 CastKind::Both)
622 continue;
623
624 // Narrow the bounds and step values.
625 Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
626 Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
627 Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
628
629 newLowerBounds[idx] = newLb;
630 newUpperBounds[idx] = newUb;
631 newSteps[idx] = newStep;
632 narrowings.push_back({idx, targetType, castKind});
633 break;
634 }
635 }
636
637 if (narrowings.empty())
638 return rewriter.notifyMatchFailure(loopLike, "no narrowings found");
639
640 // Save original types before modifying.
641 SmallVector<Type> origTypes;
642 for (auto [idx, targetType, castKind] : narrowings) {
643 Value indVar = (*inductionVars)[idx];
644 origTypes.push_back(indVar.getType());
645 }
646
647 // Attempt to update bounds and induction variable types.
648 // If this fails, mark the op so we don't try again.
649 bool updateFailed = false;
650 rewriter.modifyOpInPlace(loopLike, [&]() {
651 // Update the loop bounds and steps.
652 if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
653 failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
654 failed(loopLike.setLoopSteps(newSteps))) {
655 // Mark op to prevent future attempts. IR was modified (attribute
656 // added), so we must return success() from the pattern.
657 loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr());
658 updateFailed = true;
659 return;
660 }
661
662 // Update induction variable types.
663 for (auto [idx, targetType, castKind] : narrowings) {
664 Value indVar = (*inductionVars)[idx];
665 auto blockArg = cast<BlockArgument>(indVar);
666
667 // Change the block argument type.
668 blockArg.setType(targetType);
669 }
670 });
671
672 if (updateFailed)
673 return success();
674
675 // Insert casts back to original type for uses.
676 for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
677 auto [idx, targetType, castKind] = narrowingInfo;
678 Value indVar = (*inductionVars)[idx];
679 auto blockArg = cast<BlockArgument>(indVar);
680 Type origType = origTypes[narrowingIdx];
681
682 OpBuilder::InsertionGuard guard(rewriter);
683 rewriter.setInsertionPointToStart(blockArg.getOwner());
684 Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
685 copyIntegerRange(solver, blockArg, casted);
686
687 // Replace all uses of the narrowed indVar with the casted value.
688 rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
689 }
690
691 return success();
692 }
693
694private:
695 DataFlowSolver &solver;
696 SmallVector<unsigned, 4> targetBitwidths;
697 StringAttr boundsNarrowingFailedAttr;
698};
699
700struct IntRangeOptimizationsPass final
701 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
702
703 void runOnOperation() override {
704 Operation *op = getOperation();
705 MLIRContext *ctx = op->getContext();
706 DataFlowSolver solver;
707 loadBaselineAnalyses(solver);
708 solver.load<IntegerRangeAnalysis>();
709 if (failed(solver.initializeAndRun(op)))
710 return signalPassFailure();
711
712 DataFlowListener listener(solver);
713
714 RewritePatternSet patterns(ctx);
716
717 // Disable folding and region simplification to avoid breaking the solver
718 // state. Both can remove block arguments (folding via control-flow
719 // simplification, region simplification via dead-arg elimination), which
720 // frees their underlying storage. A subsequent allocation may reuse the
721 // same address for a different block argument, causing stale solver state
722 // to be associated with the new argument and producing incorrect constants.
723 if (failed(
724 applyPatternsGreedily(op, std::move(patterns),
725 GreedyRewriteConfig()
726 .enableFolding(false)
727 .setRegionSimplificationLevel(
728 GreedySimplifyRegionLevel::Disabled)
729 .setListener(&listener))))
730 signalPassFailure();
731 }
732};
733
734struct IntRangeNarrowingPass final
735 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
736 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
737
738 void runOnOperation() override {
739 Operation *op = getOperation();
740 MLIRContext *ctx = op->getContext();
741 DataFlowSolver solver;
742 loadBaselineAnalyses(solver);
743 solver.load<IntegerRangeAnalysis>();
744 if (failed(solver.initializeAndRun(op)))
745 return signalPassFailure();
746
747 DataFlowListener listener(solver);
748
749 RewritePatternSet patterns(ctx);
750 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
752 bitwidthsSupported);
753
754 // We specifically need bottom-up traversal as cmpi pattern needs range
755 // data, attached to its original argument values.
757 op, std::move(patterns),
758 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
759 &listener))))
760 signalPassFailure();
761 }
762};
763} // namespace
764
766 RewritePatternSet &patterns, DataFlowSolver &solver) {
767 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
768 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
769}
770
772 RewritePatternSet &patterns, DataFlowSolver &solver,
773 ArrayRef<unsigned> bitwidthsSupported) {
774 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
775 bitwidthsSupported);
776 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
777 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
778 bitwidthsSupported);
779}
780
782 RewritePatternSet &patterns, DataFlowSolver &solver,
783 ArrayRef<unsigned> bitwidthsSupported) {
784 patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
785 bitwidthsSupported);
786}
787
789 return std::make_unique<IntRangeOptimizationsPass>();
790}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
lhs
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
MLIRContext * getContext() const
Definition Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This is a value defined by a result of an operation.
Definition Value.h:454
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
operand_type_range getOperandTypes()
Definition Operation.h:423
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in 'rhs' into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateControlFlowValuesNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for narrowing control flow values (loop bounds, steps, etc.) based on int range analysis...
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...