MLIR 23.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
15
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
28
29namespace mlir::arith {
30#define GEN_PASS_DEF_ARITHINTRANGEOPTS
31#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32
33#define GEN_PASS_DEF_ARITHINTRANGENARROWING
34#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
35} // namespace mlir::arith
36
37using namespace mlir;
38using namespace mlir::arith;
39using namespace mlir::dataflow;
40
41static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
42 Value value) {
43 auto *maybeInferredRange =
45 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
46 return std::nullopt;
47 const ConstantIntRanges &inferredRange =
48 maybeInferredRange->getValue().getValue();
49 return inferredRange.getConstantValue();
50}
51
52static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
53 Value newVal) {
54 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
55 if (!oldState)
56 return;
58 *oldState);
59}
60
61namespace mlir::dataflow {
62/// Patterned after SCCP
64 RewriterBase &rewriter, Value value) {
65 if (value.use_empty())
66 return failure();
67 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
68 if (!maybeConstValue.has_value())
69 return failure();
70
71 Type type = value.getType();
72 Location loc = value.getLoc();
73 Operation *maybeDefiningOp = value.getDefiningOp();
74 Dialect *valueDialect =
75 maybeDefiningOp ? maybeDefiningOp->getDialect()
77
78 Attribute constAttr;
79 if (auto shaped = dyn_cast<ShapedType>(type)) {
80 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
81 } else {
82 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
83 }
84 Operation *constOp =
85 valueDialect->materializeConstant(rewriter, constAttr, type, loc);
86 // Fall back to arith.constant if the dialect materializer doesn't know what
87 // to do with an integer constant.
88 if (!constOp)
89 constOp = rewriter.getContext()
90 ->getLoadedDialect<ArithDialect>()
91 ->materializeConstant(rewriter, constAttr, type, loc);
92 if (!constOp)
93 return failure();
94
95 OpResult res = constOp->getResult(0);
97 solver.eraseState(res);
98 copyIntegerRange(solver, value, res);
99 rewriter.replaceAllUsesWith(value, res);
100 return success();
101}
102} // namespace mlir::dataflow
103
104namespace {
105class DataFlowListener : public RewriterBase::Listener {
106public:
107 DataFlowListener(DataFlowSolver &s) : s(s) {}
108
109protected:
110 void notifyOperationErased(Operation *op) override {
111 s.eraseState(s.getProgramPointAfter(op));
112 for (Value res : op->getResults())
113 s.eraseState(res);
114 }
115
116 DataFlowSolver &s;
117};
118
119/// Rewrite any results of `op` that were inferred to be constant integers to
120/// and replace their uses with that constant. Return success() if all results
121/// where thus replaced and the operation is erased. Also replace any block
122/// arguments with their constant values.
123struct MaterializeKnownConstantValues : public RewritePattern {
124 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
125 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
126 /*benefit=*/1, context),
127 solver(s) {}
128
129 LogicalResult matchAndRewrite(Operation *op,
130 PatternRewriter &rewriter) const override {
131 if (matchPattern(op, m_Constant()))
132 return failure();
133
134 auto needsReplacing = [&](Value v) {
135 return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
136 };
137 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
138 if (op->getNumRegions() == 0)
139 if (!hasConstantResults)
140 return failure();
141 bool hasConstantRegionArgs = false;
142 for (Region &region : op->getRegions()) {
143 for (Block &block : region.getBlocks()) {
144 hasConstantRegionArgs |=
145 llvm::any_of(block.getArguments(), needsReplacing);
146 }
147 }
148 if (!hasConstantResults && !hasConstantRegionArgs)
149 return failure();
150
151 bool replacedAll = (op->getNumResults() != 0);
152 for (Value v : op->getResults())
153 replacedAll &=
154 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
155 v.use_empty());
156 if (replacedAll && isOpTriviallyDead(op)) {
157 rewriter.eraseOp(op);
158 return success();
159 }
160
161 PatternRewriter::InsertionGuard guard(rewriter);
162 for (Region &region : op->getRegions()) {
163 for (Block &block : region.getBlocks()) {
164 rewriter.setInsertionPointToStart(&block);
165 for (BlockArgument &arg : block.getArguments()) {
166 (void)maybeReplaceWithConstant(solver, rewriter, arg);
167 }
168 }
169 }
170
171 return success();
172 }
173
174private:
175 DataFlowSolver &solver;
176};
177
178template <typename RemOp>
179struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
180 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
181 : OpRewritePattern<RemOp>(context), solver(s) {}
182
183 LogicalResult matchAndRewrite(RemOp op,
184 PatternRewriter &rewriter) const override {
185 Value lhs = op.getOperand(0);
186 Value rhs = op.getOperand(1);
187 auto maybeModulus = getConstantIntValue(rhs);
188 if (!maybeModulus.has_value())
189 return failure();
190 int64_t modulus = *maybeModulus;
191 if (modulus <= 0)
192 return failure();
193 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
194 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
195 return failure();
196 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
197 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
198 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
199 // The minima and maxima here are given as closed ranges, we must be
200 // strictly less than the modulus.
201 if (min.isNegative() || min.uge(modulus))
202 return failure();
203 if (max.isNegative() || max.uge(modulus))
204 return failure();
205 if (!min.ule(max))
206 return failure();
207
208 // With all those conditions out of the way, we know thas this invocation of
209 // a remainder is a noop because the input is strictly within the range
210 // [0, modulus), so get rid of it.
211 rewriter.replaceOp(op, ValueRange{lhs});
212 return success();
213 }
214
215private:
216 DataFlowSolver &solver;
217};
218
219/// Gather ranges for all the values in `values`. Appends to the existing
220/// vector.
221static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
223 for (Value val : values) {
224 auto *maybeInferredRange =
226 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
227 return failure();
228
229 const ConstantIntRanges &inferredRange =
230 maybeInferredRange->getValue().getValue();
231 ranges.push_back(inferredRange);
232 }
233 return success();
234}
235
236/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
237/// return shaped type as well.
238static Type getTargetType(Type srcType, unsigned targetBitwidth) {
239 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
240 if (auto shaped = dyn_cast<ShapedType>(srcType))
241 return shaped.clone(dstType);
242
243 assert(srcType.isIntOrIndex() && "Invalid src type");
244 return dstType;
245}
246
247namespace {
248// Enum for tracking which type of truncation should be performed
249// to narrow an operation, if any.
250enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
251} // namespace
252
253/// If the values within `range` can be represented using only `width` bits,
254/// return the kind of truncation needed to preserve that property.
255///
256/// This check relies on the fact that the signed and unsigned ranges are both
257/// always correct, but that one might be an approximation of the other,
258/// so we want to use the correct truncation operation.
259static CastKind checkTruncatability(const ConstantIntRanges &range,
260 unsigned targetWidth) {
261 unsigned srcWidth = range.smin().getBitWidth();
262 if (srcWidth <= targetWidth)
263 return CastKind::None;
264 unsigned removedWidth = srcWidth - targetWidth;
265 // The sign bits need to extend into the sign bit of the target width. For
266 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
267 // bits.
268 bool canTruncateSigned =
269 range.smin().getNumSignBits() >= (removedWidth + 1) &&
270 range.smax().getNumSignBits() >= (removedWidth + 1);
271 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
272 range.umax().countLeadingZeros() >= removedWidth;
273 if (canTruncateSigned && canTruncateUnsigned)
274 return CastKind::Both;
275 if (canTruncateSigned)
276 return CastKind::Signed;
277 if (canTruncateUnsigned)
278 return CastKind::Unsigned;
279 return CastKind::None;
280}
281
282static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
283 if (lhs == CastKind::None || rhs == CastKind::None)
284 return CastKind::None;
285 if (lhs == CastKind::Both)
286 return rhs;
287 if (rhs == CastKind::Both)
288 return lhs;
289 if (lhs == rhs)
290 return lhs;
291 return CastKind::None;
292}
293
294static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
295 CastKind castKind) {
296 Type srcType = src.getType();
297 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
298 "Mixing vector and non-vector types");
299 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
300 Type srcElemType = getElementTypeOrSelf(srcType);
301 Type dstElemType = getElementTypeOrSelf(dstType);
302 assert(srcElemType.isIntOrIndex() && "Invalid src type");
303 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
304 if (srcType == dstType)
305 return src;
306
307 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
308 if (castKind == CastKind::Signed)
309 return arith::IndexCastOp::create(builder, loc, dstType, src);
310 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
311 }
312
313 auto srcInt = cast<IntegerType>(srcElemType);
314 auto dstInt = cast<IntegerType>(dstElemType);
315 if (dstInt.getWidth() < srcInt.getWidth())
316 return arith::TruncIOp::create(builder, loc, dstType, src);
317
318 if (castKind == CastKind::Signed)
319 return arith::ExtSIOp::create(builder, loc, dstType, src);
320 return arith::ExtUIOp::create(builder, loc, dstType, src);
321}
322
323struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
324 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
325 ArrayRef<unsigned> target)
326 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
327
329 LogicalResult matchAndRewrite(Operation *op,
330 PatternRewriter &rewriter) const override {
331 if (op->getNumResults() == 0)
332 return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
333
335 if (failed(collectRanges(solver, op->getOperands(), ranges)))
336 return rewriter.notifyMatchFailure(op, "input without specified range");
337 if (failed(collectRanges(solver, op->getResults(), ranges)))
338 return rewriter.notifyMatchFailure(op, "output without specified range");
340 Type srcType = op->getResult(0).getType();
341 if (!llvm::all_equal(op->getResultTypes()))
342 return rewriter.notifyMatchFailure(op, "mismatched result types");
343 if (op->getNumOperands() == 0 ||
344 !llvm::all_of(op->getOperandTypes(),
345 [=](Type t) { return t == srcType; }))
346 return rewriter.notifyMatchFailure(
347 op, "no operands or operand types don't match result type");
348
349 for (unsigned targetBitwidth : targetBitwidths) {
350 CastKind castKind = CastKind::Both;
351 for (const ConstantIntRanges &range : ranges) {
352 castKind = mergeCastKinds(castKind,
353 checkTruncatability(range, targetBitwidth));
354 if (castKind == CastKind::None)
355 break;
356 }
357 if (castKind == CastKind::None)
358 continue;
359 Type targetType = getTargetType(srcType, targetBitwidth);
360 if (targetType == srcType)
361 continue;
362
363 Location loc = op->getLoc();
364 IRMapping mapping;
365 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
366 CastKind argCastKind = castKind;
367 // When dealing with `index` values, preserve non-negativity in the
368 // index_casts since we can't recover this in unsigned when equivalent.
369 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
370 argCastKind = CastKind::Both;
371 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
372 mapping.map(arg, newArg);
373 }
374
375 Operation *newOp = rewriter.clone(*op, mapping);
376 rewriter.modifyOpInPlace(newOp, [&]() {
377 for (OpResult res : newOp->getResults()) {
378 res.setType(targetType);
379 }
380 });
382 for (auto [newRes, oldRes] :
383 llvm::zip_equal(newOp->getResults(), op->getResults())) {
384 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
385 copyIntegerRange(solver, oldRes, castBack);
386 newResults.push_back(castBack);
387 }
389 rewriter.replaceOp(op, newResults);
390 return success();
391 }
392 return failure();
393 }
395private:
396 DataFlowSolver &solver;
397 SmallVector<unsigned, 4> targetBitwidths;
399
400struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
401 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
402 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
403
404 LogicalResult matchAndRewrite(arith::CmpIOp op,
405 PatternRewriter &rewriter) const override {
406 Value lhs = op.getLhs();
407 Value rhs = op.getRhs();
408
410 if (failed(collectRanges(solver, op.getOperands(), ranges)))
411 return failure();
412 const ConstantIntRanges &lhsRange = ranges[0];
413 const ConstantIntRanges &rhsRange = ranges[1];
414
415 Type srcType = lhs.getType();
416 for (unsigned targetBitwidth : targetBitwidths) {
417 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
418 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
419 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
420 // Note: this includes target width > src width.
421 if (castKind == CastKind::None)
422 continue;
424 Type targetType = getTargetType(srcType, targetBitwidth);
425 if (targetType == srcType)
426 continue;
428 Location loc = op->getLoc();
429 IRMapping mapping;
430 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
431 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
432 mapping.map(lhs, lhsCast);
433 mapping.map(rhs, rhsCast);
435 Operation *newOp = rewriter.clone(*op, mapping);
436 copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
437 rewriter.replaceOp(op, newOp->getResults());
438 return success();
440 return failure();
441 }
443private:
444 DataFlowSolver &solver;
446};
447
448/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
449/// This pattern assumes all passed `targetBitwidths` are not wider than index
450/// type.
451template <typename CastOp>
452struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
453 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
454 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
456 LogicalResult matchAndRewrite(CastOp op,
457 PatternRewriter &rewriter) const override {
458 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
459 if (!srcOp)
460 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
461
462 Value src = srcOp.getIn();
463 if (src.getType() != op.getType())
464 return rewriter.notifyMatchFailure(op, "outer types don't match");
465
466 if (!srcOp.getType().isIndex())
467 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
468
469 auto intType = dyn_cast<IntegerType>(op.getType());
470 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
471 return failure();
472
473 rewriter.replaceOp(op, src);
474 return success();
475 }
476
477private:
478 SmallVector<unsigned, 4> targetBitwidths;
479};
480
481struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
482 NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
484 : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
485 targetBitwidths(target),
486 boundsNarrowingFailedAttr(
487 StringAttr::get(context, "arith.bounds_narrowing_failed")) {}
488
489 LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
490 PatternRewriter &rewriter) const override {
491 // Skip ops where bounds narrowing previously failed.
492 if (loopLike->hasAttr(boundsNarrowingFailedAttr))
493 return rewriter.notifyMatchFailure(loopLike,
494 "bounds narrowing previously failed");
495
496 std::optional<SmallVector<Value>> inductionVars =
497 loopLike.getLoopInductionVars();
498 if (!inductionVars.has_value() || inductionVars->empty())
499 return rewriter.notifyMatchFailure(loopLike, "no induction variables");
500
501 std::optional<SmallVector<OpFoldResult>> lowerBounds =
502 loopLike.getLoopLowerBounds();
503 std::optional<SmallVector<OpFoldResult>> upperBounds =
504 loopLike.getLoopUpperBounds();
505 std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
506
507 if (!lowerBounds.has_value() || !upperBounds.has_value() ||
508 !steps.has_value())
509 return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
510
511 if (lowerBounds->size() != inductionVars->size() ||
512 upperBounds->size() != inductionVars->size() ||
513 steps->size() != inductionVars->size())
514 return rewriter.notifyMatchFailure(loopLike,
515 "mismatched bounds/steps count");
516
517 Location loc = loopLike->getLoc();
518 SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
519 SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
520 SmallVector<OpFoldResult> newSteps(*steps);
521 SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
522
523 // Check each (indVar, lb, ub, step) tuple.
524 for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
525 llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
526
527 // Only process value operands, skip attributes.
528 auto maybeLb = dyn_cast<Value>(lbOFR);
529 auto maybeUb = dyn_cast<Value>(ubOFR);
530 auto maybeStep = dyn_cast<Value>(stepOFR);
531
532 if (!maybeLb || !maybeUb || !maybeStep)
533 continue;
534
535 // Collect ranges for (lb, ub, step, indVar).
536 SmallVector<ConstantIntRanges> ranges;
537 if (failed(collectRanges(
538 solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
539 continue;
540
541 const ConstantIntRanges &stepRange = ranges[2];
542 const ConstantIntRanges &indVarRange = ranges[3];
543
544 Type srcType = maybeLb.getType();
545
546 // Try each target bitwidth.
547 for (unsigned targetBitwidth : targetBitwidths) {
548 Type targetType = getTargetType(srcType, targetBitwidth);
549 if (targetType == srcType)
550 continue;
551
552 // Check if the target type is valid for this loop's induction
553 // variables.
554 if (!loopLike.isValidInductionVarType(targetType))
555 continue;
556
557 // Check if all values in this tuple can be truncated.
558 CastKind castKind = CastKind::Both;
559 for (const ConstantIntRanges &range : ranges) {
560 castKind = mergeCastKinds(castKind,
561 checkTruncatability(range, targetBitwidth));
562 if (castKind == CastKind::None)
563 break;
564 }
565
566 if (castKind == CastKind::None)
567 continue;
568
569 // Check if indVar + step fits in the narrowed type.
570 // This is critical for loop correctness: the loop computes
571 // iv_next = iv_current + step in the narrowed type, then compares
572 // iv_next < ub. If iv_current + step overflows, the comparison may
573 // produce incorrect results and break loop termination.
574 // Both signed and unsigned interpretations must fit because loop
575 // semantics are unknown (integer types are signless).
576 ConstantIntRanges indVarPlusStepRange(
577 indVarRange.smin().sadd_sat(stepRange.smin()),
578 indVarRange.smax().sadd_sat(stepRange.smax()),
579 indVarRange.umin().uadd_sat(stepRange.umin()),
580 indVarRange.umax().uadd_sat(stepRange.umax()));
581
582 if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
583 CastKind::Both)
584 continue;
585
586 // Narrow the bounds and step values.
587 Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
588 Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
589 Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
590
591 newLowerBounds[idx] = newLb;
592 newUpperBounds[idx] = newUb;
593 newSteps[idx] = newStep;
594 narrowings.push_back({idx, targetType, castKind});
595 break;
596 }
597 }
598
599 if (narrowings.empty())
600 return rewriter.notifyMatchFailure(loopLike, "no narrowings found");
601
602 // Save original types before modifying.
603 SmallVector<Type> origTypes;
604 for (auto [idx, targetType, castKind] : narrowings) {
605 Value indVar = (*inductionVars)[idx];
606 origTypes.push_back(indVar.getType());
607 }
608
609 // Attempt to update bounds and induction variable types.
610 // If this fails, mark the op so we don't try again.
611 bool updateFailed = false;
612 rewriter.modifyOpInPlace(loopLike, [&]() {
613 // Update the loop bounds and steps.
614 if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
615 failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
616 failed(loopLike.setLoopSteps(newSteps))) {
617 // Mark op to prevent future attempts. IR was modified (attribute
618 // added), so we must return success() from the pattern.
619 loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr());
620 updateFailed = true;
621 return;
622 }
623
624 // Update induction variable types.
625 for (auto [idx, targetType, castKind] : narrowings) {
626 Value indVar = (*inductionVars)[idx];
627 auto blockArg = cast<BlockArgument>(indVar);
628
629 // Change the block argument type.
630 blockArg.setType(targetType);
631 }
632 });
633
634 if (updateFailed)
635 return success();
636
637 // Insert casts back to original type for uses.
638 for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
639 auto [idx, targetType, castKind] = narrowingInfo;
640 Value indVar = (*inductionVars)[idx];
641 auto blockArg = cast<BlockArgument>(indVar);
642 Type origType = origTypes[narrowingIdx];
643
644 OpBuilder::InsertionGuard guard(rewriter);
645 rewriter.setInsertionPointToStart(blockArg.getOwner());
646 Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
647 copyIntegerRange(solver, blockArg, casted);
648
649 // Replace all uses of the narrowed indVar with the casted value.
650 rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
651 }
652
653 return success();
654 }
655
656private:
657 DataFlowSolver &solver;
658 SmallVector<unsigned, 4> targetBitwidths;
659 StringAttr boundsNarrowingFailedAttr;
660};
661
662struct IntRangeOptimizationsPass final
663 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
664
665 void runOnOperation() override {
666 Operation *op = getOperation();
667 MLIRContext *ctx = op->getContext();
668 DataFlowSolver solver;
669 loadBaselineAnalyses(solver);
670 solver.load<IntegerRangeAnalysis>();
671 if (failed(solver.initializeAndRun(op)))
672 return signalPassFailure();
673
674 DataFlowListener listener(solver);
675
676 RewritePatternSet patterns(ctx);
678
680 op, std::move(patterns),
681 GreedyRewriteConfig().setListener(&listener))))
682 signalPassFailure();
683 }
684};
685
686struct IntRangeNarrowingPass final
687 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
688 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
689
690 void runOnOperation() override {
691 Operation *op = getOperation();
692 MLIRContext *ctx = op->getContext();
693 DataFlowSolver solver;
694 loadBaselineAnalyses(solver);
695 solver.load<IntegerRangeAnalysis>();
696 if (failed(solver.initializeAndRun(op)))
697 return signalPassFailure();
698
699 DataFlowListener listener(solver);
700
701 RewritePatternSet patterns(ctx);
702 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
704 bitwidthsSupported);
705
706 // We specifically need bottom-up traversal as cmpi pattern needs range
707 // data, attached to its original argument values.
709 op, std::move(patterns),
710 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
711 &listener))))
712 signalPassFailure();
713 }
714};
715} // namespace
716
719 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
720 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
721}
722
725 ArrayRef<unsigned> bitwidthsSupported) {
726 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
727 bitwidthsSupported);
728 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
729 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
730 bitwidthsSupported);
731}
732
735 ArrayRef<unsigned> bitwidthsSupported) {
736 patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
737 bitwidthsSupported);
738}
739
741 return std::make_unique<IntRangeOptimizationsPass>();
742}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
lhs
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
MLIRContext * getContext() const
Definition Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This is a value defined by a result of an operation.
Definition Value.h:457
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateControlFlowValuesNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for narrowing control flow values (loop bounds, steps, etc.) based on int range analysis...
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...