15#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16#define MLIR_DIALECT_COMMONFOLDERS_H
23#include "llvm/ADT/ArrayRef.h"
24#include "llvm/ADT/STLExtras.h"
39template <
class LAttrElementT,
class RAttrElementT = LAttrElementT,
40 class LElementValueT =
typename LAttrElementT::ValueType,
41 class RElementValueT =
typename RAttrElementT::ValueType,
42 class PoisonAttr = ub::PoisonAttr,
43 class ResultAttrElementT = LAttrElementT,
44 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
46 LElementValueT, RElementValueT)>>
49 CalculationT &&calculate) {
50 assert(operands.size() == 2 &&
"binary op takes two operands");
52 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
53 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
54 "void as template argument to opt-out from poison semantics.");
55 if constexpr (!std::is_void_v<PoisonAttr>) {
56 if (isa_and_nonnull<PoisonAttr>(operands[0]))
59 if (isa_and_nonnull<PoisonAttr>(operands[1]))
63 if (!resultType || !operands[0] || !operands[1])
66 if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
67 auto lhs = cast<LAttrElementT>(operands[0]);
68 auto rhs = cast<RAttrElementT>(operands[1]);
69 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
70 if (
lhs.getType() !=
rhs.getType())
73 auto calRes = calculate(
lhs.getValue(),
rhs.getValue());
78 return ResultAttrElementT::get(resultType, *calRes);
81 if (isa<SplatElementsAttr>(operands[0]) &&
82 isa<SplatElementsAttr>(operands[1])) {
85 auto lhs = cast<SplatElementsAttr>(operands[0]);
86 auto rhs = cast<SplatElementsAttr>(operands[1]);
87 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
88 if (
lhs.getType() !=
rhs.getType())
91 auto elementResult = calculate(
lhs.getSplatValue<LElementValueT>(),
92 rhs.getSplatValue<RElementValueT>());
99 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
102 auto lhs = cast<ElementsAttr>(operands[0]);
103 auto rhs = cast<ElementsAttr>(operands[1]);
104 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
105 if (
lhs.getType() !=
rhs.getType())
108 auto maybeLhsIt =
lhs.try_value_begin<LElementValueT>();
109 auto maybeRhsIt =
rhs.try_value_begin<RElementValueT>();
110 if (!maybeLhsIt || !maybeRhsIt)
112 auto lhsIt = *maybeLhsIt;
113 auto rhsIt = *maybeRhsIt;
115 elementResults.reserve(
lhs.getNumElements());
116 for (
size_t i = 0, e =
lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
117 auto elementResult = calculate(*lhsIt, *rhsIt);
120 elementResults.push_back(*elementResult);
134template <
class LAttrElementT,
class RAttrElementT = LAttrElementT,
135 class LElementValueT =
typename LAttrElementT::ValueType,
136 class RElementValueT =
typename RAttrElementT::ValueType,
137 class PoisonAttr = ub::PoisonAttr,
138 class ResultAttrElementT = LAttrElementT,
139 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
141 LElementValueT, RElementValueT)>>
143 CalculationT &&calculate) {
144 assert(operands.size() == 2 &&
"binary op takes two operands");
146 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
147 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
148 "void as template argument to opt-out from poison semantics.");
149 if constexpr (!std::is_void_v<PoisonAttr>) {
150 if (isa_and_nonnull<PoisonAttr>(operands[0]))
153 if (isa_and_nonnull<PoisonAttr>(operands[1]))
158 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
159 return typed.getType();
163 Type lhsType = getAttrType(operands[0]);
164 Type rhsType = getAttrType(operands[1]);
165 if (!lhsType || !rhsType)
167 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
168 if (lhsType != rhsType)
172 LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
173 ResultAttrElementT, ResultElementValueT, CalculationT>(
174 operands, lhsType, std::forward<CalculationT>(calculate));
177template <
class LAttrElementT,
class RAttrElementT = LAttrElementT,
178 class LElementValueT =
typename LAttrElementT::ValueType,
179 class RElementValueT =
typename RAttrElementT::ValueType,
180 class PoisonAttr =
void,
181 class ResultAttrElementT = LAttrElementT,
182 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
184 function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
186 CalculationT &&calculate) {
188 LElementValueT, RElementValueT,
189 PoisonAttr, ResultAttrElementT>(
190 operands, resultType,
191 [&](LElementValueT a, RElementValueT
b)
192 -> std::optional<ResultElementValueT> {
return calculate(a,
b); });
195template <
class LAttrElementT,
class RAttrElementT = LAttrElementT,
196 class LElementValueT =
typename LAttrElementT::ValueType,
197 class RElementValueT =
typename RAttrElementT::ValueType,
198 class PoisonAttr = ub::PoisonAttr,
199 class ResultAttrElementT = LAttrElementT,
200 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
202 function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
204 CalculationT &&calculate) {
206 LElementValueT, RElementValueT,
207 PoisonAttr, ResultAttrElementT>(
209 [&](LElementValueT a, RElementValueT
b)
210 -> std::optional<ResultElementValueT> {
return calculate(a,
b); });
218template <
class AttrElementT,
219 class ElementValueT =
typename AttrElementT::ValueType,
220 class PoisonAttr = ub::PoisonAttr,
221 class ResultAttrElementT = AttrElementT,
222 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
227 CalculationT &&calculate) {
228 if (!resultType || !llvm::getSingleElement(operands))
232 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
233 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
234 "void as template argument to opt-out from poison semantics.");
235 if constexpr (!std::is_void_v<PoisonAttr>) {
236 if (isa<PoisonAttr>(operands[0]))
240 if (isa<AttrElementT>(operands[0])) {
241 auto op = cast<AttrElementT>(operands[0]);
243 auto res = calculate(op.getValue());
246 return ResultAttrElementT::get(resultType, *res);
248 if (isa<SplatElementsAttr>(operands[0])) {
251 auto op = cast<SplatElementsAttr>(operands[0]);
253 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
257 }
else if (isa<ElementsAttr>(operands[0])) {
260 auto op = cast<ElementsAttr>(operands[0]);
262 auto maybeOpIt = op.try_value_begin<ElementValueT>();
265 auto opIt = *maybeOpIt;
267 elementResults.reserve(op.getNumElements());
268 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
269 auto elementResult = calculate(*opIt);
272 elementResults.push_back(*elementResult);
285template <
class AttrElementT,
286 class ElementValueT =
typename AttrElementT::ValueType,
287 class PoisonAttr = ub::PoisonAttr,
288 class ResultAttrElementT = AttrElementT,
289 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
293 CalculationT &&calculate) {
294 if (!llvm::getSingleElement(operands))
298 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
299 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
300 "void as template argument to opt-out from poison semantics.");
301 if constexpr (!std::is_void_v<PoisonAttr>) {
302 if (isa<PoisonAttr>(operands[0]))
307 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
308 return typed.getType();
312 Type operandType = getAttrType(operands[0]);
317 ResultAttrElementT, ResultElementValueT,
319 operands, operandType, std::forward<CalculationT>(calculate));
322template <
class AttrElementT,
323 class ElementValueT =
typename AttrElementT::ValueType,
324 class PoisonAttr = ub::PoisonAttr,
325 class ResultAttrElementT = AttrElementT,
326 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
327 class CalculationT =
function_ref<ResultElementValueT(ElementValueT)>>
329 CalculationT &&calculate) {
332 operands, resultType,
333 [&](ElementValueT a) -> std::optional<ResultElementValueT> {
338template <
class AttrElementT,
339 class ElementValueT =
typename AttrElementT::ValueType,
340 class PoisonAttr = ub::PoisonAttr,
341 class ResultAttrElementT = AttrElementT,
342 class ResultElementValueT =
typename ResultAttrElementT::ValueType,
343 class CalculationT =
function_ref<ResultElementValueT(ElementValueT)>>
345 CalculationT &&calculate) {
348 operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
354 class AttrElementT,
class TargetAttrElementT,
355 class ElementValueT =
typename AttrElementT::ValueType,
356 class TargetElementValueT =
typename TargetAttrElementT::ValueType,
357 class PoisonAttr = ub::PoisonAttr,
358 class CalculationT =
function_ref<TargetElementValueT(ElementValueT,
bool)>>
360 CalculationT &&calculate) {
361 if (!llvm::getSingleElement(operands))
365 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
366 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
367 "void as template argument to opt-out from poison semantics.");
368 if constexpr (!std::is_void_v<PoisonAttr>) {
369 if (isa<PoisonAttr>(operands[0]))
373 if (isa<AttrElementT>(operands[0])) {
374 auto op = cast<AttrElementT>(operands[0]);
375 bool castStatus =
true;
376 auto res = calculate(op.getValue(), castStatus);
379 return TargetAttrElementT::get(resType, res);
381 if (isa<SplatElementsAttr>(operands[0])) {
384 auto op = cast<SplatElementsAttr>(operands[0]);
385 bool castStatus =
true;
387 calculate(op.getSplatValue<ElementValueT>(), castStatus);
390 auto shapedResType = cast<ShapedType>(resType);
391 if (!shapedResType.hasStaticShape())
395 if (
auto op = dyn_cast<ElementsAttr>(operands[0])) {
398 bool castStatus =
true;
399 auto maybeOpIt = op.try_value_begin<ElementValueT>();
402 auto opIt = *maybeOpIt;
404 elementResults.reserve(op.getNumElements());
405 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
406 auto elt = calculate(*opIt, castStatus);
409 elementResults.push_back(elt);
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Include the generated interface declarations.
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
llvm::function_ref< Fn > function_ref