MLIR  21.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 
30 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
31 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32 } // namespace mlir::arith
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 using namespace mlir::dataflow;
37 
38 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
39  Value value) {
40  auto *maybeInferredRange =
42  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
43  return std::nullopt;
44  const ConstantIntRanges &inferredRange =
45  maybeInferredRange->getValue().getValue();
46  return inferredRange.getConstantValue();
47 }
48 
49 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
50  Value newVal) {
51  assert(oldVal.getType() == newVal.getType() &&
52  "Can't copy integer ranges between different types");
53  auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
54  if (!oldState)
55  return;
56  (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
57  *oldState);
58 }
59 
60 namespace mlir::dataflow {
61 /// Patterned after SCCP
63  RewriterBase &rewriter, Value value) {
64  if (value.use_empty())
65  return failure();
66  std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
67  if (!maybeConstValue.has_value())
68  return failure();
69 
70  Type type = value.getType();
71  Location loc = value.getLoc();
72  Operation *maybeDefiningOp = value.getDefiningOp();
73  Dialect *valueDialect =
74  maybeDefiningOp ? maybeDefiningOp->getDialect()
75  : value.getParentRegion()->getParentOp()->getDialect();
76 
77  Attribute constAttr;
78  if (auto shaped = dyn_cast<ShapedType>(type)) {
79  constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
80  } else {
81  constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
82  }
83  Operation *constOp =
84  valueDialect->materializeConstant(rewriter, constAttr, type, loc);
85  // Fall back to arith.constant if the dialect materializer doesn't know what
86  // to do with an integer constant.
87  if (!constOp)
88  constOp = rewriter.getContext()
89  ->getLoadedDialect<ArithDialect>()
90  ->materializeConstant(rewriter, constAttr, type, loc);
91  if (!constOp)
92  return failure();
93 
94  OpResult res = constOp->getResult(0);
96  solver.eraseState(res);
97  copyIntegerRange(solver, value, res);
98  rewriter.replaceAllUsesWith(value, res);
99  return success();
100 }
101 } // namespace mlir::dataflow
102 
103 namespace {
104 class DataFlowListener : public RewriterBase::Listener {
105 public:
106  DataFlowListener(DataFlowSolver &s) : s(s) {}
107 
108 protected:
109  void notifyOperationErased(Operation *op) override {
110  s.eraseState(s.getProgramPointAfter(op));
111  for (Value res : op->getResults())
112  s.eraseState(res);
113  }
114 
115  DataFlowSolver &s;
116 };
117 
118 /// Rewrite any results of `op` that were inferred to be constant integers to
119 /// and replace their uses with that constant. Return success() if all results
120 /// where thus replaced and the operation is erased. Also replace any block
121 /// arguments with their constant values.
122 struct MaterializeKnownConstantValues : public RewritePattern {
123  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
124  : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
125  /*benefit=*/1, context),
126  solver(s) {}
127 
128  LogicalResult matchAndRewrite(Operation *op,
129  PatternRewriter &rewriter) const override {
130  if (matchPattern(op, m_Constant()))
131  return failure();
132 
133  auto needsReplacing = [&](Value v) {
134  return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
135  };
136  bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
137  if (op->getNumRegions() == 0)
138  if (!hasConstantResults)
139  return failure();
140  bool hasConstantRegionArgs = false;
141  for (Region &region : op->getRegions()) {
142  for (Block &block : region.getBlocks()) {
143  hasConstantRegionArgs |=
144  llvm::any_of(block.getArguments(), needsReplacing);
145  }
146  }
147  if (!hasConstantResults && !hasConstantRegionArgs)
148  return failure();
149 
150  bool replacedAll = (op->getNumResults() != 0);
151  for (Value v : op->getResults())
152  replacedAll &=
153  (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
154  v.use_empty());
155  if (replacedAll && isOpTriviallyDead(op)) {
156  rewriter.eraseOp(op);
157  return success();
158  }
159 
160  PatternRewriter::InsertionGuard guard(rewriter);
161  for (Region &region : op->getRegions()) {
162  for (Block &block : region.getBlocks()) {
163  rewriter.setInsertionPointToStart(&block);
164  for (BlockArgument &arg : block.getArguments()) {
165  (void)maybeReplaceWithConstant(solver, rewriter, arg);
166  }
167  }
168  }
169 
170  return success();
171  }
172 
173 private:
174  DataFlowSolver &solver;
175 };
176 
177 template <typename RemOp>
178 struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
179  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
180  : OpRewritePattern<RemOp>(context), solver(s) {}
181 
182  LogicalResult matchAndRewrite(RemOp op,
183  PatternRewriter &rewriter) const override {
184  Value lhs = op.getOperand(0);
185  Value rhs = op.getOperand(1);
186  auto maybeModulus = getConstantIntValue(rhs);
187  if (!maybeModulus.has_value())
188  return failure();
189  int64_t modulus = *maybeModulus;
190  if (modulus <= 0)
191  return failure();
192  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
193  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
194  return failure();
195  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
196  const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
197  const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
198  // The minima and maxima here are given as closed ranges, we must be
199  // strictly less than the modulus.
200  if (min.isNegative() || min.uge(modulus))
201  return failure();
202  if (max.isNegative() || max.uge(modulus))
203  return failure();
204  if (!min.ule(max))
205  return failure();
206 
207  // With all those conditions out of the way, we know thas this invocation of
208  // a remainder is a noop because the input is strictly within the range
209  // [0, modulus), so get rid of it.
210  rewriter.replaceOp(op, ValueRange{lhs});
211  return success();
212  }
213 
214 private:
215  DataFlowSolver &solver;
216 };
217 
218 /// Gather ranges for all the values in `values`. Appends to the existing
219 /// vector.
220 static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
222  for (Value val : values) {
223  auto *maybeInferredRange =
225  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
226  return failure();
227 
228  const ConstantIntRanges &inferredRange =
229  maybeInferredRange->getValue().getValue();
230  ranges.push_back(inferredRange);
231  }
232  return success();
233 }
234 
235 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
236 /// return shaped type as well.
237 static Type getTargetType(Type srcType, unsigned targetBitwidth) {
238  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
239  if (auto shaped = dyn_cast<ShapedType>(srcType))
240  return shaped.clone(dstType);
241 
242  assert(srcType.isIntOrIndex() && "Invalid src type");
243  return dstType;
244 }
245 
246 namespace {
247 // Enum for tracking which type of truncation should be performed
248 // to narrow an operation, if any.
249 enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
250 } // namespace
251 
252 /// If the values within `range` can be represented using only `width` bits,
253 /// return the kind of truncation needed to preserve that property.
254 ///
255 /// This check relies on the fact that the signed and unsigned ranges are both
256 /// always correct, but that one might be an approximation of the other,
257 /// so we want to use the correct truncation operation.
258 static CastKind checkTruncatability(const ConstantIntRanges &range,
259  unsigned targetWidth) {
260  unsigned srcWidth = range.smin().getBitWidth();
261  if (srcWidth <= targetWidth)
262  return CastKind::None;
263  unsigned removedWidth = srcWidth - targetWidth;
264  // The sign bits need to extend into the sign bit of the target width. For
265  // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
266  // bits.
267  bool canTruncateSigned =
268  range.smin().getNumSignBits() >= (removedWidth + 1) &&
269  range.smax().getNumSignBits() >= (removedWidth + 1);
270  bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
271  range.umax().countLeadingZeros() >= removedWidth;
272  if (canTruncateSigned && canTruncateUnsigned)
273  return CastKind::Both;
274  if (canTruncateSigned)
275  return CastKind::Signed;
276  if (canTruncateUnsigned)
277  return CastKind::Unsigned;
278  return CastKind::None;
279 }
280 
281 static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
282  if (lhs == CastKind::None || rhs == CastKind::None)
283  return CastKind::None;
284  if (lhs == CastKind::Both)
285  return rhs;
286  if (rhs == CastKind::Both)
287  return lhs;
288  if (lhs == rhs)
289  return lhs;
290  return CastKind::None;
291 }
292 
293 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
294  CastKind castKind) {
295  Type srcType = src.getType();
296  assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
297  "Mixing vector and non-vector types");
298  assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
299  Type srcElemType = getElementTypeOrSelf(srcType);
300  Type dstElemType = getElementTypeOrSelf(dstType);
301  assert(srcElemType.isIntOrIndex() && "Invalid src type");
302  assert(dstElemType.isIntOrIndex() && "Invalid dst type");
303  if (srcType == dstType)
304  return src;
305 
306  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
307  if (castKind == CastKind::Signed)
308  return builder.create<arith::IndexCastOp>(loc, dstType, src);
309  return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
310  }
311 
312  auto srcInt = cast<IntegerType>(srcElemType);
313  auto dstInt = cast<IntegerType>(dstElemType);
314  if (dstInt.getWidth() < srcInt.getWidth())
315  return builder.create<arith::TruncIOp>(loc, dstType, src);
316 
317  if (castKind == CastKind::Signed)
318  return builder.create<arith::ExtSIOp>(loc, dstType, src);
319  return builder.create<arith::ExtUIOp>(loc, dstType, src);
320 }
321 
322 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
323  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
324  ArrayRef<unsigned> target)
325  : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
326 
328  LogicalResult matchAndRewrite(Operation *op,
329  PatternRewriter &rewriter) const override {
330  if (op->getNumResults() == 0)
331  return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
332 
334  if (failed(collectRanges(solver, op->getOperands(), ranges)))
335  return rewriter.notifyMatchFailure(op, "input without specified range");
336  if (failed(collectRanges(solver, op->getResults(), ranges)))
337  return rewriter.notifyMatchFailure(op, "output without specified range");
338 
339  Type srcType = op->getResult(0).getType();
340  if (!llvm::all_equal(op->getResultTypes()))
341  return rewriter.notifyMatchFailure(op, "mismatched result types");
342  if (op->getNumOperands() == 0 ||
343  !llvm::all_of(op->getOperandTypes(),
344  [=](Type t) { return t == srcType; }))
345  return rewriter.notifyMatchFailure(
346  op, "no operands or operand types don't match result type");
347 
348  for (unsigned targetBitwidth : targetBitwidths) {
349  CastKind castKind = CastKind::Both;
350  for (const ConstantIntRanges &range : ranges) {
351  castKind = mergeCastKinds(castKind,
352  checkTruncatability(range, targetBitwidth));
353  if (castKind == CastKind::None)
354  break;
355  }
356  if (castKind == CastKind::None)
357  continue;
358  Type targetType = getTargetType(srcType, targetBitwidth);
359  if (targetType == srcType)
360  continue;
361 
362  Location loc = op->getLoc();
363  IRMapping mapping;
364  for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
365  CastKind argCastKind = castKind;
366  // When dealing with `index` values, preserve non-negativity in the
367  // index_casts since we can't recover this in unsigned when equivalent.
368  if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
369  argCastKind = CastKind::Both;
370  Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
371  mapping.map(arg, newArg);
372  }
373 
374  Operation *newOp = rewriter.clone(*op, mapping);
375  rewriter.modifyOpInPlace(newOp, [&]() {
376  for (OpResult res : newOp->getResults()) {
377  res.setType(targetType);
378  }
379  });
380  SmallVector<Value> newResults;
381  for (auto [newRes, oldRes] :
382  llvm::zip_equal(newOp->getResults(), op->getResults())) {
383  Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
384  copyIntegerRange(solver, oldRes, castBack);
385  newResults.push_back(castBack);
386  }
387 
388  rewriter.replaceOp(op, newResults);
389  return success();
390  }
391  return failure();
392  }
393 
394 private:
395  DataFlowSolver &solver;
396  SmallVector<unsigned, 4> targetBitwidths;
397 };
398 
399 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
400  NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
401  : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
402 
403  LogicalResult matchAndRewrite(arith::CmpIOp op,
404  PatternRewriter &rewriter) const override {
405  Value lhs = op.getLhs();
406  Value rhs = op.getRhs();
407 
409  if (failed(collectRanges(solver, op.getOperands(), ranges)))
410  return failure();
411  const ConstantIntRanges &lhsRange = ranges[0];
412  const ConstantIntRanges &rhsRange = ranges[1];
413 
414  Type srcType = lhs.getType();
415  for (unsigned targetBitwidth : targetBitwidths) {
416  CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
417  CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
418  CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
419  // Note: this includes target width > src width.
420  if (castKind == CastKind::None)
421  continue;
422 
423  Type targetType = getTargetType(srcType, targetBitwidth);
424  if (targetType == srcType)
425  continue;
426 
427  Location loc = op->getLoc();
428  IRMapping mapping;
429  Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
430  Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
431  mapping.map(lhs, lhsCast);
432  mapping.map(rhs, rhsCast);
433 
434  Operation *newOp = rewriter.clone(*op, mapping);
435  copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
436  rewriter.replaceOp(op, newOp->getResults());
437  return success();
438  }
439  return failure();
440  }
441 
442 private:
443  DataFlowSolver &solver;
444  SmallVector<unsigned, 4> targetBitwidths;
445 };
446 
447 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
448 /// This pattern assumes all passed `targetBitwidths` are not wider than index
449 /// type.
450 template <typename CastOp>
451 struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
452  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
453  : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
454 
455  LogicalResult matchAndRewrite(CastOp op,
456  PatternRewriter &rewriter) const override {
457  auto srcOp = op.getIn().template getDefiningOp<CastOp>();
458  if (!srcOp)
459  return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
460 
461  Value src = srcOp.getIn();
462  if (src.getType() != op.getType())
463  return rewriter.notifyMatchFailure(op, "outer types don't match");
464 
465  if (!srcOp.getType().isIndex())
466  return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
467 
468  auto intType = dyn_cast<IntegerType>(op.getType());
469  if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
470  return failure();
471 
472  rewriter.replaceOp(op, src);
473  return success();
474  }
475 
476 private:
477  SmallVector<unsigned, 4> targetBitwidths;
478 };
479 
480 struct IntRangeOptimizationsPass final
481  : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
482 
483  void runOnOperation() override {
484  Operation *op = getOperation();
485  MLIRContext *ctx = op->getContext();
486  DataFlowSolver solver;
487  solver.load<DeadCodeAnalysis>();
488  solver.load<IntegerRangeAnalysis>();
489  if (failed(solver.initializeAndRun(op)))
490  return signalPassFailure();
491 
492  DataFlowListener listener(solver);
493 
496 
498  config.listener = &listener;
499 
500  if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
501  signalPassFailure();
502  }
503 };
504 
505 struct IntRangeNarrowingPass final
506  : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
507  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
508 
509  void runOnOperation() override {
510  Operation *op = getOperation();
511  MLIRContext *ctx = op->getContext();
512  DataFlowSolver solver;
513  solver.load<DeadCodeAnalysis>();
514  solver.load<IntegerRangeAnalysis>();
515  if (failed(solver.initializeAndRun(op)))
516  return signalPassFailure();
517 
518  DataFlowListener listener(solver);
519 
521  populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
522 
524  // We specifically need bottom-up traversal as cmpi pattern needs range
525  // data, attached to its original argument values.
526  config.useTopDownTraversal = false;
527  config.listener = &listener;
528 
529  if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
530  signalPassFailure();
531  }
532 };
533 } // namespace
534 
537  patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538  DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
539 }
540 
543  ArrayRef<unsigned> bitwidthsSupported) {
544  patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
545  bitwidthsSupported);
546  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547  FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
548  bitwidthsSupported);
549 }
550 
552  return std::make_unique<IntRangeOptimizationsPass>();
553 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
MLIRContext * getContext() const
Definition: Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This is a value defined by a result of an operation.
Definition: Value.h:433
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:346
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:348
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:194
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318