MLIR  22.0.0git
CommonFolders.h
Go to the documentation of this file.
1 //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various common operation folders. These folders
10 // are intended to be used by dialects to support common folding behavior
11 // without requiring each dialect to provide its own implementation.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
17 
18 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Types.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 
26 #include <cassert>
27 #include <cstddef>
28 #include <optional>
29 
30 namespace mlir {
31 namespace ub {
32 class PoisonAttr;
33 }
34 /// Performs constant folding `calculate` with element-wise behavior on the two
35 /// attributes in `operands` and returns the result if possible.
36 /// Uses `resultType` for the type of the returned attribute.
37 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
38 /// which will be directly propagated to result.
39 template <class AttrElementT, //
40  class ElementValueT = typename AttrElementT::ValueType,
41  class PoisonAttr = ub::PoisonAttr,
42  class ResultAttrElementT = AttrElementT,
43  class ResultElementValueT = typename ResultAttrElementT::ValueType,
44  class CalculationT = function_ref<
45  std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
47  Type resultType,
48  CalculationT &&calculate) {
49  assert(operands.size() == 2 && "binary op takes two operands");
50  static_assert(
51  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
52  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
53  "void as template argument to opt-out from poison semantics.");
54  if constexpr (!std::is_void_v<PoisonAttr>) {
55  if (isa_and_nonnull<PoisonAttr>(operands[0]))
56  return operands[0];
57 
58  if (isa_and_nonnull<PoisonAttr>(operands[1]))
59  return operands[1];
60  }
61 
62  if (!resultType || !operands[0] || !operands[1])
63  return {};
64 
65  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
66  auto lhs = cast<AttrElementT>(operands[0]);
67  auto rhs = cast<AttrElementT>(operands[1]);
68  if (lhs.getType() != rhs.getType())
69  return {};
70 
71  auto calRes = calculate(lhs.getValue(), rhs.getValue());
72 
73  if (!calRes)
74  return {};
75 
76  return ResultAttrElementT::get(resultType, *calRes);
77  }
78 
79  if (isa<SplatElementsAttr>(operands[0]) &&
80  isa<SplatElementsAttr>(operands[1])) {
81  // Both operands are splats so we can avoid expanding the values out and
82  // just fold based on the splat value.
83  auto lhs = cast<SplatElementsAttr>(operands[0]);
84  auto rhs = cast<SplatElementsAttr>(operands[1]);
85  if (lhs.getType() != rhs.getType())
86  return {};
87 
88  auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
89  rhs.getSplatValue<ElementValueT>());
90  if (!elementResult)
91  return {};
92 
93  return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
94  }
95 
96  if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
97  // Operands are ElementsAttr-derived; perform an element-wise fold by
98  // expanding the values.
99  auto lhs = cast<ElementsAttr>(operands[0]);
100  auto rhs = cast<ElementsAttr>(operands[1]);
101  if (lhs.getType() != rhs.getType())
102  return {};
103 
104  auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
105  auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
106  if (!maybeLhsIt || !maybeRhsIt)
107  return {};
108  auto lhsIt = *maybeLhsIt;
109  auto rhsIt = *maybeRhsIt;
111  elementResults.reserve(lhs.getNumElements());
112  for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
113  auto elementResult = calculate(*lhsIt, *rhsIt);
114  if (!elementResult)
115  return {};
116  elementResults.push_back(*elementResult);
117  }
118 
119  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
120  }
121  return {};
122 }
123 
124 /// Performs constant folding `calculate` with element-wise behavior on the two
125 /// attributes in `operands` and returns the result if possible.
126 /// Uses the operand element type for the element type of the returned
127 /// attribute.
128 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
129 /// which will be directly propagated to result.
130 template <class AttrElementT, //
131  class ElementValueT = typename AttrElementT::ValueType,
132  class PoisonAttr = ub::PoisonAttr,
133  class ResultAttrElementT = AttrElementT,
134  class ResultElementValueT = typename ResultAttrElementT::ValueType,
135  class CalculationT = function_ref<
136  std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
138  CalculationT &&calculate) {
139  assert(operands.size() == 2 && "binary op takes two operands");
140  static_assert(
141  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
142  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
143  "void as template argument to opt-out from poison semantics.");
144  if constexpr (!std::is_void_v<PoisonAttr>) {
145  if (isa_and_nonnull<PoisonAttr>(operands[0]))
146  return operands[0];
147 
148  if (isa_and_nonnull<PoisonAttr>(operands[1]))
149  return operands[1];
150  }
151 
152  auto getAttrType = [](Attribute attr) -> Type {
153  if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
154  return typed.getType();
155  return {};
156  };
157 
158  Type lhsType = getAttrType(operands[0]);
159  Type rhsType = getAttrType(operands[1]);
160  if (!lhsType || !rhsType)
161  return {};
162  if (lhsType != rhsType)
163  return {};
164 
165  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166  ResultAttrElementT, ResultElementValueT,
167  CalculationT>(
168  operands, lhsType, std::forward<CalculationT>(calculate));
169 }
170 
171 template <class AttrElementT,
172  class ElementValueT = typename AttrElementT::ValueType,
173  class PoisonAttr = void, //
174  class ResultAttrElementT = AttrElementT,
175  class ResultElementValueT = typename ResultAttrElementT::ValueType,
176  class CalculationT =
177  function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
179  CalculationT &&calculate) {
180  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181  ResultAttrElementT>(
182  operands, resultType,
183  [&](ElementValueT a, ElementValueT b)
184  -> std::optional<ResultElementValueT> { return calculate(a, b); });
185 }
186 
187 template <class AttrElementT, //
188  class ElementValueT = typename AttrElementT::ValueType,
189  class PoisonAttr = ub::PoisonAttr,
190  class ResultAttrElementT = AttrElementT,
191  class ResultElementValueT = typename ResultAttrElementT::ValueType,
192  class CalculationT =
193  function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
195  CalculationT &&calculate) {
196  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197  ResultAttrElementT>(
198  operands,
199  [&](ElementValueT a, ElementValueT b)
200  -> std::optional<ResultElementValueT> { return calculate(a, b); });
201 }
202 
203 /// Performs constant folding `calculate` with element-wise behavior on the one
204 /// attributes in `operands` and returns the result if possible.
205 /// Uses `resultType` for the type of the returned attribute.
206 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
207 /// which will be directly propagated to result.
208 template <class AttrElementT, //
209  class ElementValueT = typename AttrElementT::ValueType,
210  class PoisonAttr = ub::PoisonAttr,
211  class ResultAttrElementT = AttrElementT,
212  class ResultElementValueT = typename ResultAttrElementT::ValueType,
213  class CalculationT =
214  function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
216  Type resultType,
217  CalculationT &&calculate) {
218  if (!resultType || !llvm::getSingleElement(operands))
219  return {};
220 
221  static_assert(
222  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
223  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
224  "void as template argument to opt-out from poison semantics.");
225  if constexpr (!std::is_void_v<PoisonAttr>) {
226  if (isa<PoisonAttr>(operands[0]))
227  return operands[0];
228  }
229 
230  if (isa<AttrElementT>(operands[0])) {
231  auto op = cast<AttrElementT>(operands[0]);
232 
233  auto res = calculate(op.getValue());
234  if (!res)
235  return {};
236  return ResultAttrElementT::get(resultType, *res);
237  }
238  if (isa<SplatElementsAttr>(operands[0])) {
239  // Both operands are splats so we can avoid expanding the values out and
240  // just fold based on the splat value.
241  auto op = cast<SplatElementsAttr>(operands[0]);
242 
243  auto elementResult = calculate(op.getSplatValue<ElementValueT>());
244  if (!elementResult)
245  return {};
246  return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
247  } else if (isa<ElementsAttr>(operands[0])) {
248  // Operands are ElementsAttr-derived; perform an element-wise fold by
249  // expanding the values.
250  auto op = cast<ElementsAttr>(operands[0]);
251 
252  auto maybeOpIt = op.try_value_begin<ElementValueT>();
253  if (!maybeOpIt)
254  return {};
255  auto opIt = *maybeOpIt;
256  SmallVector<ResultElementValueT> elementResults;
257  elementResults.reserve(op.getNumElements());
258  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
259  auto elementResult = calculate(*opIt);
260  if (!elementResult)
261  return {};
262  elementResults.push_back(*elementResult);
263  }
264  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
265  }
266  return {};
267 }
268 
269 /// Performs constant folding `calculate` with element-wise behavior on the one
270 /// attributes in `operands` and returns the result if possible.
271 /// Uses the operand element type for the element type of the returned
272 /// attribute.
273 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
274 /// which will be directly propagated to result.
275 template <class AttrElementT, //
276  class ElementValueT = typename AttrElementT::ValueType,
277  class PoisonAttr = ub::PoisonAttr,
278  class ResultAttrElementT = AttrElementT,
279  class ResultElementValueT = typename ResultAttrElementT::ValueType,
280  class CalculationT =
281  function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
283  CalculationT &&calculate) {
284  if (!llvm::getSingleElement(operands))
285  return {};
286 
287  static_assert(
288  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290  "void as template argument to opt-out from poison semantics.");
291  if constexpr (!std::is_void_v<PoisonAttr>) {
292  if (isa<PoisonAttr>(operands[0]))
293  return operands[0];
294  }
295 
296  auto getAttrType = [](Attribute attr) -> Type {
297  if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298  return typed.getType();
299  return {};
300  };
301 
302  Type operandType = getAttrType(operands[0]);
303  if (!operandType)
304  return {};
305 
306  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307  ResultAttrElementT, ResultElementValueT,
308  CalculationT>(
309  operands, operandType, std::forward<CalculationT>(calculate));
310 }
311 
312 template <class AttrElementT, //
313  class ElementValueT = typename AttrElementT::ValueType,
314  class PoisonAttr = ub::PoisonAttr,
315  class ResultAttrElementT = AttrElementT,
316  class ResultElementValueT = typename ResultAttrElementT::ValueType,
317  class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
319  CalculationT &&calculate) {
320  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321  ResultAttrElementT>(
322  operands, resultType,
323  [&](ElementValueT a) -> std::optional<ResultElementValueT> {
324  return calculate(a);
325  });
326 }
327 
328 template <class AttrElementT, //
329  class ElementValueT = typename AttrElementT::ValueType,
330  class PoisonAttr = ub::PoisonAttr,
331  class ResultAttrElementT = AttrElementT,
332  class ResultElementValueT = typename ResultAttrElementT::ValueType,
333  class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
335  CalculationT &&calculate) {
336  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337  ResultAttrElementT>(
338  operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
339  return calculate(a);
340  });
341 }
342 
343 template <
344  class AttrElementT, class TargetAttrElementT,
345  class ElementValueT = typename AttrElementT::ValueType,
346  class TargetElementValueT = typename TargetAttrElementT::ValueType,
347  class PoisonAttr = ub::PoisonAttr,
348  class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
350  CalculationT &&calculate) {
351  if (!llvm::getSingleElement(operands))
352  return {};
353 
354  static_assert(
355  std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
356  "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
357  "void as template argument to opt-out from poison semantics.");
358  if constexpr (!std::is_void_v<PoisonAttr>) {
359  if (isa<PoisonAttr>(operands[0]))
360  return operands[0];
361  }
362 
363  if (isa<AttrElementT>(operands[0])) {
364  auto op = cast<AttrElementT>(operands[0]);
365  bool castStatus = true;
366  auto res = calculate(op.getValue(), castStatus);
367  if (!castStatus)
368  return {};
369  return TargetAttrElementT::get(resType, res);
370  }
371  if (isa<SplatElementsAttr>(operands[0])) {
372  // The operand is a splat so we can avoid expanding the values out and
373  // just fold based on the splat value.
374  auto op = cast<SplatElementsAttr>(operands[0]);
375  bool castStatus = true;
376  auto elementResult =
377  calculate(op.getSplatValue<ElementValueT>(), castStatus);
378  if (!castStatus)
379  return {};
380  auto shapedResType = cast<ShapedType>(resType);
381  if (!shapedResType.hasStaticShape())
382  return {};
383  return DenseElementsAttr::get(shapedResType, elementResult);
384  }
385  if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
386  // Operand is ElementsAttr-derived; perform an element-wise fold by
387  // expanding the value.
388  bool castStatus = true;
389  auto maybeOpIt = op.try_value_begin<ElementValueT>();
390  if (!maybeOpIt)
391  return {};
392  auto opIt = *maybeOpIt;
393  SmallVector<TargetElementValueT> elementResults;
394  elementResults.reserve(op.getNumElements());
395  for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
396  auto elt = calculate(*opIt, castStatus);
397  if (!castStatus)
398  return {};
399  elementResults.push_back(elt);
400  }
401 
402  return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
403  }
404  return {};
405 }
406 } // namespace mlir
407 
408 #endif // MLIR_DIALECT_COMMONFOLDERS_H
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CommonFolders.h:46
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)