1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various operation folding utilities. These
10 // utilities are intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Dialect.h"
22 namespace mlir {
23 class Operation;
24 class Value;
26 //===--------------------------------------------------------------------===//
27 // OperationFolder
28 //===--------------------------------------------------------------------===//
30 /// A utility class for folding operations, and unifying duplicated constants
31 /// generated along the way.
33 public:
34  OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
36  /// Tries to perform folding on the given `op`, including unifying
37  /// deduplicated constants. If successful, replaces `op`'s uses with
38  /// folded results, and returns success. `preReplaceAction` is invoked on `op`
39  /// before it is replaced. 'processGeneratedConstants' is invoked for any new
40  /// operations generated when folding. If the op was completely folded it is
41  /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
43  tryToFold(Operation *op,
44  function_ref<void(Operation *)> processGeneratedConstants = nullptr,
45  function_ref<void(Operation *)> preReplaceAction = nullptr,
46  bool *inPlaceUpdate = nullptr);
48  /// Notifies that the given constant `op` should be remove from this
49  /// OperationFolder's internal bookkeeping.
50  ///
51  /// Note: this method must be called if a constant op is to be deleted
52  /// externally to this OperationFolder. `op` must be a constant op.
53  void notifyRemoval(Operation *op);
55  /// Create an operation of specific op type with the given builder,
56  /// and immediately try to fold it. This function populates 'results' with
57  /// the results after folding the operation.
58  template <typename OpTy, typename... Args>
59  void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
60  Location location, Args &&... args) {
61  // The op needs to be inserted only if the fold (below) fails, or the number
62  // of results produced by the successful folding is zero (which is treated
63  // as an in-place fold). Using create methods of the builder will insert the
64  // op, so not using it here.
65  OperationState state(location, OpTy::getOperationName());
66  OpTy::build(builder, state, std::forward<Args>(args)...);
67  Operation *op = Operation::create(state);
69  if (failed(tryToFold(builder, op, results)) || results.empty()) {
70  builder.insert(op);
71  results.assign(op->result_begin(), op->result_end());
72  return;
73  }
74  op->destroy();
75  }
77  /// Overload to create or fold a single result operation.
78  template <typename OpTy, typename... Args>
79  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
80  Value>::type
81  create(OpBuilder &builder, Location location, Args &&... args) {
82  SmallVector<Value, 1> results;
83  create<OpTy>(builder, results, location, std::forward<Args>(args)...);
84  return results.front();
85  }
87  /// Overload to create or fold a zero result operation.
88  template <typename OpTy, typename... Args>
89  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
90  OpTy>::type
91  create(OpBuilder &builder, Location location, Args &&... args) {
92  auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
93  SmallVector<Value, 0> unused;
94  (void)tryToFold(op.getOperation(), unused);
96  // Folding cannot remove a zero-result operation, so for convenience we
97  // continue to return it.
98  return op;
99  }
101  /// Clear out any constants cached inside of the folder.
102  void clear();
104  /// Get or create a constant using the given builder. On success this returns
105  /// the constant operation, nullptr otherwise.
106  Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
107  Attribute value, Type type, Location loc);
109 private:
110  /// This map keeps track of uniqued constants by dialect, attribute, and type.
111  /// A constant operation materializes an attribute with a type. Dialects may
112  /// generate different constants with the same input attribute and type, so we
113  /// also need to track per-dialect.
114  using ConstantMap =
117  /// Tries to perform folding on the given `op`. If successful, populates
118  /// `results` with the results of the folding.
120  OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
121  function_ref<void(Operation *)> processGeneratedConstants = nullptr);
123  /// Try to get or create a new constant entry. On success this returns the
124  /// constant operation, nullptr otherwise.
125  Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
126  Dialect *dialect, OpBuilder &builder,
127  Attribute value, Type type, Location loc);
129  /// A mapping between an insertion region and the constants that have been
130  /// created within it.
133  /// This map tracks all of the dialects that an operation is referenced by;
134  /// given that many dialects may generate the same constant.
137  /// A collection of dialect folder interfaces.
139 };
141 } // namespace mlir
