MLIR 22.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
14
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
26
27namespace mlir::arith {
28#define GEN_PASS_DEF_ARITHINTRANGEOPTS
29#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30
31#define GEN_PASS_DEF_ARITHINTRANGENARROWING
32#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33} // namespace mlir::arith
34
35using namespace mlir;
36using namespace mlir::arith;
37using namespace mlir::dataflow;
38
39static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
40 Value value) {
41 auto *maybeInferredRange =
43 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
44 return std::nullopt;
45 const ConstantIntRanges &inferredRange =
46 maybeInferredRange->getValue().getValue();
47 return inferredRange.getConstantValue();
48}
49
50static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
51 Value newVal) {
52 assert(oldVal.getType() == newVal.getType() &&
53 "Can't copy integer ranges between different types");
54 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
55 if (!oldState)
56 return;
58 *oldState);
59}
60
61namespace mlir::dataflow {
62/// Patterned after SCCP
64 RewriterBase &rewriter, Value value) {
65 if (value.use_empty())
66 return failure();
67 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
68 if (!maybeConstValue.has_value())
69 return failure();
70
71 Type type = value.getType();
72 Location loc = value.getLoc();
73 Operation *maybeDefiningOp = value.getDefiningOp();
74 Dialect *valueDialect =
75 maybeDefiningOp ? maybeDefiningOp->getDialect()
77
78 Attribute constAttr;
79 if (auto shaped = dyn_cast<ShapedType>(type)) {
80 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
81 } else {
82 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
83 }
84 Operation *constOp =
85 valueDialect->materializeConstant(rewriter, constAttr, type, loc);
86 // Fall back to arith.constant if the dialect materializer doesn't know what
87 // to do with an integer constant.
88 if (!constOp)
89 constOp = rewriter.getContext()
90 ->getLoadedDialect<ArithDialect>()
91 ->materializeConstant(rewriter, constAttr, type, loc);
92 if (!constOp)
93 return failure();
94
95 OpResult res = constOp->getResult(0);
97 solver.eraseState(res);
98 copyIntegerRange(solver, value, res);
99 rewriter.replaceAllUsesWith(value, res);
100 return success();
101}
102} // namespace mlir::dataflow
103
104namespace {
105class DataFlowListener : public RewriterBase::Listener {
106public:
107 DataFlowListener(DataFlowSolver &s) : s(s) {}
108
109protected:
110 void notifyOperationErased(Operation *op) override {
111 s.eraseState(s.getProgramPointAfter(op));
112 for (Value res : op->getResults())
113 s.eraseState(res);
114 }
115
116 DataFlowSolver &s;
117};
118
119/// Rewrite any results of `op` that were inferred to be constant integers to
120/// and replace their uses with that constant. Return success() if all results
121/// where thus replaced and the operation is erased. Also replace any block
122/// arguments with their constant values.
123struct MaterializeKnownConstantValues : public RewritePattern {
124 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
125 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
126 /*benefit=*/1, context),
127 solver(s) {}
128
129 LogicalResult matchAndRewrite(Operation *op,
130 PatternRewriter &rewriter) const override {
131 if (matchPattern(op, m_Constant()))
132 return failure();
133
134 auto needsReplacing = [&](Value v) {
135 return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
136 };
137 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
138 if (op->getNumRegions() == 0)
139 if (!hasConstantResults)
140 return failure();
141 bool hasConstantRegionArgs = false;
142 for (Region &region : op->getRegions()) {
143 for (Block &block : region.getBlocks()) {
144 hasConstantRegionArgs |=
145 llvm::any_of(block.getArguments(), needsReplacing);
146 }
147 }
148 if (!hasConstantResults && !hasConstantRegionArgs)
149 return failure();
150
151 bool replacedAll = (op->getNumResults() != 0);
152 for (Value v : op->getResults())
153 replacedAll &=
154 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
155 v.use_empty());
156 if (replacedAll && isOpTriviallyDead(op)) {
157 rewriter.eraseOp(op);
158 return success();
159 }
160
161 PatternRewriter::InsertionGuard guard(rewriter);
162 for (Region &region : op->getRegions()) {
163 for (Block &block : region.getBlocks()) {
164 rewriter.setInsertionPointToStart(&block);
165 for (BlockArgument &arg : block.getArguments()) {
166 (void)maybeReplaceWithConstant(solver, rewriter, arg);
167 }
168 }
169 }
170
171 return success();
172 }
173
174private:
175 DataFlowSolver &solver;
176};
177
178template <typename RemOp>
179struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
180 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
181 : OpRewritePattern<RemOp>(context), solver(s) {}
182
183 LogicalResult matchAndRewrite(RemOp op,
184 PatternRewriter &rewriter) const override {
185 Value lhs = op.getOperand(0);
186 Value rhs = op.getOperand(1);
187 auto maybeModulus = getConstantIntValue(rhs);
188 if (!maybeModulus.has_value())
189 return failure();
190 int64_t modulus = *maybeModulus;
191 if (modulus <= 0)
192 return failure();
193 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
194 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
195 return failure();
196 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
197 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
198 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
199 // The minima and maxima here are given as closed ranges, we must be
200 // strictly less than the modulus.
201 if (min.isNegative() || min.uge(modulus))
202 return failure();
203 if (max.isNegative() || max.uge(modulus))
204 return failure();
205 if (!min.ule(max))
206 return failure();
207
208 // With all those conditions out of the way, we know thas this invocation of
209 // a remainder is a noop because the input is strictly within the range
210 // [0, modulus), so get rid of it.
211 rewriter.replaceOp(op, ValueRange{lhs});
212 return success();
213 }
214
215private:
216 DataFlowSolver &solver;
217};
218
219/// Gather ranges for all the values in `values`. Appends to the existing
220/// vector.
221static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
223 for (Value val : values) {
224 auto *maybeInferredRange =
226 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
227 return failure();
228
229 const ConstantIntRanges &inferredRange =
230 maybeInferredRange->getValue().getValue();
231 ranges.push_back(inferredRange);
232 }
233 return success();
234}
235
236/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
237/// return shaped type as well.
238static Type getTargetType(Type srcType, unsigned targetBitwidth) {
239 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
240 if (auto shaped = dyn_cast<ShapedType>(srcType))
241 return shaped.clone(dstType);
242
243 assert(srcType.isIntOrIndex() && "Invalid src type");
244 return dstType;
245}
246
247namespace {
248// Enum for tracking which type of truncation should be performed
249// to narrow an operation, if any.
250enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
251} // namespace
252
253/// If the values within `range` can be represented using only `width` bits,
254/// return the kind of truncation needed to preserve that property.
255///
256/// This check relies on the fact that the signed and unsigned ranges are both
257/// always correct, but that one might be an approximation of the other,
258/// so we want to use the correct truncation operation.
259static CastKind checkTruncatability(const ConstantIntRanges &range,
260 unsigned targetWidth) {
261 unsigned srcWidth = range.smin().getBitWidth();
262 if (srcWidth <= targetWidth)
263 return CastKind::None;
264 unsigned removedWidth = srcWidth - targetWidth;
265 // The sign bits need to extend into the sign bit of the target width. For
266 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
267 // bits.
268 bool canTruncateSigned =
269 range.smin().getNumSignBits() >= (removedWidth + 1) &&
270 range.smax().getNumSignBits() >= (removedWidth + 1);
271 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
272 range.umax().countLeadingZeros() >= removedWidth;
273 if (canTruncateSigned && canTruncateUnsigned)
274 return CastKind::Both;
275 if (canTruncateSigned)
276 return CastKind::Signed;
277 if (canTruncateUnsigned)
278 return CastKind::Unsigned;
279 return CastKind::None;
280}
281
282static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
283 if (lhs == CastKind::None || rhs == CastKind::None)
284 return CastKind::None;
285 if (lhs == CastKind::Both)
286 return rhs;
287 if (rhs == CastKind::Both)
288 return lhs;
289 if (lhs == rhs)
290 return lhs;
291 return CastKind::None;
292}
293
294static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
295 CastKind castKind) {
296 Type srcType = src.getType();
297 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
298 "Mixing vector and non-vector types");
299 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
300 Type srcElemType = getElementTypeOrSelf(srcType);
301 Type dstElemType = getElementTypeOrSelf(dstType);
302 assert(srcElemType.isIntOrIndex() && "Invalid src type");
303 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
304 if (srcType == dstType)
305 return src;
306
307 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
308 if (castKind == CastKind::Signed)
309 return arith::IndexCastOp::create(builder, loc, dstType, src);
310 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
311 }
312
313 auto srcInt = cast<IntegerType>(srcElemType);
314 auto dstInt = cast<IntegerType>(dstElemType);
315 if (dstInt.getWidth() < srcInt.getWidth())
316 return arith::TruncIOp::create(builder, loc, dstType, src);
317
318 if (castKind == CastKind::Signed)
319 return arith::ExtSIOp::create(builder, loc, dstType, src);
320 return arith::ExtUIOp::create(builder, loc, dstType, src);
321}
322
323struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
324 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
325 ArrayRef<unsigned> target)
326 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
327
329 LogicalResult matchAndRewrite(Operation *op,
330 PatternRewriter &rewriter) const override {
331 if (op->getNumResults() == 0)
332 return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
333
335 if (failed(collectRanges(solver, op->getOperands(), ranges)))
336 return rewriter.notifyMatchFailure(op, "input without specified range");
337 if (failed(collectRanges(solver, op->getResults(), ranges)))
338 return rewriter.notifyMatchFailure(op, "output without specified range");
340 Type srcType = op->getResult(0).getType();
341 if (!llvm::all_equal(op->getResultTypes()))
342 return rewriter.notifyMatchFailure(op, "mismatched result types");
343 if (op->getNumOperands() == 0 ||
344 !llvm::all_of(op->getOperandTypes(),
345 [=](Type t) { return t == srcType; }))
346 return rewriter.notifyMatchFailure(
347 op, "no operands or operand types don't match result type");
348
349 for (unsigned targetBitwidth : targetBitwidths) {
350 CastKind castKind = CastKind::Both;
351 for (const ConstantIntRanges &range : ranges) {
352 castKind = mergeCastKinds(castKind,
353 checkTruncatability(range, targetBitwidth));
354 if (castKind == CastKind::None)
355 break;
356 }
357 if (castKind == CastKind::None)
358 continue;
359 Type targetType = getTargetType(srcType, targetBitwidth);
360 if (targetType == srcType)
361 continue;
362
363 Location loc = op->getLoc();
364 IRMapping mapping;
365 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
366 CastKind argCastKind = castKind;
367 // When dealing with `index` values, preserve non-negativity in the
368 // index_casts since we can't recover this in unsigned when equivalent.
369 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
370 argCastKind = CastKind::Both;
371 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
372 mapping.map(arg, newArg);
373 }
374
375 Operation *newOp = rewriter.clone(*op, mapping);
376 rewriter.modifyOpInPlace(newOp, [&]() {
377 for (OpResult res : newOp->getResults()) {
378 res.setType(targetType);
379 }
380 });
382 for (auto [newRes, oldRes] :
383 llvm::zip_equal(newOp->getResults(), op->getResults())) {
384 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
385 copyIntegerRange(solver, oldRes, castBack);
386 newResults.push_back(castBack);
387 }
389 rewriter.replaceOp(op, newResults);
390 return success();
391 }
392 return failure();
393 }
395private:
396 DataFlowSolver &solver;
397 SmallVector<unsigned, 4> targetBitwidths;
399
400struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
401 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
402 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
403
404 LogicalResult matchAndRewrite(arith::CmpIOp op,
405 PatternRewriter &rewriter) const override {
406 Value lhs = op.getLhs();
407 Value rhs = op.getRhs();
408
410 if (failed(collectRanges(solver, op.getOperands(), ranges)))
411 return failure();
412 const ConstantIntRanges &lhsRange = ranges[0];
413 const ConstantIntRanges &rhsRange = ranges[1];
414
415 Type srcType = lhs.getType();
416 for (unsigned targetBitwidth : targetBitwidths) {
417 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
418 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
419 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
420 // Note: this includes target width > src width.
421 if (castKind == CastKind::None)
422 continue;
424 Type targetType = getTargetType(srcType, targetBitwidth);
425 if (targetType == srcType)
426 continue;
428 Location loc = op->getLoc();
429 IRMapping mapping;
430 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
431 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
432 mapping.map(lhs, lhsCast);
433 mapping.map(rhs, rhsCast);
435 Operation *newOp = rewriter.clone(*op, mapping);
436 copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
437 rewriter.replaceOp(op, newOp->getResults());
438 return success();
440 return failure();
441 }
443private:
444 DataFlowSolver &solver;
446};
447
448/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
449/// This pattern assumes all passed `targetBitwidths` are not wider than index
450/// type.
451template <typename CastOp>
452struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
453 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
454 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
456 LogicalResult matchAndRewrite(CastOp op,
457 PatternRewriter &rewriter) const override {
458 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
459 if (!srcOp)
460 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
461
462 Value src = srcOp.getIn();
463 if (src.getType() != op.getType())
464 return rewriter.notifyMatchFailure(op, "outer types don't match");
465
466 if (!srcOp.getType().isIndex())
467 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
468
469 auto intType = dyn_cast<IntegerType>(op.getType());
470 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
471 return failure();
472
473 rewriter.replaceOp(op, src);
474 return success();
475 }
476
477private:
478 SmallVector<unsigned, 4> targetBitwidths;
479};
480
481struct IntRangeOptimizationsPass final
482 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
483
484 void runOnOperation() override {
485 Operation *op = getOperation();
486 MLIRContext *ctx = op->getContext();
487 DataFlowSolver solver;
488 solver.load<DeadCodeAnalysis>();
490 solver.load<IntegerRangeAnalysis>();
491 if (failed(solver.initializeAndRun(op)))
492 return signalPassFailure();
493
494 DataFlowListener listener(solver);
495
498
499 if (failed(applyPatternsGreedily(
500 op, std::move(patterns),
501 GreedyRewriteConfig().setListener(&listener))))
502 signalPassFailure();
503 }
504};
505
506struct IntRangeNarrowingPass final
507 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
508 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
509
510 void runOnOperation() override {
511 Operation *op = getOperation();
512 MLIRContext *ctx = op->getContext();
513 DataFlowSolver solver;
514 solver.load<DeadCodeAnalysis>();
515 solver.load<IntegerRangeAnalysis>();
516 if (failed(solver.initializeAndRun(op)))
517 return signalPassFailure();
518
519 DataFlowListener listener(solver);
520
521 RewritePatternSet patterns(ctx);
522 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
523
524 // We specifically need bottom-up traversal as cmpi pattern needs range
525 // data, attached to its original argument values.
527 op, std::move(patterns),
528 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
529 &listener))))
530 signalPassFailure();
531 }
532};
533} // namespace
534
537 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
539}
540
543 ArrayRef<unsigned> bitwidthsSupported) {
544 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
545 bitwidthsSupported);
546 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
548 bitwidthsSupported);
549}
550
552 return std::make_unique<IntRangeOptimizationsPass>();
553}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
lhs
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
MLIRContext * getContext() const
Definition Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This is a value defined by a result of an operation.
Definition Value.h:457
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...