15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
39 template <
class AttrElementT,
40 class ElementValueT =
typename AttrElementT::ValueType,
41 class PoisonAttr = ub::PoisonAttr,
42 class ResultAttrElementT = AttrElementT,
43 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
45 std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
48 CalculationT &&calculate) {
49 assert(operands.size() == 2 &&
"binary op takes two operands");
51 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
52 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
53 "void as template argument to opt-out from poison semantics.");
54 if constexpr (!std::is_void_v<PoisonAttr>) {
55 if (isa_and_nonnull<PoisonAttr>(operands[0]))
58 if (isa_and_nonnull<PoisonAttr>(operands[1]))
62 if (!resultType || !operands[0] || !operands[1])
65 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
66 auto lhs = cast<AttrElementT>(operands[0]);
67 auto rhs = cast<AttrElementT>(operands[1]);
68 if (lhs.getType() != rhs.getType())
71 auto calRes = calculate(lhs.getValue(), rhs.getValue());
79 if (isa<SplatElementsAttr>(operands[0]) &&
80 isa<SplatElementsAttr>(operands[1])) {
83 auto lhs = cast<SplatElementsAttr>(operands[0]);
84 auto rhs = cast<SplatElementsAttr>(operands[1]);
85 if (lhs.getType() != rhs.getType())
88 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
89 rhs.getSplatValue<ElementValueT>());
96 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
99 auto lhs = cast<ElementsAttr>(operands[0]);
100 auto rhs = cast<ElementsAttr>(operands[1]);
101 if (lhs.getType() != rhs.getType())
104 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
105 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
106 if (!maybeLhsIt || !maybeRhsIt)
108 auto lhsIt = *maybeLhsIt;
109 auto rhsIt = *maybeRhsIt;
111 elementResults.reserve(lhs.getNumElements());
112 for (
size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
113 auto elementResult = calculate(*lhsIt, *rhsIt);
116 elementResults.push_back(*elementResult);
130 template <
class AttrElementT,
131 class ElementValueT =
typename AttrElementT::ValueType,
132 class PoisonAttr = ub::PoisonAttr,
133 class ResultAttrElementT = AttrElementT,
134 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
136 std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
138 CalculationT &&calculate) {
139 assert(operands.size() == 2 &&
"binary op takes two operands");
141 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
142 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
143 "void as template argument to opt-out from poison semantics.");
144 if constexpr (!std::is_void_v<PoisonAttr>) {
145 if (isa_and_nonnull<PoisonAttr>(operands[0]))
148 if (isa_and_nonnull<PoisonAttr>(operands[1]))
153 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
154 return typed.getType();
158 Type lhsType = getAttrType(operands[0]);
159 Type rhsType = getAttrType(operands[1]);
160 if (!lhsType || !rhsType)
162 if (lhsType != rhsType)
166 ResultAttrElementT, ResultElementValueT,
168 operands, lhsType, std::forward<CalculationT>(calculate));
171 template <
class AttrElementT,
172 class ElementValueT =
typename AttrElementT::ValueType,
173 class PoisonAttr = void,
174 class ResultAttrElementT = AttrElementT,
175 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
177 function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
179 CalculationT &&calculate) {
182 operands, resultType,
183 [&](ElementValueT a, ElementValueT b)
184 -> std::optional<ResultElementValueT> {
return calculate(a, b); });
187 template <
class AttrElementT,
188 class ElementValueT =
typename AttrElementT::ValueType,
189 class PoisonAttr = ub::PoisonAttr,
190 class ResultAttrElementT = AttrElementT,
191 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
193 function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
195 CalculationT &&calculate) {
199 [&](ElementValueT a, ElementValueT b)
200 -> std::optional<ResultElementValueT> {
return calculate(a, b); });
208 template <
class AttrElementT,
209 class ElementValueT =
typename AttrElementT::ValueType,
210 class PoisonAttr = ub::PoisonAttr,
211 class ResultAttrElementT = AttrElementT,
212 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
214 function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
217 CalculationT &&calculate) {
218 if (!resultType || !llvm::getSingleElement(operands))
222 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
223 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
224 "void as template argument to opt-out from poison semantics.");
225 if constexpr (!std::is_void_v<PoisonAttr>) {
226 if (isa<PoisonAttr>(operands[0]))
230 if (isa<AttrElementT>(operands[0])) {
231 auto op = cast<AttrElementT>(operands[0]);
233 auto res = calculate(op.getValue());
238 if (isa<SplatElementsAttr>(operands[0])) {
241 auto op = cast<SplatElementsAttr>(operands[0]);
243 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
247 }
else if (isa<ElementsAttr>(operands[0])) {
250 auto op = cast<ElementsAttr>(operands[0]);
252 auto maybeOpIt = op.try_value_begin<ElementValueT>();
255 auto opIt = *maybeOpIt;
257 elementResults.reserve(op.getNumElements());
258 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
259 auto elementResult = calculate(*opIt);
262 elementResults.push_back(*elementResult);
275 template <
class AttrElementT,
276 class ElementValueT =
typename AttrElementT::ValueType,
277 class PoisonAttr = ub::PoisonAttr,
278 class ResultAttrElementT = AttrElementT,
279 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
281 function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
283 CalculationT &&calculate) {
284 if (!llvm::getSingleElement(operands))
288 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290 "void as template argument to opt-out from poison semantics.");
291 if constexpr (!std::is_void_v<PoisonAttr>) {
292 if (isa<PoisonAttr>(operands[0]))
297 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
298 return typed.getType();
302 Type operandType = getAttrType(operands[0]);
307 ResultAttrElementT, ResultElementValueT,
309 operands, operandType, std::forward<CalculationT>(calculate));
312 template <
class AttrElementT,
313 class ElementValueT =
typename AttrElementT::ValueType,
314 class PoisonAttr = ub::PoisonAttr,
315 class ResultAttrElementT = AttrElementT,
316 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
317 class CalculationT =
function_ref<ResultElementValueT(ElementValueT)>>
319 CalculationT &&calculate) {
322 operands, resultType,
323 [&](ElementValueT a) -> std::optional<ResultElementValueT> {
328 template <
class AttrElementT,
329 class ElementValueT =
typename AttrElementT::ValueType,
330 class PoisonAttr = ub::PoisonAttr,
331 class ResultAttrElementT = AttrElementT,
332 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
333 class CalculationT =
function_ref<ResultElementValueT(ElementValueT)>>
335 CalculationT &&calculate) {
338 operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
344 class AttrElementT,
class TargetAttrElementT,
345 class ElementValueT =
typename AttrElementT::ValueType,
346 class TargetElementValueT =
typename TargetAttrElementT::ValueType,
347 class PoisonAttr = ub::PoisonAttr,
348 class CalculationT =
function_ref<TargetElementValueT(ElementValueT,
bool)>>
350 CalculationT &&calculate) {
351 if (!llvm::getSingleElement(operands))
355 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
356 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
357 "void as template argument to opt-out from poison semantics.");
358 if constexpr (!std::is_void_v<PoisonAttr>) {
359 if (isa<PoisonAttr>(operands[0]))
363 if (isa<AttrElementT>(operands[0])) {
364 auto op = cast<AttrElementT>(operands[0]);
365 bool castStatus =
true;
366 auto res = calculate(op.getValue(), castStatus);
371 if (isa<SplatElementsAttr>(operands[0])) {
374 auto op = cast<SplatElementsAttr>(operands[0]);
375 bool castStatus =
true;
377 calculate(op.getSplatValue<ElementValueT>(), castStatus);
380 auto shapedResType = cast<ShapedType>(resType);
381 if (!shapedResType.hasStaticShape())
385 if (
auto op = dyn_cast<ElementsAttr>(operands[0])) {
388 bool castStatus =
true;
389 auto maybeOpIt = op.try_value_begin<ElementValueT>();
392 auto opIt = *maybeOpIt;
394 elementResults.reserve(op.getNumElements());
395 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
396 auto elt = calculate(*opIt, castStatus);
399 elementResults.push_back(elt);
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)