MLIR 22.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
15
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
27
28namespace mlir::arith {
29#define GEN_PASS_DEF_ARITHINTRANGEOPTS
30#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
31
32#define GEN_PASS_DEF_ARITHINTRANGENARROWING
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34} // namespace mlir::arith
35
36using namespace mlir;
37using namespace mlir::arith;
38using namespace mlir::dataflow;
39
40static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
41 Value value) {
42 auto *maybeInferredRange =
44 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
45 return std::nullopt;
46 const ConstantIntRanges &inferredRange =
47 maybeInferredRange->getValue().getValue();
48 return inferredRange.getConstantValue();
49}
50
51static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
52 Value newVal) {
53 assert(oldVal.getType() == newVal.getType() &&
54 "Can't copy integer ranges between different types");
55 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
56 if (!oldState)
57 return;
59 *oldState);
60}
61
62namespace mlir::dataflow {
63/// Patterned after SCCP
65 RewriterBase &rewriter, Value value) {
66 if (value.use_empty())
67 return failure();
68 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
69 if (!maybeConstValue.has_value())
70 return failure();
71
72 Type type = value.getType();
73 Location loc = value.getLoc();
74 Operation *maybeDefiningOp = value.getDefiningOp();
75 Dialect *valueDialect =
76 maybeDefiningOp ? maybeDefiningOp->getDialect()
78
79 Attribute constAttr;
80 if (auto shaped = dyn_cast<ShapedType>(type)) {
81 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
82 } else {
83 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
84 }
85 Operation *constOp =
86 valueDialect->materializeConstant(rewriter, constAttr, type, loc);
87 // Fall back to arith.constant if the dialect materializer doesn't know what
88 // to do with an integer constant.
89 if (!constOp)
90 constOp = rewriter.getContext()
91 ->getLoadedDialect<ArithDialect>()
92 ->materializeConstant(rewriter, constAttr, type, loc);
93 if (!constOp)
94 return failure();
95
96 OpResult res = constOp->getResult(0);
98 solver.eraseState(res);
99 copyIntegerRange(solver, value, res);
100 rewriter.replaceAllUsesWith(value, res);
101 return success();
102}
103} // namespace mlir::dataflow
104
105namespace {
106class DataFlowListener : public RewriterBase::Listener {
107public:
108 DataFlowListener(DataFlowSolver &s) : s(s) {}
109
110protected:
111 void notifyOperationErased(Operation *op) override {
112 s.eraseState(s.getProgramPointAfter(op));
113 for (Value res : op->getResults())
114 s.eraseState(res);
115 }
116
117 DataFlowSolver &s;
118};
119
120/// Rewrite any results of `op` that were inferred to be constant integers to
121/// and replace their uses with that constant. Return success() if all results
122/// where thus replaced and the operation is erased. Also replace any block
123/// arguments with their constant values.
124struct MaterializeKnownConstantValues : public RewritePattern {
125 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
126 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
127 /*benefit=*/1, context),
128 solver(s) {}
129
130 LogicalResult matchAndRewrite(Operation *op,
131 PatternRewriter &rewriter) const override {
132 if (matchPattern(op, m_Constant()))
133 return failure();
134
135 auto needsReplacing = [&](Value v) {
136 return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
137 };
138 bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
139 if (op->getNumRegions() == 0)
140 if (!hasConstantResults)
141 return failure();
142 bool hasConstantRegionArgs = false;
143 for (Region &region : op->getRegions()) {
144 for (Block &block : region.getBlocks()) {
145 hasConstantRegionArgs |=
146 llvm::any_of(block.getArguments(), needsReplacing);
147 }
148 }
149 if (!hasConstantResults && !hasConstantRegionArgs)
150 return failure();
151
152 bool replacedAll = (op->getNumResults() != 0);
153 for (Value v : op->getResults())
154 replacedAll &=
155 (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
156 v.use_empty());
157 if (replacedAll && isOpTriviallyDead(op)) {
158 rewriter.eraseOp(op);
159 return success();
160 }
161
162 PatternRewriter::InsertionGuard guard(rewriter);
163 for (Region &region : op->getRegions()) {
164 for (Block &block : region.getBlocks()) {
165 rewriter.setInsertionPointToStart(&block);
166 for (BlockArgument &arg : block.getArguments()) {
167 (void)maybeReplaceWithConstant(solver, rewriter, arg);
168 }
169 }
170 }
171
172 return success();
173 }
174
175private:
176 DataFlowSolver &solver;
177};
178
179template <typename RemOp>
180struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
181 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
182 : OpRewritePattern<RemOp>(context), solver(s) {}
183
184 LogicalResult matchAndRewrite(RemOp op,
185 PatternRewriter &rewriter) const override {
186 Value lhs = op.getOperand(0);
187 Value rhs = op.getOperand(1);
188 auto maybeModulus = getConstantIntValue(rhs);
189 if (!maybeModulus.has_value())
190 return failure();
191 int64_t modulus = *maybeModulus;
192 if (modulus <= 0)
193 return failure();
194 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
195 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
196 return failure();
197 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
198 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
199 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
200 // The minima and maxima here are given as closed ranges, we must be
201 // strictly less than the modulus.
202 if (min.isNegative() || min.uge(modulus))
203 return failure();
204 if (max.isNegative() || max.uge(modulus))
205 return failure();
206 if (!min.ule(max))
207 return failure();
208
209 // With all those conditions out of the way, we know thas this invocation of
210 // a remainder is a noop because the input is strictly within the range
211 // [0, modulus), so get rid of it.
212 rewriter.replaceOp(op, ValueRange{lhs});
213 return success();
214 }
215
216private:
217 DataFlowSolver &solver;
218};
219
220/// Gather ranges for all the values in `values`. Appends to the existing
221/// vector.
222static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
224 for (Value val : values) {
225 auto *maybeInferredRange =
227 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
228 return failure();
229
230 const ConstantIntRanges &inferredRange =
231 maybeInferredRange->getValue().getValue();
232 ranges.push_back(inferredRange);
233 }
234 return success();
235}
236
237/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
238/// return shaped type as well.
239static Type getTargetType(Type srcType, unsigned targetBitwidth) {
240 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
241 if (auto shaped = dyn_cast<ShapedType>(srcType))
242 return shaped.clone(dstType);
243
244 assert(srcType.isIntOrIndex() && "Invalid src type");
245 return dstType;
246}
247
248namespace {
249// Enum for tracking which type of truncation should be performed
250// to narrow an operation, if any.
251enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
252} // namespace
253
254/// If the values within `range` can be represented using only `width` bits,
255/// return the kind of truncation needed to preserve that property.
256///
257/// This check relies on the fact that the signed and unsigned ranges are both
258/// always correct, but that one might be an approximation of the other,
259/// so we want to use the correct truncation operation.
260static CastKind checkTruncatability(const ConstantIntRanges &range,
261 unsigned targetWidth) {
262 unsigned srcWidth = range.smin().getBitWidth();
263 if (srcWidth <= targetWidth)
264 return CastKind::None;
265 unsigned removedWidth = srcWidth - targetWidth;
266 // The sign bits need to extend into the sign bit of the target width. For
267 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
268 // bits.
269 bool canTruncateSigned =
270 range.smin().getNumSignBits() >= (removedWidth + 1) &&
271 range.smax().getNumSignBits() >= (removedWidth + 1);
272 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
273 range.umax().countLeadingZeros() >= removedWidth;
274 if (canTruncateSigned && canTruncateUnsigned)
275 return CastKind::Both;
276 if (canTruncateSigned)
277 return CastKind::Signed;
278 if (canTruncateUnsigned)
279 return CastKind::Unsigned;
280 return CastKind::None;
281}
282
283static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
284 if (lhs == CastKind::None || rhs == CastKind::None)
285 return CastKind::None;
286 if (lhs == CastKind::Both)
287 return rhs;
288 if (rhs == CastKind::Both)
289 return lhs;
290 if (lhs == rhs)
291 return lhs;
292 return CastKind::None;
293}
294
295static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
296 CastKind castKind) {
297 Type srcType = src.getType();
298 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
299 "Mixing vector and non-vector types");
300 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
301 Type srcElemType = getElementTypeOrSelf(srcType);
302 Type dstElemType = getElementTypeOrSelf(dstType);
303 assert(srcElemType.isIntOrIndex() && "Invalid src type");
304 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
305 if (srcType == dstType)
306 return src;
307
308 if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
309 if (castKind == CastKind::Signed)
310 return arith::IndexCastOp::create(builder, loc, dstType, src);
311 return arith::IndexCastUIOp::create(builder, loc, dstType, src);
312 }
313
314 auto srcInt = cast<IntegerType>(srcElemType);
315 auto dstInt = cast<IntegerType>(dstElemType);
316 if (dstInt.getWidth() < srcInt.getWidth())
317 return arith::TruncIOp::create(builder, loc, dstType, src);
318
319 if (castKind == CastKind::Signed)
320 return arith::ExtSIOp::create(builder, loc, dstType, src);
321 return arith::ExtUIOp::create(builder, loc, dstType, src);
322}
323
324struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
325 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
326 ArrayRef<unsigned> target)
327 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
328
330 LogicalResult matchAndRewrite(Operation *op,
331 PatternRewriter &rewriter) const override {
332 if (op->getNumResults() == 0)
333 return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
334
335 SmallVector<ConstantIntRanges> ranges;
336 if (failed(collectRanges(solver, op->getOperands(), ranges)))
337 return rewriter.notifyMatchFailure(op, "input without specified range");
338 if (failed(collectRanges(solver, op->getResults(), ranges)))
339 return rewriter.notifyMatchFailure(op, "output without specified range");
340
341 Type srcType = op->getResult(0).getType();
342 if (!llvm::all_equal(op->getResultTypes()))
343 return rewriter.notifyMatchFailure(op, "mismatched result types");
344 if (op->getNumOperands() == 0 ||
345 !llvm::all_of(op->getOperandTypes(),
346 [=](Type t) { return t == srcType; }))
347 return rewriter.notifyMatchFailure(
348 op, "no operands or operand types don't match result type");
349
350 for (unsigned targetBitwidth : targetBitwidths) {
351 CastKind castKind = CastKind::Both;
352 for (const ConstantIntRanges &range : ranges) {
353 castKind = mergeCastKinds(castKind,
354 checkTruncatability(range, targetBitwidth));
355 if (castKind == CastKind::None)
356 break;
357 }
358 if (castKind == CastKind::None)
359 continue;
360 Type targetType = getTargetType(srcType, targetBitwidth);
361 if (targetType == srcType)
362 continue;
363
364 Location loc = op->getLoc();
365 IRMapping mapping;
366 for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
367 CastKind argCastKind = castKind;
368 // When dealing with `index` values, preserve non-negativity in the
369 // index_casts since we can't recover this in unsigned when equivalent.
370 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
371 argCastKind = CastKind::Both;
372 Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
373 mapping.map(arg, newArg);
374 }
375
376 Operation *newOp = rewriter.clone(*op, mapping);
377 rewriter.modifyOpInPlace(newOp, [&]() {
378 for (OpResult res : newOp->getResults()) {
379 res.setType(targetType);
380 }
381 });
382 SmallVector<Value> newResults;
383 for (auto [newRes, oldRes] :
384 llvm::zip_equal(newOp->getResults(), op->getResults())) {
385 Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
386 copyIntegerRange(solver, oldRes, castBack);
387 newResults.push_back(castBack);
388 }
389
390 rewriter.replaceOp(op, newResults);
391 return success();
392 }
393 return failure();
394 }
395
396private:
397 DataFlowSolver &solver;
398 SmallVector<unsigned, 4> targetBitwidths;
399};
400
401struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
402 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
403 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
404
405 LogicalResult matchAndRewrite(arith::CmpIOp op,
406 PatternRewriter &rewriter) const override {
407 Value lhs = op.getLhs();
408 Value rhs = op.getRhs();
409
410 SmallVector<ConstantIntRanges> ranges;
411 if (failed(collectRanges(solver, op.getOperands(), ranges)))
412 return failure();
413 const ConstantIntRanges &lhsRange = ranges[0];
414 const ConstantIntRanges &rhsRange = ranges[1];
415
416 Type srcType = lhs.getType();
417 for (unsigned targetBitwidth : targetBitwidths) {
418 CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
419 CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
420 CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
421 // Note: this includes target width > src width.
422 if (castKind == CastKind::None)
423 continue;
424
425 Type targetType = getTargetType(srcType, targetBitwidth);
426 if (targetType == srcType)
427 continue;
428
429 Location loc = op->getLoc();
430 IRMapping mapping;
431 Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
432 Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
433 mapping.map(lhs, lhsCast);
434 mapping.map(rhs, rhsCast);
435
436 Operation *newOp = rewriter.clone(*op, mapping);
437 copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
438 rewriter.replaceOp(op, newOp->getResults());
439 return success();
440 }
441 return failure();
442 }
443
444private:
445 DataFlowSolver &solver;
446 SmallVector<unsigned, 4> targetBitwidths;
447};
448
449/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
450/// This pattern assumes all passed `targetBitwidths` are not wider than index
451/// type.
452template <typename CastOp>
453struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
454 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
455 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
456
457 LogicalResult matchAndRewrite(CastOp op,
458 PatternRewriter &rewriter) const override {
459 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
460 if (!srcOp)
461 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
462
463 Value src = srcOp.getIn();
464 if (src.getType() != op.getType())
465 return rewriter.notifyMatchFailure(op, "outer types don't match");
466
467 if (!srcOp.getType().isIndex())
468 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
469
470 auto intType = dyn_cast<IntegerType>(op.getType());
471 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
472 return failure();
473
474 rewriter.replaceOp(op, src);
475 return success();
476 }
477
478private:
479 SmallVector<unsigned, 4> targetBitwidths;
480};
481
482struct IntRangeOptimizationsPass final
483 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
484
485 void runOnOperation() override {
486 Operation *op = getOperation();
487 MLIRContext *ctx = op->getContext();
488 DataFlowSolver solver;
489 loadBaselineAnalyses(solver);
490 solver.load<IntegerRangeAnalysis>();
491 if (failed(solver.initializeAndRun(op)))
492 return signalPassFailure();
493
494 DataFlowListener listener(solver);
495
496 RewritePatternSet patterns(ctx);
498
500 op, std::move(patterns),
501 GreedyRewriteConfig().setListener(&listener))))
502 signalPassFailure();
503 }
504};
505
506struct IntRangeNarrowingPass final
507 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
508 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
509
510 void runOnOperation() override {
511 Operation *op = getOperation();
512 MLIRContext *ctx = op->getContext();
513 DataFlowSolver solver;
514 loadBaselineAnalyses(solver);
515 solver.load<IntegerRangeAnalysis>();
516 if (failed(solver.initializeAndRun(op)))
517 return signalPassFailure();
518
519 DataFlowListener listener(solver);
520
521 RewritePatternSet patterns(ctx);
522 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
523
524 // We specifically need bottom-up traversal as cmpi pattern needs range
525 // data, attached to its original argument values.
527 op, std::move(patterns),
528 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
529 &listener))))
530 signalPassFailure();
531 }
532};
533} // namespace
534
537 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
539}
540
543 ArrayRef<unsigned> bitwidthsSupported) {
544 patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
545 bitwidthsSupported);
546 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
548 bitwidthsSupported);
549}
550
552 return std::make_unique<IntRangeOptimizationsPass>();
553}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
lhs
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
MLIRContext * getContext() const
Definition Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This is a value defined by a result of an operation.
Definition Value.h:457
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...