MLIR 23.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "llvm/ADT/TypeSwitch.h"
12
17
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
30
31namespace mlir::arith {
32#define GEN_PASS_DEF_ARITHINTRANGEOPTS
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34
35#define GEN_PASS_DEF_ARITHINTRANGENARROWING
36#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
37} // namespace mlir::arith
38
39using namespace mlir;
40using namespace mlir::arith;
41using namespace mlir::dataflow;
42
43static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
44 Value value) {
45 auto *maybeInferredRange =
47 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
48 return std::nullopt;
49 const ConstantIntRanges &inferredRange =
50 maybeInferredRange->getValue().getValue();
51 return inferredRange.getConstantValue();
52}
53
54static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
55 Value newVal) {
56 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
57 if (!oldState)
58 return;
60 *oldState);
61}
62
63namespace mlir::dataflow {
64/// Patterned after SCCP
66 RewriterBase &rewriter, Value value) {
67 if (value.use_empty())
68 return failure();
69 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
70 if (!maybeConstValue.has_value())
71 return failure();
72
73 Type type = value.getType();
74 Location loc = value.getLoc();
75 Operation *maybeDefiningOp = value.getDefiningOp();
76 Dialect *valueDialect =
77 maybeDefiningOp ? maybeDefiningOp->getDialect()
79
80 Attribute constAttr;
81 if (auto shaped = dyn_cast<ShapedType>(type)) {
82 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
83 } else {
84 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
85 }
86 Operation *constOp =
87 valueDialect->materializeConstant(rewriter, constAttr, type, loc);
88 // Fall back to arith.constant if the dialect materializer doesn't know what
89 // to do with an integer constant.
90 if (!constOp)
91 constOp = rewriter.getContext()
92 ->getLoadedDialect<ArithDialect>()
93 ->materializeConstant(rewriter, constAttr, type, loc);
94 if (!constOp)
95 return failure();
96
97 OpResult res = constOp->getResult(0);
99 solver.eraseState(res);
100 copyIntegerRange(solver, value, res);
101 rewriter.replaceAllUsesWith(value, res);
102 return success();
103}
104} // namespace mlir::dataflow
105
106namespace {
107class DataFlowListener : public RewriterBase::Listener {
108public:
109 DataFlowListener(DataFlowSolver &s) : s(s) {}
110
111protected:
112 void notifyOperationErased(Operation *op) override {
113 s.eraseState(s.getProgramPointAfter(op));
114 for (Value res : op->getResults())
115 s.eraseState(res);
116 }
117
118 DataFlowSolver &s;
119};
120
121/// Rewrite any results of `op` that were inferred to be constant integers to
122/// and replace their uses with that constant. Return success() if all results
123/// where thus replaced and the operation is erased. Also replace any block
124/// arguments with their constant values.
125struct MaterializeKnownConstantValues : public RewritePattern {
126 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
127 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
128 /*benefit=*/1, context),
129 solver(s) {}
130
131 LogicalResult matchAndRewrite(Operation *op,
132 PatternRewriter &rewriter) const override {
133 if (matchPattern(op, m_Constant()))
134 return failure();
135
136 auto needsReplacing = [&](Value v) {
137 return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
138 };
139 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
140 if (op->getNumRegions() == 0)
141 if (!hasConstantResults)
142 return failure();
143 bool hasConstantRegionArgs = false;
144 for (Region &region : op->getRegions()) {
145 for (Block &block : region.getBlocks()) {
146 hasConstantRegionArgs |=
147 llvm::any_of(block.getArguments(), needsReplacing);
148 }
149 }
150 if (!hasConstantResults && !hasConstantRegionArgs)
151 return failure();
152
153 bool replacedAll = (op->getNumResults() != 0);
154 for (Value v : op->getResults())
155 replacedAll &=
156 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
157 v.use_empty());
158 if (replacedAll && isOpTriviallyDead(op)) {
159 rewriter.eraseOp(op);
160 return success();
161 }
162
163 PatternRewriter::InsertionGuard guard(rewriter);
164 for (Region &region : op->getRegions()) {
165 for (Block &block : region.getBlocks()) {
166 rewriter.setInsertionPointToStart(&block);
167 for (BlockArgument &arg : block.getArguments()) {
168 (void)maybeReplaceWithConstant(solver, rewriter, arg);
169 }
170 }
171 }
172
173 return success();
174 }
175
176private:
177 DataFlowSolver &solver;
178};
179
180template <typename RemOp>
181struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
182 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
183 : OpRewritePattern<RemOp>(context), solver(s) {}
184
185 LogicalResult matchAndRewrite(RemOp op,
186 PatternRewriter &rewriter) const override {
187 Value lhs = op.getOperand(0);
188 Value rhs = op.getOperand(1);
189 auto maybeModulus = getConstantIntValue(rhs);
190 if (!maybeModulus.has_value())
191 return failure();
192 int64_t modulus = *maybeModulus;
193 if (modulus <= 0)
194 return failure();
195 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
196 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
197 return failure();
198 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
199 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
200 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
201 // The minima and maxima here are given as closed ranges, we must be
202 // strictly less than the modulus.
203 if (min.isNegative() || min.uge(modulus))
204 return failure();
205 if (max.isNegative() || max.uge(modulus))
206 return failure();
207 if (!min.ule(max))
208 return failure();
209
210 // With all those conditions out of the way, we know thas this invocation of
211 // a remainder is a noop because the input is strictly within the range
212 // [0, modulus), so get rid of it.
213 rewriter.replaceOp(op, ValueRange{lhs});
214 return success();
215 }
216
217private:
218 DataFlowSolver &solver;
219};
220
221/// Gather ranges for all the values in `values`. Appends to the existing
222/// vector.
223static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
225 for (Value val : values) {
226 auto *maybeInferredRange =
228 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
229 return failure();
230
231 const ConstantIntRanges &inferredRange =
232 maybeInferredRange->getValue().getValue();
233 ranges.push_back(inferredRange);
234 }
235 return success();
236}
237
238/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
239/// return shaped type as well.
240static Type getTargetType(Type srcType, unsigned targetBitwidth) {
241 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
242 if (auto shaped = dyn_cast<ShapedType>(srcType))
243 return shaped.clone(dstType);
244
245 assert(srcType.isIntOrIndex() && "Invalid src type");
246 return dstType;
247}
248
249namespace {
250// Enum for tracking which type of truncation should be performed
251// to narrow an operation, if any.
252enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
253} // namespace
254
255/// If the values within `range` can be represented using only `width` bits,
256/// return the kind of truncation needed to preserve that property.
257///
258/// This check relies on the fact that the signed and unsigned ranges are both
259/// always correct, but that one might be an approximation of the other,
260/// so we want to use the correct truncation operation.
261static CastKind checkTruncatability(const ConstantIntRanges &range,
262 unsigned targetWidth) {
263 unsigned srcWidth = range.smin().getBitWidth();
264 if (srcWidth <= targetWidth)
265 return CastKind::None;
266 unsigned removedWidth = srcWidth - targetWidth;
267 // The sign bits need to extend into the sign bit of the target width. For
268 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
269 // bits.
270 bool canTruncateSigned =
271 range.smin().getNumSignBits() >= (removedWidth + 1) &&
272 range.smax().getNumSignBits() >= (removedWidth + 1);
273 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
274 range.umax().countLeadingZeros() >= removedWidth;
275 if (canTruncateSigned && canTruncateUnsigned)
276 return CastKind::Both;
277 if (canTruncateSigned)
278 return CastKind::Signed;
279 if (canTruncateUnsigned)
280 return CastKind::Unsigned;
281 return CastKind::None;
282}
283
284static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
285 if (lhs == CastKind::None || rhs == CastKind::None)
286 return CastKind::None;
287 if (lhs == CastKind::Both)
288 return rhs;
289 if (rhs == CastKind::Both)
290 return lhs;
291 if (lhs == rhs)
292 return lhs;
293 return CastKind::None;
294}
295
296static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
297 CastKind castKind) {
298 Type srcType = src.getType();
299 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
300 "Mixing vector and non-vector types");
301 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
302 Type srcElemType = getElementTypeOrSelf(srcType);
303 Type dstElemType = getElementTypeOrSelf(dstType);
304 assert(srcElemType.isIntOrIndex() && "Invalid src type");
305 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
306 if (srcType == dstType)
307 return src;
308
309 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
310 if (castKind == CastKind::Signed)
311 return arith::IndexCastOp::create(builder, loc, dstType, src);
312 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
313 }
314
315 auto srcInt = cast<IntegerType>(srcElemType);
316 auto dstInt = cast<IntegerType>(dstElemType);
317 if (dstInt.getWidth() < srcInt.getWidth())
318 return arith::TruncIOp::create(builder, loc, dstType, src);
319
320 if (castKind == CastKind::Signed)
321 return arith::ExtSIOp::create(builder, loc, dstType, src);
322 return arith::ExtUIOp::create(builder, loc, dstType, src);
323}
324
325struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
326 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
327 ArrayRef<unsigned> target)
328 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
329
331 LogicalResult matchAndRewrite(Operation *op,
332 PatternRewriter &rewriter) const override {
333 if (op->getNumResults() == 0)
334 return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
335
336 // Inline size chosen empirically based on compilation profiling.
337 // Profiled: 2.6M calls, avg=1.7+-1.3. N=4 covers >95% of cases inline.
338 SmallVector<ConstantIntRanges, 4> ranges;
339 if (failed(collectRanges(solver, op->getOperands(), ranges)))
340 return rewriter.notifyMatchFailure(op, "input without specified range");
341 if (failed(collectRanges(solver, op->getResults(), ranges)))
342 return rewriter.notifyMatchFailure(op, "output without specified range");
343
344 Type srcType = op->getResult(0).getType();
345 if (!llvm::all_equal(op->getResultTypes()))
346 return rewriter.notifyMatchFailure(op, "mismatched result types");
347 if (op->getNumOperands() == 0 ||
348 !llvm::all_of(op->getOperandTypes(),
349 [=](Type t) { return t == srcType; }))
350 return rewriter.notifyMatchFailure(
351 op, "no operands or operand types don't match result type");
352
353 for (unsigned targetBitwidth : targetBitwidths) {
354 CastKind castKind = CastKind::Both;
355 for (const ConstantIntRanges &range : ranges) {
356 castKind = mergeCastKinds(castKind,
357 checkTruncatability(range, targetBitwidth));
358 if (castKind == CastKind::None)
359 break;
360 }
361 // For operations that explicitly treat the values as signed, we should
362 // only do signed casts, if those are deemed possible as such based on the
363 // value range.
364 auto castKindForOp =
365 llvm::TypeSwitch<Operation *, CastKind>(op)
366 .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp,
367 arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp,
368 arith::ShRSIOp>([](auto) { return CastKind::Signed; })
369 .Default(CastKind::Both);
370 castKind = mergeCastKinds(castKind, castKindForOp);
371 if (castKind == CastKind::None)
372 continue;
373 Type targetType = getTargetType(srcType, targetBitwidth);
374 if (targetType == srcType)
375 continue;
376
377 Location loc = op->getLoc();
378 IRMapping mapping;
379 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
380 CastKind argCastKind = castKind;
381 // When dealing with `index` values, preserve non-negativity in the
382 // index_casts since we can't recover this in unsigned when equivalent.
383 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
384 argCastKind = CastKind::Both;
385 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
386 mapping.map(arg, newArg);
387 }
388
389 Operation *newOp = rewriter.clone(*op, mapping);
390 rewriter.modifyOpInPlace(newOp, [&]() {
391 for (OpResult res : newOp->getResults()) {
392 res.setType(targetType);
393 }
394 });
395 SmallVector<Value> newResults;
396 for (auto [newRes, oldRes] :
397 llvm::zip_equal(newOp->getResults(), op->getResults())) {
398 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
399 copyIntegerRange(solver, oldRes, castBack);
400 newResults.push_back(castBack);
401 }
402
403 rewriter.replaceOp(op, newResults);
404 return success();
405 }
406 return failure();
407 }
408
409private:
410 DataFlowSolver &solver;
411 SmallVector<unsigned, 4> targetBitwidths;
412};
413
414struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
415 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
416 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
417
418 LogicalResult matchAndRewrite(arith::CmpIOp op,
419 PatternRewriter &rewriter) const override {
420 Value lhs = op.getLhs();
421 Value rhs = op.getRhs();
422
423 SmallVector<ConstantIntRanges> ranges;
424 if (failed(collectRanges(solver, op.getOperands(), ranges)))
425 return failure();
426 const ConstantIntRanges &lhsRange = ranges[0];
427 const ConstantIntRanges &rhsRange = ranges[1];
428
429 auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool {
430 return pred == arith::CmpIPredicate::sge ||
431 pred == arith::CmpIPredicate::sgt ||
432 pred == arith::CmpIPredicate::sle ||
433 pred == arith::CmpIPredicate::slt;
434 };
435 // If we're to narrow the input values via a cast, we should preserve the
436 // sign.
437 CastKind predicateBasedCastRestriction =
438 isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed
439 : CastKind::Both;
440
441 Type srcType = lhs.getType();
442 for (unsigned targetBitwidth : targetBitwidths) {
443 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
444 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
445 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
446 castKind = mergeCastKinds(castKind, predicateBasedCastRestriction);
447 // Note: this includes target width > src width, as well as the unsigned
448 // truncatability & signed predicate scenario.
449 if (castKind == CastKind::None)
450 continue;
451
452 Type targetType = getTargetType(srcType, targetBitwidth);
453 if (targetType == srcType)
454 continue;
455
456 Location loc = op->getLoc();
457 IRMapping mapping;
458 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
459 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
460 mapping.map(lhs, lhsCast);
461 mapping.map(rhs, rhsCast);
462
463 Operation *newOp = rewriter.clone(*op, mapping);
464 copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
465 rewriter.replaceOp(op, newOp->getResults());
466 return success();
467 }
468 return failure();
469 }
470
471private:
472 DataFlowSolver &solver;
473 SmallVector<unsigned, 4> targetBitwidths;
474};
475
476/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
477/// This pattern assumes all passed `targetBitwidths` are not wider than index
478/// type.
479template <typename CastOp>
480struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
481 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
482 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
483
484 LogicalResult matchAndRewrite(CastOp op,
485 PatternRewriter &rewriter) const override {
486 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
487 if (!srcOp)
488 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
489
490 Value src = srcOp.getIn();
491 if (src.getType() != op.getType())
492 return rewriter.notifyMatchFailure(op, "outer types don't match");
493
494 if (!srcOp.getType().isIndex())
495 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
496
497 auto intType = dyn_cast<IntegerType>(op.getType());
498 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
499 return failure();
500
501 rewriter.replaceOp(op, src);
502 return success();
503 }
504
505private:
506 SmallVector<unsigned, 4> targetBitwidths;
507};
508
509struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
510 NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
511 ArrayRef<unsigned> target)
512 : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
513 targetBitwidths(target),
514 boundsNarrowingFailedAttr(
515 StringAttr::get(context, "arith.bounds_narrowing_failed")) {}
516
517 LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
518 PatternRewriter &rewriter) const override {
519 // Skip ops where bounds narrowing previously failed.
520 if (loopLike->hasAttr(boundsNarrowingFailedAttr))
521 return rewriter.notifyMatchFailure(loopLike,
522 "bounds narrowing previously failed");
523
524 std::optional<SmallVector<Value>> inductionVars =
525 loopLike.getLoopInductionVars();
526 if (!inductionVars.has_value() || inductionVars->empty())
527 return rewriter.notifyMatchFailure(loopLike, "no induction variables");
528
529 std::optional<SmallVector<OpFoldResult>> lowerBounds =
530 loopLike.getLoopLowerBounds();
531 std::optional<SmallVector<OpFoldResult>> upperBounds =
532 loopLike.getLoopUpperBounds();
533 std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
534
535 if (!lowerBounds.has_value() || !upperBounds.has_value() ||
536 !steps.has_value())
537 return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
538
539 if (lowerBounds->size() != inductionVars->size() ||
540 upperBounds->size() != inductionVars->size() ||
541 steps->size() != inductionVars->size())
542 return rewriter.notifyMatchFailure(loopLike,
543 "mismatched bounds/steps count");
544
545 Location loc = loopLike->getLoc();
546 SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
547 SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
548 SmallVector<OpFoldResult> newSteps(*steps);
549 SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
550
551 // Check each (indVar, lb, ub, step) tuple.
552 for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
553 llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
554
555 // Only process value operands, skip attributes.
556 auto maybeLb = dyn_cast<Value>(lbOFR);
557 auto maybeUb = dyn_cast<Value>(ubOFR);
558 auto maybeStep = dyn_cast<Value>(stepOFR);
559
560 if (!maybeLb || !maybeUb || !maybeStep)
561 continue;
562
563 // Collect ranges for (lb, ub, step, indVar).
564 SmallVector<ConstantIntRanges> ranges;
565 if (failed(collectRanges(
566 solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
567 continue;
568
569 const ConstantIntRanges &stepRange = ranges[2];
570 const ConstantIntRanges &indVarRange = ranges[3];
571
572 Type srcType = maybeLb.getType();
573
574 // Try each target bitwidth.
575 for (unsigned targetBitwidth : targetBitwidths) {
576 Type targetType = getTargetType(srcType, targetBitwidth);
577 if (targetType == srcType)
578 continue;
579
580 // Check if the target type is valid for this loop's induction
581 // variables.
582 if (!loopLike.isValidInductionVarType(targetType))
583 continue;
584
585 // Check if all values in this tuple can be truncated.
586 CastKind castKind = CastKind::Both;
587 for (const ConstantIntRanges &range : ranges) {
588 castKind = mergeCastKinds(castKind,
589 checkTruncatability(range, targetBitwidth));
590 if (castKind == CastKind::None)
591 break;
592 }
593
594 if (castKind == CastKind::None)
595 continue;
596
597 // Check if indVar + step fits in the narrowed type.
598 // This is critical for loop correctness: the loop computes
599 // iv_next = iv_current + step in the narrowed type, then compares
600 // iv_next < ub. If iv_current + step overflows, the comparison may
601 // produce incorrect results and break loop termination.
602 // Both signed and unsigned interpretations must fit because loop
603 // semantics are unknown (integer types are signless).
604 ConstantIntRanges indVarPlusStepRange(
605 indVarRange.smin().sadd_sat(stepRange.smin()),
606 indVarRange.smax().sadd_sat(stepRange.smax()),
607 indVarRange.umin().uadd_sat(stepRange.umin()),
608 indVarRange.umax().uadd_sat(stepRange.umax()));
609
610 if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
611 CastKind::Both)
612 continue;
613
614 // Narrow the bounds and step values.
615 Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
616 Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
617 Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
618
619 newLowerBounds[idx] = newLb;
620 newUpperBounds[idx] = newUb;
621 newSteps[idx] = newStep;
622 narrowings.push_back({idx, targetType, castKind});
623 break;
624 }
625 }
626
627 if (narrowings.empty())
628 return rewriter.notifyMatchFailure(loopLike, "no narrowings found");
629
630 // Save original types before modifying.
631 SmallVector<Type> origTypes;
632 for (auto [idx, targetType, castKind] : narrowings) {
633 Value indVar = (*inductionVars)[idx];
634 origTypes.push_back(indVar.getType());
635 }
636
637 // Attempt to update bounds and induction variable types.
638 // If this fails, mark the op so we don't try again.
639 bool updateFailed = false;
640 rewriter.modifyOpInPlace(loopLike, [&]() {
641 // Update the loop bounds and steps.
642 if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
643 failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
644 failed(loopLike.setLoopSteps(newSteps))) {
645 // Mark op to prevent future attempts. IR was modified (attribute
646 // added), so we must return success() from the pattern.
647 loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr());
648 updateFailed = true;
649 return;
650 }
651
652 // Update induction variable types.
653 for (auto [idx, targetType, castKind] : narrowings) {
654 Value indVar = (*inductionVars)[idx];
655 auto blockArg = cast<BlockArgument>(indVar);
656
657 // Change the block argument type.
658 blockArg.setType(targetType);
659 }
660 });
661
662 if (updateFailed)
663 return success();
664
665 // Insert casts back to original type for uses.
666 for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
667 auto [idx, targetType, castKind] = narrowingInfo;
668 Value indVar = (*inductionVars)[idx];
669 auto blockArg = cast<BlockArgument>(indVar);
670 Type origType = origTypes[narrowingIdx];
671
672 OpBuilder::InsertionGuard guard(rewriter);
673 rewriter.setInsertionPointToStart(blockArg.getOwner());
674 Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
675 copyIntegerRange(solver, blockArg, casted);
676
677 // Replace all uses of the narrowed indVar with the casted value.
678 rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
679 }
680
681 return success();
682 }
683
684private:
685 DataFlowSolver &solver;
686 SmallVector<unsigned, 4> targetBitwidths;
687 StringAttr boundsNarrowingFailedAttr;
688};
689
690struct IntRangeOptimizationsPass final
691 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
692
693 void runOnOperation() override {
694 Operation *op = getOperation();
695 MLIRContext *ctx = op->getContext();
696 DataFlowSolver solver;
697 loadBaselineAnalyses(solver);
698 solver.load<IntegerRangeAnalysis>();
699 if (failed(solver.initializeAndRun(op)))
700 return signalPassFailure();
701
702 DataFlowListener listener(solver);
703
704 RewritePatternSet patterns(ctx);
706
707 // Disable folding and region simplification to avoid breaking the solver
708 // state. Both can remove block arguments (folding via control-flow
709 // simplification, region simplification via dead-arg elimination), which
710 // frees their underlying storage. A subsequent allocation may reuse the
711 // same address for a different block argument, causing stale solver state
712 // to be associated with the new argument and producing incorrect constants.
713 if (failed(
714 applyPatternsGreedily(op, std::move(patterns),
715 GreedyRewriteConfig()
716 .enableFolding(false)
717 .setRegionSimplificationLevel(
718 GreedySimplifyRegionLevel::Disabled)
719 .setListener(&listener))))
720 signalPassFailure();
721 }
722};
723
724struct IntRangeNarrowingPass final
725 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
726 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
727
728 void runOnOperation() override {
729 Operation *op = getOperation();
730 MLIRContext *ctx = op->getContext();
731 DataFlowSolver solver;
732 loadBaselineAnalyses(solver);
733 solver.load<IntegerRangeAnalysis>();
734 if (failed(solver.initializeAndRun(op)))
735 return signalPassFailure();
736
737 DataFlowListener listener(solver);
738
739 RewritePatternSet patterns(ctx);
740 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
742 bitwidthsSupported);
743
744 // We specifically need bottom-up traversal as cmpi pattern needs range
745 // data, attached to its original argument values.
747 op, std::move(patterns),
748 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
749 &listener))))
750 signalPassFailure();
751 }
752};
753} // namespace
754
756 RewritePatternSet &patterns, DataFlowSolver &solver) {
757 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
758 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
759}
760
762 RewritePatternSet &patterns, DataFlowSolver &solver,
763 ArrayRef<unsigned> bitwidthsSupported) {
764 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
765 bitwidthsSupported);
766 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
767 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
768 bitwidthsSupported);
769}
770
772 RewritePatternSet &patterns, DataFlowSolver &solver,
773 ArrayRef<unsigned> bitwidthsSupported) {
774 patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
775 bitwidthsSupported);
776}
777
779 return std::make_unique<IntRangeOptimizationsPass>();
780}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
lhs
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
MLIRContext * getContext() const
Definition Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This is a value defined by a result of an operation.
Definition Value.h:454
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
operand_type_range getOperandTypes()
Definition Operation.h:423
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateControlFlowValuesNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for narrowing control flow values (loop bounds, steps, etc.) based on int range analysis...
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...