MLIR  20.0.0git
CommonFolders.h
Go to the documentation of this file.
1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace ub {
26 class PoisonAttr;
27 }
28 /// Performs constant folding `calculate` with element-wise behavior on the two
29 /// attributes in `operands` and returns the result if possible.
30 /// Uses `resultType` for the type of the returned attribute.
31 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
32 /// which will be directly propagated to result.
33 template <class AttrElementT,
34  class ElementValueT = typename AttrElementT::ValueType,
35  class PoisonAttr = ub::PoisonAttr,
36  class CalculationT = function_ref<
37  std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
39  Type resultType,
40  CalculationT &&calculate) {
41  assert(operands.size() == 2 && "binary op takes two operands");
42  static_assert(
43  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45  "void as template argument to opt-out from poison semantics.");
46  if constexpr (!std::is_void_v<PoisonAttr>) {
47  if (isa_and_nonnull<PoisonAttr>(operands[0]))
48  return operands[0];
49 
50  if (isa_and_nonnull<PoisonAttr>(operands[1]))
51  return operands[1];
52  }
53 
54  if (!resultType || !operands[0] || !operands[1])
55  return {};
56 
57  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58  auto lhs = cast<AttrElementT>(operands[0]);
59  auto rhs = cast<AttrElementT>(operands[1]);
60  if (lhs.getType() != rhs.getType())
61  return {};
62 
63  auto calRes = calculate(lhs.getValue(), rhs.getValue());
64 
65  if (!calRes)
66  return {};
67 
68  return AttrElementT::get(resultType, *calRes);
69  }
70 
71  if (isa<SplatElementsAttr>(operands[0]) &&
72  isa<SplatElementsAttr>(operands[1])) {
73  // Both operands are splats so we can avoid expanding the values out and
74  // just fold based on the splat value.
75  auto lhs = cast<SplatElementsAttr>(operands[0]);
76  auto rhs = cast<SplatElementsAttr>(operands[1]);
77  if (lhs.getType() != rhs.getType())
78  return {};
79 
80  auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81  rhs.getSplatValue<ElementValueT>());
82  if (!elementResult)
83  return {};
84 
85  return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
86  }
87 
88  if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
89  // Operands are ElementsAttr-derived; perform an element-wise fold by
90  // expanding the values.
91  auto lhs = cast<ElementsAttr>(operands[0]);
92  auto rhs = cast<ElementsAttr>(operands[1]);
93  if (lhs.getType() != rhs.getType())
94  return {};
95 
96  auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97  auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98  if (!maybeLhsIt || !maybeRhsIt)
99  return {};
100  auto lhsIt = *maybeLhsIt;
101  auto rhsIt = *maybeRhsIt;
102  SmallVector<ElementValueT, 4> elementResults;
103  elementResults.reserve(lhs.getNumElements());
104  for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105  auto elementResult = calculate(*lhsIt, *rhsIt);
106  if (!elementResult)
107  return {};
108  elementResults.push_back(*elementResult);
109  }
110 
111  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
112  }
113  return {};
114 }
115 
116 /// Performs constant folding `calculate` with element-wise behavior on the two
117 /// attributes in `operands` and returns the result if possible.
118 /// Uses the operand element type for the element type of the returned
119 /// attribute.
120 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
121 /// which will be directly propagated to result.
122 template <class AttrElementT,
123  class ElementValueT = typename AttrElementT::ValueType,
124  class PoisonAttr = ub::PoisonAttr,
125  class CalculationT = function_ref<
126  std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
128  CalculationT &&calculate) {
129  assert(operands.size() == 2 && "binary op takes two operands");
130  static_assert(
131  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133  "void as template argument to opt-out from poison semantics.");
134  if constexpr (!std::is_void_v<PoisonAttr>) {
135  if (isa_and_nonnull<PoisonAttr>(operands[0]))
136  return operands[0];
137 
138  if (isa_and_nonnull<PoisonAttr>(operands[1]))
139  return operands[1];
140  }
141 
142  auto getResultType = [](Attribute attr) -> Type {
143  if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144  return typed.getType();
145  return {};
146  };
147 
148  Type lhsType = getResultType(operands[0]);
149  Type rhsType = getResultType(operands[1]);
150  if (!lhsType || !rhsType)
151  return {};
152  if (lhsType != rhsType)
153  return {};
154 
155  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
156  CalculationT>(
157  operands, lhsType, std::forward<CalculationT>(calculate));
158 }
159 
160 template <class AttrElementT,
161  class ElementValueT = typename AttrElementT::ValueType,
162  class PoisonAttr = void,
163  class CalculationT =
164  function_ref<ElementValueT(ElementValueT, ElementValueT)>>
166  CalculationT &&calculate) {
167  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168  operands, resultType,
169  [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170  return calculate(a, b);
171  });
172 }
173 
174 template <class AttrElementT,
175  class ElementValueT = typename AttrElementT::ValueType,
176  class PoisonAttr = ub::PoisonAttr,
177  class CalculationT =
178  function_ref<ElementValueT(ElementValueT, ElementValueT)>>
180  CalculationT &&calculate) {
181  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
182  operands,
183  [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184  return calculate(a, b);
185  });
186 }
187 
188 /// Performs constant folding `calculate` with element-wise behavior on the one
189 /// attributes in `operands` and returns the result if possible.
190 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
191 /// which will be directly propagated to result.
192 template <class AttrElementT,
193  class ElementValueT = typename AttrElementT::ValueType,
194  class PoisonAttr = ub::PoisonAttr,
195  class CalculationT =
196  function_ref<std::optional<ElementValueT>(ElementValueT)>>
198  CalculationT &&calculate) {
199  assert(operands.size() == 1 && "unary op takes one operands");
200  if (!operands[0])
201  return {};
202 
203  static_assert(
204  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
205  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
206  "void as template argument to opt-out from poison semantics.");
207  if constexpr (!std::is_void_v<PoisonAttr>) {
208  if (isa<PoisonAttr>(operands[0]))
209  return operands[0];
210  }
211 
212  if (isa<AttrElementT>(operands[0])) {
213  auto op = cast<AttrElementT>(operands[0]);
214 
215  auto res = calculate(op.getValue());
216  if (!res)
217  return {};
218  return AttrElementT::get(op.getType(), *res);
219  }
220  if (isa<SplatElementsAttr>(operands[0])) {
221  // Both operands are splats so we can avoid expanding the values out and
222  // just fold based on the splat value.
223  auto op = cast<SplatElementsAttr>(operands[0]);
224 
225  auto elementResult = calculate(op.getSplatValue<ElementValueT>());
226  if (!elementResult)
227  return {};
228  return DenseElementsAttr::get(op.getType(), *elementResult);
229  } else if (isa<ElementsAttr>(operands[0])) {
230  // Operands are ElementsAttr-derived; perform an element-wise fold by
231  // expanding the values.
232  auto op = cast<ElementsAttr>(operands[0]);
233 
234  auto maybeOpIt = op.try_value_begin<ElementValueT>();
235  if (!maybeOpIt)
236  return {};
237  auto opIt = *maybeOpIt;
238  SmallVector<ElementValueT> elementResults;
239  elementResults.reserve(op.getNumElements());
240  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
241  auto elementResult = calculate(*opIt);
242  if (!elementResult)
243  return {};
244  elementResults.push_back(*elementResult);
245  }
246  return DenseElementsAttr::get(op.getShapedType(), elementResults);
247  }
248  return {};
249 }
250 
251 template <class AttrElementT,
252  class ElementValueT = typename AttrElementT::ValueType,
253  class PoisonAttr = ub::PoisonAttr,
254  class CalculationT = function_ref<ElementValueT(ElementValueT)>>
256  CalculationT &&calculate) {
257  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
258  operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
259  return calculate(a);
260  });
261 }
262 
263 template <
264  class AttrElementT, class TargetAttrElementT,
265  class ElementValueT = typename AttrElementT::ValueType,
266  class TargetElementValueT = typename TargetAttrElementT::ValueType,
267  class PoisonAttr = ub::PoisonAttr,
268  class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
270  CalculationT &&calculate) {
271  assert(operands.size() == 1 && "Cast op takes one operand");
272  if (!operands[0])
273  return {};
274 
275  static_assert(
276  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
277  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
278  "void as template argument to opt-out from poison semantics.");
279  if constexpr (!std::is_void_v<PoisonAttr>) {
280  if (isa<PoisonAttr>(operands[0]))
281  return operands[0];
282  }
283 
284  if (isa<AttrElementT>(operands[0])) {
285  auto op = cast<AttrElementT>(operands[0]);
286  bool castStatus = true;
287  auto res = calculate(op.getValue(), castStatus);
288  if (!castStatus)
289  return {};
290  return TargetAttrElementT::get(resType, res);
291  }
292  if (isa<SplatElementsAttr>(operands[0])) {
293  // The operand is a splat so we can avoid expanding the values out and
294  // just fold based on the splat value.
295  auto op = cast<SplatElementsAttr>(operands[0]);
296  bool castStatus = true;
297  auto elementResult =
298  calculate(op.getSplatValue<ElementValueT>(), castStatus);
299  if (!castStatus)
300  return {};
301  auto shapedResType = cast<ShapedType>(resType);
302  if (!shapedResType.hasStaticShape())
303  return {};
304  return DenseElementsAttr::get(shapedResType, elementResult);
305  }
306  if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
307  // Operand is ElementsAttr-derived; perform an element-wise fold by
308  // expanding the value.
309  bool castStatus = true;
310  auto maybeOpIt = op.try_value_begin<ElementValueT>();
311  if (!maybeOpIt)
312  return {};
313  auto opIt = *maybeOpIt;
314  SmallVector<TargetElementValueT> elementResults;
315  elementResults.reserve(op.getNumElements());
316  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
317  auto elt = calculate(*opIt, castStatus);
318  if (!castStatus)
319  return {};
320  elementResults.push_back(elt);
321  }
322 
323  return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
324  }
325  return {};
326 }
327 } // namespace mlir
328 
329 #endif // MLIR_DIALECT_COMMONFOLDERS_H
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CommonFolders.h:38
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)