MLIR  22.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
14 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
26 
27 namespace mlir::arith {
28 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
29 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
30 
31 #define GEN_PASS_DEF_ARITHINTRANGENARROWING
32 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33 } // namespace mlir::arith
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 using namespace mlir::dataflow;
38 
39 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
40  Value value) {
41  auto *maybeInferredRange =
43  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
44  return std::nullopt;
45  const ConstantIntRanges &inferredRange =
46  maybeInferredRange->getValue().getValue();
47  return inferredRange.getConstantValue();
48 }
49 
50 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
51  Value newVal) {
52  assert(oldVal.getType() == newVal.getType() &&
53  "Can't copy integer ranges between different types");
54  auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
55  if (!oldState)
56  return;
57  (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
58  *oldState);
59 }
60 
61 namespace mlir::dataflow {
62 /// Patterned after SCCP
64  RewriterBase &rewriter, Value value) {
65  if (value.use_empty())
66  return failure();
67  std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
68  if (!maybeConstValue.has_value())
69  return failure();
70 
71  Type type = value.getType();
72  Location loc = value.getLoc();
73  Operation *maybeDefiningOp = value.getDefiningOp();
74  Dialect *valueDialect =
75  maybeDefiningOp ? maybeDefiningOp->getDialect()
76  : value.getParentRegion()->getParentOp()->getDialect();
77 
78  Attribute constAttr;
79  if (auto shaped = dyn_cast<ShapedType>(type)) {
80  constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
81  } else {
82  constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
83  }
84  Operation *constOp =
85  valueDialect->materializeConstant(rewriter, constAttr, type, loc);
86  // Fall back to arith.constant if the dialect materializer doesn't know what
87  // to do with an integer constant.
88  if (!constOp)
89  constOp = rewriter.getContext()
90  ->getLoadedDialect<ArithDialect>()
91  ->materializeConstant(rewriter, constAttr, type, loc);
92  if (!constOp)
93  return failure();
94 
95  OpResult res = constOp->getResult(0);
97  solver.eraseState(res);
98  copyIntegerRange(solver, value, res);
99  rewriter.replaceAllUsesWith(value, res);
100  return success();
101 }
102 } // namespace mlir::dataflow
103 
104 namespace {
105 class DataFlowListener : public RewriterBase::Listener {
106 public:
107  DataFlowListener(DataFlowSolver &s) : s(s) {}
108 
109 protected:
110  void notifyOperationErased(Operation *op) override {
111  s.eraseState(s.getProgramPointAfter(op));
112  for (Value res : op->getResults())
113  s.eraseState(res);
114  }
115 
116  DataFlowSolver &s;
117 };
118 
119 /// Rewrite any results of `op` that were inferred to be constant integers to
120 /// and replace their uses with that constant. Return success() if all results
121 /// where thus replaced and the operation is erased. Also replace any block
122 /// arguments with their constant values.
123 struct MaterializeKnownConstantValues : public RewritePattern {
124  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
125  : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
126  /*benefit=*/1, context),
127  solver(s) {}
128 
129  LogicalResult matchAndRewrite(Operation *op,
130  PatternRewriter &rewriter) const override {
131  if (matchPattern(op, m_Constant()))
132  return failure();
133 
134  auto needsReplacing = [&](Value v) {
135  return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
136  };
137  bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
138  if (op->getNumRegions() == 0)
139  if (!hasConstantResults)
140  return failure();
141  bool hasConstantRegionArgs = false;
142  for (Region &region : op->getRegions()) {
143  for (Block &block : region.getBlocks()) {
144  hasConstantRegionArgs |=
145  llvm::any_of(block.getArguments(), needsReplacing);
146  }
147  }
148  if (!hasConstantResults && !hasConstantRegionArgs)
149  return failure();
150 
151  bool replacedAll = (op->getNumResults() != 0);
152  for (Value v : op->getResults())
153  replacedAll &=
154  (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
155  v.use_empty());
156  if (replacedAll && isOpTriviallyDead(op)) {
157  rewriter.eraseOp(op);
158  return success();
159  }
160 
161  PatternRewriter::InsertionGuard guard(rewriter);
162  for (Region &region : op->getRegions()) {
163  for (Block &block : region.getBlocks()) {
164  rewriter.setInsertionPointToStart(&block);
165  for (BlockArgument &arg : block.getArguments()) {
166  (void)maybeReplaceWithConstant(solver, rewriter, arg);
167  }
168  }
169  }
170 
171  return success();
172  }
173 
174 private:
175  DataFlowSolver &solver;
176 };
177 
178 template <typename RemOp>
179 struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
180  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
181  : OpRewritePattern<RemOp>(context), solver(s) {}
182 
183  LogicalResult matchAndRewrite(RemOp op,
184  PatternRewriter &rewriter) const override {
185  Value lhs = op.getOperand(0);
186  Value rhs = op.getOperand(1);
187  auto maybeModulus = getConstantIntValue(rhs);
188  if (!maybeModulus.has_value())
189  return failure();
190  int64_t modulus = *maybeModulus;
191  if (modulus <= 0)
192  return failure();
193  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
194  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
195  return failure();
196  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
197  const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
198  const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
199  // The minima and maxima here are given as closed ranges, we must be
200  // strictly less than the modulus.
201  if (min.isNegative() || min.uge(modulus))
202  return failure();
203  if (max.isNegative() || max.uge(modulus))
204  return failure();
205  if (!min.ule(max))
206  return failure();
207 
208  // With all those conditions out of the way, we know thas this invocation of
209  // a remainder is a noop because the input is strictly within the range
210  // [0, modulus), so get rid of it.
211  rewriter.replaceOp(op, ValueRange{lhs});
212  return success();
213  }
214 
215 private:
216  DataFlowSolver &solver;
217 };
218 
219 /// Gather ranges for all the values in `values`. Appends to the existing
220 /// vector.
221 static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
223  for (Value val : values) {
224  auto *maybeInferredRange =
226  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
227  return failure();
228 
229  const ConstantIntRanges &inferredRange =
230  maybeInferredRange->getValue().getValue();
231  ranges.push_back(inferredRange);
232  }
233  return success();
234 }
235 
236 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
237 /// return shaped type as well.
238 static Type getTargetType(Type srcType, unsigned targetBitwidth) {
239  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
240  if (auto shaped = dyn_cast<ShapedType>(srcType))
241  return shaped.clone(dstType);
242 
243  assert(srcType.isIntOrIndex() && "Invalid src type");
244  return dstType;
245 }
246 
247 namespace {
248 // Enum for tracking which type of truncation should be performed
249 // to narrow an operation, if any.
250 enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
251 } // namespace
252 
253 /// If the values within `range` can be represented using only `width` bits,
254 /// return the kind of truncation needed to preserve that property.
255 ///
256 /// This check relies on the fact that the signed and unsigned ranges are both
257 /// always correct, but that one might be an approximation of the other,
258 /// so we want to use the correct truncation operation.
259 static CastKind checkTruncatability(const ConstantIntRanges &range,
260  unsigned targetWidth) {
261  unsigned srcWidth = range.smin().getBitWidth();
262  if (srcWidth <= targetWidth)
263  return CastKind::None;
264  unsigned removedWidth = srcWidth - targetWidth;
265  // The sign bits need to extend into the sign bit of the target width. For
266  // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
267  // bits.
268  bool canTruncateSigned =
269  range.smin().getNumSignBits() >= (removedWidth + 1) &&
270  range.smax().getNumSignBits() >= (removedWidth + 1);
271  bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
272  range.umax().countLeadingZeros() >= removedWidth;
273  if (canTruncateSigned && canTruncateUnsigned)
274  return CastKind::Both;
275  if (canTruncateSigned)
276  return CastKind::Signed;
277  if (canTruncateUnsigned)
278  return CastKind::Unsigned;
279  return CastKind::None;
280 }
281 
282 static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
283  if (lhs == CastKind::None || rhs == CastKind::None)
284  return CastKind::None;
285  if (lhs == CastKind::Both)
286  return rhs;
287  if (rhs == CastKind::Both)
288  return lhs;
289  if (lhs == rhs)
290  return lhs;
291  return CastKind::None;
292 }
293 
294 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
295  CastKind castKind) {
296  Type srcType = src.getType();
297  assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
298  "Mixing vector and non-vector types");
299  assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
300  Type srcElemType = getElementTypeOrSelf(srcType);
301  Type dstElemType = getElementTypeOrSelf(dstType);
302  assert(srcElemType.isIntOrIndex() && "Invalid src type");
303  assert(dstElemType.isIntOrIndex() && "Invalid dst type");
304  if (srcType == dstType)
305  return src;
306 
307  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
308  if (castKind == CastKind::Signed)
309  return arith::IndexCastOp::create(builder, loc, dstType, src);
310  return arith::IndexCastUIOp::create(builder, loc, dstType, src);
311  }
312 
313  auto srcInt = cast<IntegerType>(srcElemType);
314  auto dstInt = cast<IntegerType>(dstElemType);
315  if (dstInt.getWidth() < srcInt.getWidth())
316  return arith::TruncIOp::create(builder, loc, dstType, src);
317 
318  if (castKind == CastKind::Signed)
319  return arith::ExtSIOp::create(builder, loc, dstType, src);
320  return arith::ExtUIOp::create(builder, loc, dstType, src);
321 }
322 
323 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
324  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
325  ArrayRef<unsigned> target)
326  : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
327 
329  LogicalResult matchAndRewrite(Operation *op,
330  PatternRewriter &rewriter) const override {
331  if (op->getNumResults() == 0)
332  return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
333 
335  if (failed(collectRanges(solver, op->getOperands(), ranges)))
336  return rewriter.notifyMatchFailure(op, "input without specified range");
337  if (failed(collectRanges(solver, op->getResults(), ranges)))
338  return rewriter.notifyMatchFailure(op, "output without specified range");
339 
340  Type srcType = op->getResult(0).getType();
341  if (!llvm::all_equal(op->getResultTypes()))
342  return rewriter.notifyMatchFailure(op, "mismatched result types");
343  if (op->getNumOperands() == 0 ||
344  !llvm::all_of(op->getOperandTypes(),
345  [=](Type t) { return t == srcType; }))
346  return rewriter.notifyMatchFailure(
347  op, "no operands or operand types don't match result type");
348 
349  for (unsigned targetBitwidth : targetBitwidths) {
350  CastKind castKind = CastKind::Both;
351  for (const ConstantIntRanges &range : ranges) {
352  castKind = mergeCastKinds(castKind,
353  checkTruncatability(range, targetBitwidth));
354  if (castKind == CastKind::None)
355  break;
356  }
357  if (castKind == CastKind::None)
358  continue;
359  Type targetType = getTargetType(srcType, targetBitwidth);
360  if (targetType == srcType)
361  continue;
362 
363  Location loc = op->getLoc();
364  IRMapping mapping;
365  for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
366  CastKind argCastKind = castKind;
367  // When dealing with `index` values, preserve non-negativity in the
368  // index_casts since we can't recover this in unsigned when equivalent.
369  if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
370  argCastKind = CastKind::Both;
371  Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
372  mapping.map(arg, newArg);
373  }
374 
375  Operation *newOp = rewriter.clone(*op, mapping);
376  rewriter.modifyOpInPlace(newOp, [&]() {
377  for (OpResult res : newOp->getResults()) {
378  res.setType(targetType);
379  }
380  });
381  SmallVector<Value> newResults;
382  for (auto [newRes, oldRes] :
383  llvm::zip_equal(newOp->getResults(), op->getResults())) {
384  Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
385  copyIntegerRange(solver, oldRes, castBack);
386  newResults.push_back(castBack);
387  }
388 
389  rewriter.replaceOp(op, newResults);
390  return success();
391  }
392  return failure();
393  }
394 
395 private:
396  DataFlowSolver &solver;
397  SmallVector<unsigned, 4> targetBitwidths;
398 };
399 
400 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
401  NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
402  : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
403 
404  LogicalResult matchAndRewrite(arith::CmpIOp op,
405  PatternRewriter &rewriter) const override {
406  Value lhs = op.getLhs();
407  Value rhs = op.getRhs();
408 
410  if (failed(collectRanges(solver, op.getOperands(), ranges)))
411  return failure();
412  const ConstantIntRanges &lhsRange = ranges[0];
413  const ConstantIntRanges &rhsRange = ranges[1];
414 
415  Type srcType = lhs.getType();
416  for (unsigned targetBitwidth : targetBitwidths) {
417  CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
418  CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
419  CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
420  // Note: this includes target width > src width.
421  if (castKind == CastKind::None)
422  continue;
423 
424  Type targetType = getTargetType(srcType, targetBitwidth);
425  if (targetType == srcType)
426  continue;
427 
428  Location loc = op->getLoc();
429  IRMapping mapping;
430  Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
431  Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
432  mapping.map(lhs, lhsCast);
433  mapping.map(rhs, rhsCast);
434 
435  Operation *newOp = rewriter.clone(*op, mapping);
436  copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
437  rewriter.replaceOp(op, newOp->getResults());
438  return success();
439  }
440  return failure();
441  }
442 
443 private:
444  DataFlowSolver &solver;
445  SmallVector<unsigned, 4> targetBitwidths;
446 };
447 
448 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
449 /// This pattern assumes all passed `targetBitwidths` are not wider than index
450 /// type.
451 template <typename CastOp>
452 struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
453  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
454  : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
455 
456  LogicalResult matchAndRewrite(CastOp op,
457  PatternRewriter &rewriter) const override {
458  auto srcOp = op.getIn().template getDefiningOp<CastOp>();
459  if (!srcOp)
460  return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
461 
462  Value src = srcOp.getIn();
463  if (src.getType() != op.getType())
464  return rewriter.notifyMatchFailure(op, "outer types don't match");
465 
466  if (!srcOp.getType().isIndex())
467  return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
468 
469  auto intType = dyn_cast<IntegerType>(op.getType());
470  if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
471  return failure();
472 
473  rewriter.replaceOp(op, src);
474  return success();
475  }
476 
477 private:
478  SmallVector<unsigned, 4> targetBitwidths;
479 };
480 
481 struct IntRangeOptimizationsPass final
482  : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
483 
484  void runOnOperation() override {
485  Operation *op = getOperation();
486  MLIRContext *ctx = op->getContext();
487  DataFlowSolver solver;
488  solver.load<DeadCodeAnalysis>();
490  solver.load<IntegerRangeAnalysis>();
491  if (failed(solver.initializeAndRun(op)))
492  return signalPassFailure();
493 
494  DataFlowListener listener(solver);
495 
498 
500  op, std::move(patterns),
501  GreedyRewriteConfig().setListener(&listener))))
502  signalPassFailure();
503  }
504 };
505 
506 struct IntRangeNarrowingPass final
507  : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
508  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
509 
510  void runOnOperation() override {
511  Operation *op = getOperation();
512  MLIRContext *ctx = op->getContext();
513  DataFlowSolver solver;
514  solver.load<DeadCodeAnalysis>();
515  solver.load<IntegerRangeAnalysis>();
516  if (failed(solver.initializeAndRun(op)))
517  return signalPassFailure();
518 
519  DataFlowListener listener(solver);
520 
522  populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
523 
524  // We specifically need bottom-up traversal as cmpi pattern needs range
525  // data, attached to its original argument values.
527  op, std::move(patterns),
528  GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
529  &listener))))
530  signalPassFailure();
531  }
532 };
533 } // namespace
534 
537  patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
538  DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
539 }
540 
543  ArrayRef<unsigned> bitwidthsSupported) {
544  patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
545  bitwidthsSupported);
546  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
547  FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
548  bitwidthsSupported);
549 }
550 
552  return std::make_unique<IntRangeOptimizationsPass>();
553 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal)
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
MLIRContext * getContext() const
Definition: Builders.h:56
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This is a value defined by a result of an operation.
Definition: Value.h:447
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:348
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:354
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef< unsigned > bitwidthsSupported)
Add patterns for int range based narrowing.
LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value)
Patterned after SCCP.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314