MLIR  21.0.0git
CommonFolders.h
Go to the documentation of this file.
1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace ub {
26 class PoisonAttr;
27 }
28 /// Performs constant folding `calculate` with element-wise behavior on the two
29 /// attributes in `operands` and returns the result if possible.
30 /// Uses `resultType` for the type of the returned attribute.
31 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
32 /// which will be directly propagated to result.
33 template <class AttrElementT,
34  class ElementValueT = typename AttrElementT::ValueType,
35  class PoisonAttr = ub::PoisonAttr,
36  class CalculationT = function_ref<
37  std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
39  Type resultType,
40  CalculationT &&calculate) {
41  assert(operands.size() == 2 && "binary op takes two operands");
42  static_assert(
43  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45  "void as template argument to opt-out from poison semantics.");
46  if constexpr (!std::is_void_v<PoisonAttr>) {
47  if (isa_and_nonnull<PoisonAttr>(operands[0]))
48  return operands[0];
49 
50  if (isa_and_nonnull<PoisonAttr>(operands[1]))
51  return operands[1];
52  }
53 
54  if (!resultType || !operands[0] || !operands[1])
55  return {};
56 
57  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58  auto lhs = cast<AttrElementT>(operands[0]);
59  auto rhs = cast<AttrElementT>(operands[1]);
60  if (lhs.getType() != rhs.getType())
61  return {};
62 
63  auto calRes = calculate(lhs.getValue(), rhs.getValue());
64 
65  if (!calRes)
66  return {};
67 
68  return AttrElementT::get(resultType, *calRes);
69  }
70 
71  if (isa<SplatElementsAttr>(operands[0]) &&
72  isa<SplatElementsAttr>(operands[1])) {
73  // Both operands are splats so we can avoid expanding the values out and
74  // just fold based on the splat value.
75  auto lhs = cast<SplatElementsAttr>(operands[0]);
76  auto rhs = cast<SplatElementsAttr>(operands[1]);
77  if (lhs.getType() != rhs.getType())
78  return {};
79 
80  auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81  rhs.getSplatValue<ElementValueT>());
82  if (!elementResult)
83  return {};
84 
85  return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
86  }
87 
88  if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
89  // Operands are ElementsAttr-derived; perform an element-wise fold by
90  // expanding the values.
91  auto lhs = cast<ElementsAttr>(operands[0]);
92  auto rhs = cast<ElementsAttr>(operands[1]);
93  if (lhs.getType() != rhs.getType())
94  return {};
95 
96  auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97  auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98  if (!maybeLhsIt || !maybeRhsIt)
99  return {};
100  auto lhsIt = *maybeLhsIt;
101  auto rhsIt = *maybeRhsIt;
102  SmallVector<ElementValueT, 4> elementResults;
103  elementResults.reserve(lhs.getNumElements());
104  for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105  auto elementResult = calculate(*lhsIt, *rhsIt);
106  if (!elementResult)
107  return {};
108  elementResults.push_back(*elementResult);
109  }
110 
111  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
112  }
113  return {};
114 }
115 
116 /// Performs constant folding `calculate` with element-wise behavior on the two
117 /// attributes in `operands` and returns the result if possible.
118 /// Uses the operand element type for the element type of the returned
119 /// attribute.
120 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
121 /// which will be directly propagated to result.
122 template <class AttrElementT,
123  class ElementValueT = typename AttrElementT::ValueType,
124  class PoisonAttr = ub::PoisonAttr,
125  class CalculationT = function_ref<
126  std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
128  CalculationT &&calculate) {
129  assert(operands.size() == 2 && "binary op takes two operands");
130  static_assert(
131  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133  "void as template argument to opt-out from poison semantics.");
134  if constexpr (!std::is_void_v<PoisonAttr>) {
135  if (isa_and_nonnull<PoisonAttr>(operands[0]))
136  return operands[0];
137 
138  if (isa_and_nonnull<PoisonAttr>(operands[1]))
139  return operands[1];
140  }
141 
142  auto getResultType = [](Attribute attr) -> Type {
143  if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144  return typed.getType();
145  return {};
146  };
147 
148  Type lhsType = getResultType(operands[0]);
149  Type rhsType = getResultType(operands[1]);
150  if (!lhsType || !rhsType)
151  return {};
152  if (lhsType != rhsType)
153  return {};
154 
155  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
156  CalculationT>(
157  operands, lhsType, std::forward<CalculationT>(calculate));
158 }
159 
160 template <class AttrElementT,
161  class ElementValueT = typename AttrElementT::ValueType,
162  class PoisonAttr = void,
163  class CalculationT =
164  function_ref<ElementValueT(ElementValueT, ElementValueT)>>
166  CalculationT &&calculate) {
167  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168  operands, resultType,
169  [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170  return calculate(a, b);
171  });
172 }
173 
174 template <class AttrElementT,
175  class ElementValueT = typename AttrElementT::ValueType,
176  class PoisonAttr = ub::PoisonAttr,
177  class CalculationT =
178  function_ref<ElementValueT(ElementValueT, ElementValueT)>>
180  CalculationT &&calculate) {
181  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
182  operands,
183  [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184  return calculate(a, b);
185  });
186 }
187 
188 /// Performs constant folding `calculate` with element-wise behavior on the one
189 /// attributes in `operands` and returns the result if possible.
190 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
191 /// which will be directly propagated to result.
192 template <class AttrElementT,
193  class ElementValueT = typename AttrElementT::ValueType,
194  class PoisonAttr = ub::PoisonAttr,
195  class CalculationT =
196  function_ref<std::optional<ElementValueT>(ElementValueT)>>
198  CalculationT &&calculate) {
199  if (!llvm::getSingleElement(operands))
200  return {};
201 
202  static_assert(
203  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
204  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
205  "void as template argument to opt-out from poison semantics.");
206  if constexpr (!std::is_void_v<PoisonAttr>) {
207  if (isa<PoisonAttr>(operands[0]))
208  return operands[0];
209  }
210 
211  if (isa<AttrElementT>(operands[0])) {
212  auto op = cast<AttrElementT>(operands[0]);
213 
214  auto res = calculate(op.getValue());
215  if (!res)
216  return {};
217  return AttrElementT::get(op.getType(), *res);
218  }
219  if (isa<SplatElementsAttr>(operands[0])) {
220  // Both operands are splats so we can avoid expanding the values out and
221  // just fold based on the splat value.
222  auto op = cast<SplatElementsAttr>(operands[0]);
223 
224  auto elementResult = calculate(op.getSplatValue<ElementValueT>());
225  if (!elementResult)
226  return {};
227  return DenseElementsAttr::get(op.getType(), *elementResult);
228  } else if (isa<ElementsAttr>(operands[0])) {
229  // Operands are ElementsAttr-derived; perform an element-wise fold by
230  // expanding the values.
231  auto op = cast<ElementsAttr>(operands[0]);
232 
233  auto maybeOpIt = op.try_value_begin<ElementValueT>();
234  if (!maybeOpIt)
235  return {};
236  auto opIt = *maybeOpIt;
237  SmallVector<ElementValueT> elementResults;
238  elementResults.reserve(op.getNumElements());
239  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
240  auto elementResult = calculate(*opIt);
241  if (!elementResult)
242  return {};
243  elementResults.push_back(*elementResult);
244  }
245  return DenseElementsAttr::get(op.getShapedType(), elementResults);
246  }
247  return {};
248 }
249 
250 template <class AttrElementT,
251  class ElementValueT = typename AttrElementT::ValueType,
252  class PoisonAttr = ub::PoisonAttr,
253  class CalculationT = function_ref<ElementValueT(ElementValueT)>>
255  CalculationT &&calculate) {
256  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
257  operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
258  return calculate(a);
259  });
260 }
261 
262 template <
263  class AttrElementT, class TargetAttrElementT,
264  class ElementValueT = typename AttrElementT::ValueType,
265  class TargetElementValueT = typename TargetAttrElementT::ValueType,
266  class PoisonAttr = ub::PoisonAttr,
267  class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
269  CalculationT &&calculate) {
270  if (!llvm::getSingleElement(operands))
271  return {};
272 
273  static_assert(
274  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
275  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
276  "void as template argument to opt-out from poison semantics.");
277  if constexpr (!std::is_void_v<PoisonAttr>) {
278  if (isa<PoisonAttr>(operands[0]))
279  return operands[0];
280  }
281 
282  if (isa<AttrElementT>(operands[0])) {
283  auto op = cast<AttrElementT>(operands[0]);
284  bool castStatus = true;
285  auto res = calculate(op.getValue(), castStatus);
286  if (!castStatus)
287  return {};
288  return TargetAttrElementT::get(resType, res);
289  }
290  if (isa<SplatElementsAttr>(operands[0])) {
291  // The operand is a splat so we can avoid expanding the values out and
292  // just fold based on the splat value.
293  auto op = cast<SplatElementsAttr>(operands[0]);
294  bool castStatus = true;
295  auto elementResult =
296  calculate(op.getSplatValue<ElementValueT>(), castStatus);
297  if (!castStatus)
298  return {};
299  auto shapedResType = cast<ShapedType>(resType);
300  if (!shapedResType.hasStaticShape())
301  return {};
302  return DenseElementsAttr::get(shapedResType, elementResult);
303  }
304  if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
305  // Operand is ElementsAttr-derived; perform an element-wise fold by
306  // expanding the value.
307  bool castStatus = true;
308  auto maybeOpIt = op.try_value_begin<ElementValueT>();
309  if (!maybeOpIt)
310  return {};
311  auto opIt = *maybeOpIt;
312  SmallVector<TargetElementValueT> elementResults;
313  elementResults.reserve(op.getNumElements());
314  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
315  auto elt = calculate(*opIt, castStatus);
316  if (!castStatus)
317  return {};
318  elementResults.push_back(elt);
319  }
320 
321  return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
322  }
323  return {};
324 }
325 } // namespace mlir
326 
327 #endif // MLIR_DIALECT_COMMONFOLDERS_H
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CommonFolders.h:38
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)