MLIR  20.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32 } // namespace mlir::arith
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 using namespace mlir::dataflow;
37 
38 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
39  Value value) {
40  auto *maybeInferredRange =
42  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
43  return std::nullopt;
44  const ConstantIntRanges &inferredRange =
45  maybeInferredRange->getValue().getValue();
46  return inferredRange.getConstantValue();
47 }
48 
49 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
50  Value newVal) {
51  assert(oldVal.getType() == newVal.getType() &&
52  "Can't copy integer ranges between different types");
53  auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
54  if (!oldState)
55  return;
56  (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
57  *oldState);
58 }
59 
60 /// Patterned after SCCP
61 static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
62  PatternRewriter &rewriter,
63  Value value) {
64  if (value.use_empty())
65  return failure();
66  std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
67  if (!maybeConstValue.has_value())
68  return failure();
69 
70  Type type = value.getType();
71  Location loc = value.getLoc();
72  Operation *maybeDefiningOp = value.getDefiningOp();
73  Dialect *valueDialect =
74  maybeDefiningOp ? maybeDefiningOp->getDialect()
75  : value.getParentRegion()->getParentOp()->getDialect();
76 
77  Attribute constAttr;
78  if (auto shaped = dyn_cast<ShapedType>(type)) {
79  constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
80  } else {
81  constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
82  }
83  Operation *constOp =
84  valueDialect->materializeConstant(rewriter, constAttr, type, loc);
85  // Fall back to arith.constant if the dialect materializer doesn't know what
86  // to do with an integer constant.
87  if (!constOp)
88  constOp = rewriter.getContext()
89  ->getLoadedDialect<ArithDialect>()
90  ->materializeConstant(rewriter, constAttr, type, loc);
91  if (!constOp)
92  return failure();
93 
94  copyIntegerRange(solver, value, constOp->getResult(0));
95  rewriter.replaceAllUsesWith(value, constOp->getResult(0));
96  return success();
97 }
98 
99 namespace {
100 class DataFlowListener : public RewriterBase::Listener {
101 public:
102  DataFlowListener(DataFlowSolver &s) : s(s) {}
103 
104 protected:
105  void notifyOperationErased(Operation *op) override {
106  s.eraseState(s.getProgramPointAfter(op));
107  for (Value res : op->getResults())
108  s.eraseState(res);
109  }
110 
111  DataFlowSolver &s;
112 };
113 
114 /// Rewrite any results of `op` that were inferred to be constant integers to
115 /// and replace their uses with that constant. Return success() if all results
116 /// where thus replaced and the operation is erased. Also replace any block
117 /// arguments with their constant values.
118 struct MaterializeKnownConstantValues : public RewritePattern {
119  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
120  : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
121  solver(s) {}
122 
123  LogicalResult match(Operation *op) const override {
124  if (matchPattern(op, m_Constant()))
125  return failure();
126 
127  auto needsReplacing = [&](Value v) {
128  return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
129  };
130  bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
131  if (op->getNumRegions() == 0)
132  return success(hasConstantResults);
133  bool hasConstantRegionArgs = false;
134  for (Region &region : op->getRegions()) {
135  for (Block &block : region.getBlocks()) {
136  hasConstantRegionArgs |=
137  llvm::any_of(block.getArguments(), needsReplacing);
138  }
139  }
140  return success(hasConstantResults || hasConstantRegionArgs);
141  }
142 
143  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
144  bool replacedAll = (op->getNumResults() != 0);
145  for (Value v : op->getResults())
146  replacedAll &=
147  (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
148  v.use_empty());
149  if (replacedAll && isOpTriviallyDead(op)) {
150  rewriter.eraseOp(op);
151  return;
152  }
153 
154  PatternRewriter::InsertionGuard guard(rewriter);
155  for (Region &region : op->getRegions()) {
156  for (Block &block : region.getBlocks()) {
157  rewriter.setInsertionPointToStart(&block);
158  for (BlockArgument &arg : block.getArguments()) {
159  (void)maybeReplaceWithConstant(solver, rewriter, arg);
160  }
161  }
162  }
163  }
164 
165 private:
166  DataFlowSolver &solver;
167 };
168 
169 template <typename RemOp>
170 struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
171  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
172  : OpRewritePattern<RemOp>(context), solver(s) {}
173 
174  LogicalResult matchAndRewrite(RemOp op,
175  PatternRewriter &rewriter) const override {
176  Value lhs = op.getOperand(0);
177  Value rhs = op.getOperand(1);
178  auto maybeModulus = getConstantIntValue(rhs);
179  if (!maybeModulus.has_value())
180  return failure();
181  int64_t modulus = *maybeModulus;
182  if (modulus <= 0)
183  return failure();
184  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
185  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
186  return failure();
187  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
188  const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
189  const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
190  // The minima and maxima here are given as closed ranges, we must be
191  // strictly less than the modulus.
192  if (min.isNegative() || min.uge(modulus))
193  return failure();
194  if (max.isNegative() || max.uge(modulus))
195  return failure();
196  if (!min.ule(max))
197  return failure();
198 
199  // With all those conditions out of the way, we know thas this invocation of
200  // a remainder is a noop because the input is strictly within the range
201  // [0, modulus), so get rid of it.
202  rewriter.replaceOp(op, ValueRange{lhs});
203  return success();
204  }
205 
206 private:
207  DataFlowSolver &solver;
208 };
209 
210 /// Gather ranges for all the values in `values`. Appends to the existing
211 /// vector.
212 static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
214  for (Value val : values) {
215  auto *maybeInferredRange =
217  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
218  return failure();
219 
220  const ConstantIntRanges &inferredRange =
221  maybeInferredRange->getValue().getValue();
222  ranges.push_back(inferredRange);
223  }
224  return success();
225 }
226 
227 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
228 /// return shaped type as well.
229 static Type getTargetType(Type srcType, unsigned targetBitwidth) {
230  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
231  if (auto shaped = dyn_cast<ShapedType>(srcType))
232  return shaped.clone(dstType);
233 
234  assert(srcType.isIntOrIndex() && "Invalid src type");
235  return dstType;
236 }
237 
238 namespace {
239 // Enum for tracking which type of truncation should be performed
240 // to narrow an operation, if any.
241 enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
242 } // namespace
243 
244 /// If the values within `range` can be represented using only `width` bits,
245 /// return the kind of truncation needed to preserve that property.
246 ///
247 /// This check relies on the fact that the signed and unsigned ranges are both
248 /// always correct, but that one might be an approximation of the other,
249 /// so we want to use the correct truncation operation.
250 static CastKind checkTruncatability(const ConstantIntRanges &range,
251  unsigned targetWidth) {
252  unsigned srcWidth = range.smin().getBitWidth();
253  if (srcWidth <= targetWidth)
254  return CastKind::None;
255  unsigned removedWidth = srcWidth - targetWidth;
256  // The sign bits need to extend into the sign bit of the target width. For
257  // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
258  // bits.
259  bool canTruncateSigned =
260  range.smin().getNumSignBits() >= (removedWidth + 1) &&
261  range.smax().getNumSignBits() >= (removedWidth + 1);
262  bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
263  range.umax().countLeadingZeros() >= removedWidth;
264  if (canTruncateSigned && canTruncateUnsigned)
265  return CastKind::Both;
266  if (canTruncateSigned)
267  return CastKind::Signed;
268  if (canTruncateUnsigned)
269  return CastKind::Unsigned;
270  return CastKind::None;
271 }
272 
273 static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
274  if (lhs == CastKind::None || rhs == CastKind::None)
275  return CastKind::None;
276  if (lhs == CastKind::Both)
277  return rhs;
278  if (rhs == CastKind::Both)
279  return lhs;
280  if (lhs == rhs)
281  return lhs;
282  return CastKind::None;
283 }
284 
285 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
286  CastKind castKind) {
287  Type srcType = src.getType();
288  assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
289  "Mixing vector and non-vector types");
290  assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
291  Type srcElemType = getElementTypeOrSelf(srcType);
292  Type dstElemType = getElementTypeOrSelf(dstType);
293  assert(srcElemType.isIntOrIndex() && "Invalid src type");
294  assert(dstElemType.isIntOrIndex() && "Invalid dst type");
295  if (srcType == dstType)
296  return src;
297 
298  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
299  if (castKind == CastKind::Signed)
300  return builder.create<arith::IndexCastOp>(loc, dstType, src);
301  return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
302  }
303 
304  auto srcInt = cast<IntegerType>(srcElemType);
305  auto dstInt = cast<IntegerType>(dstElemType);
306  if (dstInt.getWidth() < srcInt.getWidth())
307  return builder.create<arith::TruncIOp>(loc, dstType, src);
308 
309  if (castKind == CastKind::Signed)
310  return builder.create<arith::ExtSIOp>(loc, dstType, src);
311  return builder.create<arith::ExtUIOp>(loc, dstType, src);
312 }
313 
314 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
315  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
316  ArrayRef<unsigned> target)
317  : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
318 
320  LogicalResult matchAndRewrite(Operation *op,
321  PatternRewriter &rewriter) const override {
322  if (op->getNumResults() == 0)
323  return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
324 
326  if (failed(collectRanges(solver, op->getOperands(), ranges)))
327  return rewriter.notifyMatchFailure(op, "input without specified range");
328  if (failed(collectRanges(solver, op->getResults(), ranges)))
329  return rewriter.notifyMatchFailure(op, "output without specified range");
330 
331  Type srcType = op->getResult(0).getType();
332  if (!llvm::all_equal(op->getResultTypes()))
333  return rewriter.notifyMatchFailure(op, "mismatched result types");
334  if (op->getNumOperands() == 0 ||
335  !llvm::all_of(op->getOperandTypes(),
336  [=](Type t) { return t == srcType; }))
337  return rewriter.notifyMatchFailure(
338  op, "no operands or operand types don't match result type");
339 
340  for (unsigned targetBitwidth : targetBitwidths) {
341  CastKind castKind = CastKind::Both;
342  for (const ConstantIntRanges &range : ranges) {
343  castKind = mergeCastKinds(castKind,
344  checkTruncatability(range, targetBitwidth));
345  if (castKind == CastKind::None)
346  break;
347  }
348  if (castKind == CastKind::None)
349  continue;
350  Type targetType = getTargetType(srcType, targetBitwidth);
351  if (targetType == srcType)
352  continue;
353 
354  Location loc = op->getLoc();
355  IRMapping mapping;
356  for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
357  CastKind argCastKind = castKind;
358  // When dealing with `index` values, preserve non-negativity in the
359  // index_casts since we can't recover this in unsigned when equivalent.
360  if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
361  argCastKind = CastKind::Both;
362  Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
363  mapping.map(arg, newArg);
364  }
365 
366  Operation *newOp = rewriter.clone(*op, mapping);
367  rewriter.modifyOpInPlace(newOp, [&]() {
368  for (OpResult res : newOp->getResults()) {
369  res.setType(targetType);
370  }
371  });
372  SmallVector<Value> newResults;
373  for (auto [newRes, oldRes] :
374  llvm::zip_equal(newOp->getResults(), op->getResults())) {
375  Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
376  copyIntegerRange(solver, oldRes, castBack);
377  newResults.push_back(castBack);
378  }
379 
380  rewriter.replaceOp(op, newResults);
381  return success();
382  }
383  return failure();
384  }
385 
386 private:
387  DataFlowSolver &solver;
388  SmallVector<unsigned, 4> targetBitwidths;
389 };
390 
391 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
392  NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
393  : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
394 
395  LogicalResult matchAndRewrite(arith::CmpIOp op,
396  PatternRewriter &rewriter) const override {
397  Value lhs = op.getLhs();
398  Value rhs = op.getRhs();
399 
401  if (failed(collectRanges(solver, op.getOperands(), ranges)))
402  return failure();
403  const ConstantIntRanges &lhsRange = ranges[0];
404  const ConstantIntRanges &rhsRange = ranges[1];
405 
406  Type srcType = lhs.getType();
407  for (unsigned targetBitwidth : targetBitwidths) {
408  CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
409  CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
410  CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
411  // Note: this includes target width > src width.
412  if (castKind == CastKind::None)
413  continue;
414 
415  Type targetType = getTargetType(srcType, targetBitwidth);
416  if (targetType == srcType)
417  continue;
418 
419  Location loc = op->getLoc();
420  IRMapping mapping;
421  Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
422  Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
423  mapping.map(lhs, lhsCast);
424  mapping.map(rhs, rhsCast);
425 
426  Operation *newOp = rewriter.clone(*op, mapping);
427  copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
428  rewriter.replaceOp(op, newOp->getResults());
429  return success();
430  }
431  return failure();
432  }
433 
434 private:
435  DataFlowSolver &solver;
436  SmallVector<unsigned, 4> targetBitwidths;
437 };
438 
439 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
440 /// This pattern assumes all passed `targetBitwidths` are not wider than index
441 /// type.
442 template <typename CastOp>
443 struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
444  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
445  : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
446 
447  LogicalResult matchAndRewrite(CastOp op,
448  PatternRewriter &rewriter) const override {
449  auto srcOp = op.getIn().template getDefiningOp<CastOp>();
450  if (!srcOp)
451  return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
452 
453  Value src = srcOp.getIn();
454  if (src.getType() != op.getType())
455  return rewriter.notifyMatchFailure(op, "outer types don't match");
456 
457  if (!srcOp.getType().isIndex())
458  return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
459 
460  auto intType = dyn_cast<IntegerType>(op.getType());
461  if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
462  return failure();
463 
464  rewriter.replaceOp(op, src);
465  return success();
466  }
467 
468 private:
469  SmallVector<unsigned, 4> targetBitwidths;
470 };
471 
472 struct IntRangeOptimizationsPass final
473  : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
474 
475  void runOnOperation() override {
476  Operation *op = getOperation();
477  MLIRContext *ctx = op->getContext();
478  DataFlowSolver solver;
479  solver.load<DeadCodeAnalysis>();
480  solver.load<IntegerRangeAnalysis>();
481  if (failed(solver.initializeAndRun(op)))
482  return signalPassFailure();
483 
484  DataFlowListener listener(solver);
485 
488 
490  config.listener = &listener;
491 
492  if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
493  signalPassFailure();
494  }
495 };
496 
497 struct IntRangeNarrowingPass final
498  : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
499  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
500 
501  void runOnOperation() override {
502  Operation *op = getOperation();
503  MLIRContext *ctx = op->getContext();
504  DataFlowSolver solver;
505  solver.load<DeadCodeAnalysis>();
506  solver.load<IntegerRangeAnalysis>();
507  if (failed(solver.initializeAndRun(op)))
508  return signalPassFailure();
509 
510  DataFlowListener listener(solver);
511 
513  populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
514 
516  // We specifically need bottom-up traversal as cmpi pattern needs range
517  // data, attached to its original argument values.
518  config.useTopDownTraversal = false;
519  config.listener = &listener;
520 
521  if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
522  signalPassFailure();
523  }
524 };
525 } // namespace
526 
529  patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
530  DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
531 }
532 
535  ArrayRef<unsigned> bitwidthsSupported) {
536  patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
537  bitwidthsSupported);
538  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
539  FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
540  bitwidthsSupported);
541 }
542 
544  return std::make_unique<IntRangeOptimizationsPass>();
545 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value)
Patterned after SCCP.
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
MLIRContext * getContext() const
Definition: Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:384
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:386
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358