15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
33 template <
class AttrElementT,
34 class ElementValueT =
typename AttrElementT::ValueType,
35 class PoisonAttr = ub::PoisonAttr,
37 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
40 CalculationT &&calculate) {
41 assert(operands.size() == 2 &&
"binary op takes two operands");
43 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45 "void as template argument to opt-out from poison semantics.");
46 if constexpr (!std::is_void_v<PoisonAttr>) {
47 if (isa_and_nonnull<PoisonAttr>(operands[0]))
50 if (isa_and_nonnull<PoisonAttr>(operands[1]))
54 if (!resultType || !operands[0] || !operands[1])
57 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58 auto lhs = cast<AttrElementT>(operands[0]);
59 auto rhs = cast<AttrElementT>(operands[1]);
60 if (lhs.getType() != rhs.getType())
63 auto calRes = calculate(lhs.getValue(), rhs.getValue());
71 if (isa<SplatElementsAttr>(operands[0]) &&
72 isa<SplatElementsAttr>(operands[1])) {
75 auto lhs = cast<SplatElementsAttr>(operands[0]);
76 auto rhs = cast<SplatElementsAttr>(operands[1]);
77 if (lhs.getType() != rhs.getType())
80 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81 rhs.getSplatValue<ElementValueT>());
88 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
91 auto lhs = cast<ElementsAttr>(operands[0]);
92 auto rhs = cast<ElementsAttr>(operands[1]);
93 if (lhs.getType() != rhs.getType())
96 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98 if (!maybeLhsIt || !maybeRhsIt)
100 auto lhsIt = *maybeLhsIt;
101 auto rhsIt = *maybeRhsIt;
103 elementResults.reserve(lhs.getNumElements());
104 for (
size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105 auto elementResult = calculate(*lhsIt, *rhsIt);
108 elementResults.push_back(*elementResult);
122 template <
class AttrElementT,
123 class ElementValueT =
typename AttrElementT::ValueType,
124 class PoisonAttr = ub::PoisonAttr,
126 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
128 CalculationT &&calculate) {
129 assert(operands.size() == 2 &&
"binary op takes two operands");
131 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133 "void as template argument to opt-out from poison semantics.");
134 if constexpr (!std::is_void_v<PoisonAttr>) {
135 if (isa_and_nonnull<PoisonAttr>(operands[0]))
138 if (isa_and_nonnull<PoisonAttr>(operands[1]))
143 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
144 return typed.getType();
148 Type lhsType = getResultType(operands[0]);
149 Type rhsType = getResultType(operands[1]);
150 if (!lhsType || !rhsType)
152 if (lhsType != rhsType)
157 operands, lhsType, std::forward<CalculationT>(calculate));
160 template <
class AttrElementT,
161 class ElementValueT =
typename AttrElementT::ValueType,
162 class PoisonAttr = void,
164 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
166 CalculationT &&calculate) {
167 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168 operands, resultType,
169 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170 return calculate(a, b);
174 template <
class AttrElementT,
175 class ElementValueT =
typename AttrElementT::ValueType,
176 class PoisonAttr = ub::PoisonAttr,
178 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
180 CalculationT &&calculate) {
181 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
183 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184 return calculate(a, b);
192 template <
class AttrElementT,
193 class ElementValueT =
typename AttrElementT::ValueType,
194 class PoisonAttr = ub::PoisonAttr,
196 function_ref<std::optional<ElementValueT>(ElementValueT)>>
198 CalculationT &&calculate) {
199 assert(operands.size() == 1 &&
"unary op takes one operands");
204 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
205 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
206 "void as template argument to opt-out from poison semantics.");
207 if constexpr (!std::is_void_v<PoisonAttr>) {
208 if (isa<PoisonAttr>(operands[0]))
212 if (isa<AttrElementT>(operands[0])) {
213 auto op = cast<AttrElementT>(operands[0]);
215 auto res = calculate(op.getValue());
220 if (isa<SplatElementsAttr>(operands[0])) {
223 auto op = cast<SplatElementsAttr>(operands[0]);
225 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
229 }
else if (isa<ElementsAttr>(operands[0])) {
232 auto op = cast<ElementsAttr>(operands[0]);
234 auto maybeOpIt = op.try_value_begin<ElementValueT>();
237 auto opIt = *maybeOpIt;
239 elementResults.reserve(op.getNumElements());
240 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
241 auto elementResult = calculate(*opIt);
244 elementResults.push_back(*elementResult);
251 template <
class AttrElementT,
252 class ElementValueT =
typename AttrElementT::ValueType,
253 class PoisonAttr = ub::PoisonAttr,
254 class CalculationT =
function_ref<ElementValueT(ElementValueT)>>
256 CalculationT &&calculate) {
257 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
258 operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
264 class AttrElementT,
class TargetAttrElementT,
265 class ElementValueT =
typename AttrElementT::ValueType,
266 class TargetElementValueT =
typename TargetAttrElementT::ValueType,
267 class PoisonAttr = ub::PoisonAttr,
268 class CalculationT =
function_ref<TargetElementValueT(ElementValueT,
bool)>>
270 CalculationT &&calculate) {
271 assert(operands.size() == 1 &&
"Cast op takes one operand");
276 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
277 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
278 "void as template argument to opt-out from poison semantics.");
279 if constexpr (!std::is_void_v<PoisonAttr>) {
280 if (isa<PoisonAttr>(operands[0]))
284 if (isa<AttrElementT>(operands[0])) {
285 auto op = cast<AttrElementT>(operands[0]);
286 bool castStatus =
true;
287 auto res = calculate(op.getValue(), castStatus);
292 if (isa<SplatElementsAttr>(operands[0])) {
295 auto op = cast<SplatElementsAttr>(operands[0]);
296 bool castStatus =
true;
298 calculate(op.getSplatValue<ElementValueT>(), castStatus);
301 auto shapedResType = cast<ShapedType>(resType);
302 if (!shapedResType.hasStaticShape())
306 if (
auto op = dyn_cast<ElementsAttr>(operands[0])) {
309 bool castStatus =
true;
310 auto maybeOpIt = op.try_value_begin<ElementValueT>();
313 auto opIt = *maybeOpIt;
315 elementResults.reserve(op.getNumElements());
316 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
317 auto elt = calculate(*opIt, castStatus);
320 elementResults.push_back(elt);
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)