15 #ifndef MLIR_DIALECT_COMMONFOLDERS_H
16 #define MLIR_DIALECT_COMMONFOLDERS_H
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
33 template <
class AttrElementT,
34 class ElementValueT =
typename AttrElementT::ValueType,
35 class PoisonAttr = ub::PoisonAttr,
37 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
40 CalculationT &&calculate) {
41 assert(operands.size() == 2 &&
"binary op takes two operands");
43 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
44 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
45 "void as template argument to opt-out from poison semantics.");
46 if constexpr (!std::is_void_v<PoisonAttr>) {
47 if (isa_and_nonnull<PoisonAttr>(operands[0]))
50 if (isa_and_nonnull<PoisonAttr>(operands[1]))
54 if (!resultType || !operands[0] || !operands[1])
57 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
58 auto lhs = cast<AttrElementT>(operands[0]);
59 auto rhs = cast<AttrElementT>(operands[1]);
60 if (lhs.getType() != rhs.getType())
63 auto calRes = calculate(lhs.getValue(), rhs.getValue());
71 if (isa<SplatElementsAttr>(operands[0]) &&
72 isa<SplatElementsAttr>(operands[1])) {
75 auto lhs = cast<SplatElementsAttr>(operands[0]);
76 auto rhs = cast<SplatElementsAttr>(operands[1]);
77 if (lhs.getType() != rhs.getType())
80 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
81 rhs.getSplatValue<ElementValueT>());
88 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
91 auto lhs = cast<ElementsAttr>(operands[0]);
92 auto rhs = cast<ElementsAttr>(operands[1]);
93 if (lhs.getType() != rhs.getType())
96 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
97 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
98 if (!maybeLhsIt || !maybeRhsIt)
100 auto lhsIt = *maybeLhsIt;
101 auto rhsIt = *maybeRhsIt;
103 elementResults.reserve(lhs.getNumElements());
104 for (
size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
105 auto elementResult = calculate(*lhsIt, *rhsIt);
108 elementResults.push_back(*elementResult);
122 template <
class AttrElementT,
123 class ElementValueT =
typename AttrElementT::ValueType,
124 class PoisonAttr = ub::PoisonAttr,
126 std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
128 CalculationT &&calculate) {
129 assert(operands.size() == 2 &&
"binary op takes two operands");
131 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
132 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
133 "void as template argument to opt-out from poison semantics.");
134 if constexpr (!std::is_void_v<PoisonAttr>) {
135 if (isa_and_nonnull<PoisonAttr>(operands[0]))
138 if (isa_and_nonnull<PoisonAttr>(operands[1]))
143 if (
auto typed = dyn_cast_or_null<TypedAttr>(attr))
144 return typed.getType();
148 Type lhsType = getResultType(operands[0]);
149 Type rhsType = getResultType(operands[1]);
150 if (!lhsType || !rhsType)
152 if (lhsType != rhsType)
157 operands, lhsType, std::forward<CalculationT>(calculate));
160 template <
class AttrElementT,
161 class ElementValueT =
typename AttrElementT::ValueType,
162 class PoisonAttr = void,
164 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
166 CalculationT &&calculate) {
167 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
168 operands, resultType,
169 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170 return calculate(a, b);
174 template <
class AttrElementT,
175 class ElementValueT =
typename AttrElementT::ValueType,
176 class PoisonAttr = ub::PoisonAttr,
178 function_ref<ElementValueT(ElementValueT, ElementValueT)>>
180 CalculationT &&calculate) {
181 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
183 [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184 return calculate(a, b);
192 template <
class AttrElementT,
193 class ElementValueT =
typename AttrElementT::ValueType,
194 class PoisonAttr = ub::PoisonAttr,
196 function_ref<std::optional<ElementValueT>(ElementValueT)>>
198 CalculationT &&calculate) {
199 if (!llvm::getSingleElement(operands))
203 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
204 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
205 "void as template argument to opt-out from poison semantics.");
206 if constexpr (!std::is_void_v<PoisonAttr>) {
207 if (isa<PoisonAttr>(operands[0]))
211 if (isa<AttrElementT>(operands[0])) {
212 auto op = cast<AttrElementT>(operands[0]);
214 auto res = calculate(op.getValue());
219 if (isa<SplatElementsAttr>(operands[0])) {
222 auto op = cast<SplatElementsAttr>(operands[0]);
224 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
228 }
else if (isa<ElementsAttr>(operands[0])) {
231 auto op = cast<ElementsAttr>(operands[0]);
233 auto maybeOpIt = op.try_value_begin<ElementValueT>();
236 auto opIt = *maybeOpIt;
238 elementResults.reserve(op.getNumElements());
239 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
240 auto elementResult = calculate(*opIt);
243 elementResults.push_back(*elementResult);
250 template <
class AttrElementT,
251 class ElementValueT =
typename AttrElementT::ValueType,
252 class PoisonAttr = ub::PoisonAttr,
253 class CalculationT =
function_ref<ElementValueT(ElementValueT)>>
255 CalculationT &&calculate) {
256 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
257 operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
263 class AttrElementT,
class TargetAttrElementT,
264 class ElementValueT =
typename AttrElementT::ValueType,
265 class TargetElementValueT =
typename TargetAttrElementT::ValueType,
266 class PoisonAttr = ub::PoisonAttr,
267 class CalculationT =
function_ref<TargetElementValueT(ElementValueT,
bool)>>
269 CalculationT &&calculate) {
270 if (!llvm::getSingleElement(operands))
274 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
275 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
276 "void as template argument to opt-out from poison semantics.");
277 if constexpr (!std::is_void_v<PoisonAttr>) {
278 if (isa<PoisonAttr>(operands[0]))
282 if (isa<AttrElementT>(operands[0])) {
283 auto op = cast<AttrElementT>(operands[0]);
284 bool castStatus =
true;
285 auto res = calculate(op.getValue(), castStatus);
290 if (isa<SplatElementsAttr>(operands[0])) {
293 auto op = cast<SplatElementsAttr>(operands[0]);
294 bool castStatus =
true;
296 calculate(op.getSplatValue<ElementValueT>(), castStatus);
299 auto shapedResType = cast<ShapedType>(resType);
300 if (!shapedResType.hasStaticShape())
304 if (
auto op = dyn_cast<ElementsAttr>(operands[0])) {
307 bool castStatus =
true;
308 auto maybeOpIt = op.try_value_begin<ElementValueT>();
311 auto opIt = *maybeOpIt;
313 elementResults.reserve(op.getNumElements());
314 for (
size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
315 auto elt = calculate(*opIt, castStatus);
318 elementResults.push_back(elt);
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)