MLIR 22.0.0git
CommonFolders.h
Go to the documentation of this file.
1//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file declares various common operation folders. These folders
10// are intended to be used by dialects to support common folding behavior
11// without requiring each dialect to provide its own implementation.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16#define MLIR_DIALECT_COMMONFOLDERS_H
17
18#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Types.h"
23#include "llvm/ADT/ArrayRef.h"
24#include "llvm/ADT/STLExtras.h"
25
26#include <cassert>
27#include <cstddef>
28#include <optional>
29
30namespace mlir {
31namespace ub {
32class PoisonAttr;
33}
34/// Performs constant folding `calculate` with element-wise behavior on the two
35/// attributes in `operands` and returns the result if possible.
36/// Uses `resultType` for the type of the returned attribute.
37/// Optional PoisonAttr template argument allows to specify 'poison' attribute
38/// which will be directly propagated to result.
39template <class AttrElementT, //
40 class ElementValueT = typename AttrElementT::ValueType,
41 class PoisonAttr = ub::PoisonAttr,
42 class ResultAttrElementT = AttrElementT,
43 class ResultElementValueT = typename ResultAttrElementT::ValueType,
44 class CalculationT = function_ref<
45 std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
47 Type resultType,
48 CalculationT &&calculate) {
49 assert(operands.size() == 2 && "binary op takes two operands");
50 static_assert(
51 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
52 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
53 "void as template argument to opt-out from poison semantics.");
54 if constexpr (!std::is_void_v<PoisonAttr>) {
55 if (isa_and_nonnull<PoisonAttr>(operands[0]))
56 return operands[0];
57
58 if (isa_and_nonnull<PoisonAttr>(operands[1]))
59 return operands[1];
60 }
61
62 if (!resultType || !operands[0] || !operands[1])
63 return {};
64
65 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
66 auto lhs = cast<AttrElementT>(operands[0]);
67 auto rhs = cast<AttrElementT>(operands[1]);
68 if (lhs.getType() != rhs.getType())
69 return {};
70
71 auto calRes = calculate(lhs.getValue(), rhs.getValue());
72
73 if (!calRes)
74 return {};
75
76 return ResultAttrElementT::get(resultType, *calRes);
77 }
78
79 if (isa<SplatElementsAttr>(operands[0]) &&
80 isa<SplatElementsAttr>(operands[1])) {
81 // Both operands are splats so we can avoid expanding the values out and
82 // just fold based on the splat value.
83 auto lhs = cast<SplatElementsAttr>(operands[0]);
84 auto rhs = cast<SplatElementsAttr>(operands[1]);
85 if (lhs.getType() != rhs.getType())
86 return {};
87
88 auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
89 rhs.getSplatValue<ElementValueT>());
90 if (!elementResult)
91 return {};
92
93 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
94 }
95
96 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
97 // Operands are ElementsAttr-derived; perform an element-wise fold by
98 // expanding the values.
99 auto lhs = cast<ElementsAttr>(operands[0]);
100 auto rhs = cast<ElementsAttr>(operands[1]);
101 if (lhs.getType() != rhs.getType())
102 return {};
103
104 auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
105 auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
106 if (!maybeLhsIt || !maybeRhsIt)
107 return {};
108 auto lhsIt = *maybeLhsIt;
109 auto rhsIt = *maybeRhsIt;
111 elementResults.reserve(lhs.getNumElements());
112 for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
113 auto elementResult = calculate(*lhsIt, *rhsIt);
114 if (!elementResult)
115 return {};
116 elementResults.push_back(*elementResult);
117 }
118
119 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
120 }
121 return {};
122}
123
124/// Performs constant folding `calculate` with element-wise behavior on the two
125/// attributes in `operands` and returns the result if possible.
126/// Uses the operand element type for the element type of the returned
127/// attribute.
128/// Optional PoisonAttr template argument allows to specify 'poison' attribute
129/// which will be directly propagated to result.
130template <class AttrElementT, //
131 class ElementValueT = typename AttrElementT::ValueType,
132 class PoisonAttr = ub::PoisonAttr,
133 class ResultAttrElementT = AttrElementT,
134 class ResultElementValueT = typename ResultAttrElementT::ValueType,
135 class CalculationT = function_ref<
136 std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
138 CalculationT &&calculate) {
139 assert(operands.size() == 2 && "binary op takes two operands");
140 static_assert(
141 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
142 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
143 "void as template argument to opt-out from poison semantics.");
144 if constexpr (!std::is_void_v<PoisonAttr>) {
145 if (isa_and_nonnull<PoisonAttr>(operands[0]))
146 return operands[0];
147
148 if (isa_and_nonnull<PoisonAttr>(operands[1]))
149 return operands[1];
150 }
151
152 auto getAttrType = [](Attribute attr) -> Type {
153 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
154 return typed.getType();
155 return {};
156 };
157
158 Type lhsType = getAttrType(operands[0]);
159 Type rhsType = getAttrType(operands[1]);
160 if (!lhsType || !rhsType)
161 return {};
162 if (lhsType != rhsType)
163 return {};
164
165 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166 ResultAttrElementT, ResultElementValueT,
167 CalculationT>(
168 operands, lhsType, std::forward<CalculationT>(calculate));
169}
170
171template <class AttrElementT,
172 class ElementValueT = typename AttrElementT::ValueType,
173 class PoisonAttr = void, //
174 class ResultAttrElementT = AttrElementT,
175 class ResultElementValueT = typename ResultAttrElementT::ValueType,
176 class CalculationT =
177 function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
179 CalculationT &&calculate) {
180 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181 ResultAttrElementT>(
182 operands, resultType,
183 [&](ElementValueT a, ElementValueT b)
184 -> std::optional<ResultElementValueT> { return calculate(a, b); });
185}
186
187template <class AttrElementT, //
188 class ElementValueT = typename AttrElementT::ValueType,
189 class PoisonAttr = ub::PoisonAttr,
190 class ResultAttrElementT = AttrElementT,
191 class ResultElementValueT = typename ResultAttrElementT::ValueType,
192 class CalculationT =
193 function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
195 CalculationT &&calculate) {
196 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197 ResultAttrElementT>(
198 operands,
199 [&](ElementValueT a, ElementValueT b)
200 -> std::optional<ResultElementValueT> { return calculate(a, b); });
201}
202
203/// Performs constant folding `calculate` with element-wise behavior on the one
204/// attributes in `operands` and returns the result if possible.
205/// Uses `resultType` for the type of the returned attribute.
206/// Optional PoisonAttr template argument allows to specify 'poison' attribute
207/// which will be directly propagated to result.
208template <class AttrElementT, //
209 class ElementValueT = typename AttrElementT::ValueType,
210 class PoisonAttr = ub::PoisonAttr,
211 class ResultAttrElementT = AttrElementT,
212 class ResultElementValueT = typename ResultAttrElementT::ValueType,
213 class CalculationT =
216 Type resultType,
217 CalculationT &&calculate) {
218 if (!resultType || !llvm::getSingleElement(operands))
219 return {};
220
221 static_assert(
222 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
223 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
224 "void as template argument to opt-out from poison semantics.");
225 if constexpr (!std::is_void_v<PoisonAttr>) {
226 if (isa<PoisonAttr>(operands[0]))
227 return operands[0];
228 }
229
230 if (isa<AttrElementT>(operands[0])) {
231 auto op = cast<AttrElementT>(operands[0]);
232
233 auto res = calculate(op.getValue());
234 if (!res)
235 return {};
236 return ResultAttrElementT::get(resultType, *res);
237 }
238 if (isa<SplatElementsAttr>(operands[0])) {
239 // Both operands are splats so we can avoid expanding the values out and
240 // just fold based on the splat value.
241 auto op = cast<SplatElementsAttr>(operands[0]);
242
243 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
244 if (!elementResult)
245 return {};
246 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
247 } else if (isa<ElementsAttr>(operands[0])) {
248 // Operands are ElementsAttr-derived; perform an element-wise fold by
249 // expanding the values.
250 auto op = cast<ElementsAttr>(operands[0]);
251
252 auto maybeOpIt = op.try_value_begin<ElementValueT>();
253 if (!maybeOpIt)
254 return {};
255 auto opIt = *maybeOpIt;
257 elementResults.reserve(op.getNumElements());
258 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
259 auto elementResult = calculate(*opIt);
260 if (!elementResult)
261 return {};
262 elementResults.push_back(*elementResult);
263 }
264 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
265 }
266 return {};
267}
268
269/// Performs constant folding `calculate` with element-wise behavior on the one
270/// attributes in `operands` and returns the result if possible.
271/// Uses the operand element type for the element type of the returned
272/// attribute.
273/// Optional PoisonAttr template argument allows to specify 'poison' attribute
274/// which will be directly propagated to result.
275template <class AttrElementT, //
276 class ElementValueT = typename AttrElementT::ValueType,
277 class PoisonAttr = ub::PoisonAttr,
278 class ResultAttrElementT = AttrElementT,
279 class ResultElementValueT = typename ResultAttrElementT::ValueType,
280 class CalculationT =
283 CalculationT &&calculate) {
284 if (!llvm::getSingleElement(operands))
285 return {};
286
287 static_assert(
288 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290 "void as template argument to opt-out from poison semantics.");
291 if constexpr (!std::is_void_v<PoisonAttr>) {
292 if (isa<PoisonAttr>(operands[0]))
293 return operands[0];
294 }
295
296 auto getAttrType = [](Attribute attr) -> Type {
297 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298 return typed.getType();
299 return {};
300 };
301
302 Type operandType = getAttrType(operands[0]);
303 if (!operandType)
304 return {};
305
306 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307 ResultAttrElementT, ResultElementValueT,
308 CalculationT>(
309 operands, operandType, std::forward<CalculationT>(calculate));
310}
311
312template <class AttrElementT, //
313 class ElementValueT = typename AttrElementT::ValueType,
314 class PoisonAttr = ub::PoisonAttr,
315 class ResultAttrElementT = AttrElementT,
316 class ResultElementValueT = typename ResultAttrElementT::ValueType,
317 class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
319 CalculationT &&calculate) {
320 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321 ResultAttrElementT>(
322 operands, resultType,
323 [&](ElementValueT a) -> std::optional<ResultElementValueT> {
324 return calculate(a);
325 });
326}
327
328template <class AttrElementT, //
329 class ElementValueT = typename AttrElementT::ValueType,
330 class PoisonAttr = ub::PoisonAttr,
331 class ResultAttrElementT = AttrElementT,
332 class ResultElementValueT = typename ResultAttrElementT::ValueType,
333 class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
335 CalculationT &&calculate) {
336 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337 ResultAttrElementT>(
338 operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
339 return calculate(a);
340 });
341}
342
343template <
344 class AttrElementT, class TargetAttrElementT,
345 class ElementValueT = typename AttrElementT::ValueType,
346 class TargetElementValueT = typename TargetAttrElementT::ValueType,
347 class PoisonAttr = ub::PoisonAttr,
348 class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
350 CalculationT &&calculate) {
351 if (!llvm::getSingleElement(operands))
352 return {};
353
354 static_assert(
355 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
356 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
357 "void as template argument to opt-out from poison semantics.");
358 if constexpr (!std::is_void_v<PoisonAttr>) {
359 if (isa<PoisonAttr>(operands[0]))
360 return operands[0];
361 }
362
363 if (isa<AttrElementT>(operands[0])) {
364 auto op = cast<AttrElementT>(operands[0]);
365 bool castStatus = true;
366 auto res = calculate(op.getValue(), castStatus);
367 if (!castStatus)
368 return {};
369 return TargetAttrElementT::get(resType, res);
370 }
371 if (isa<SplatElementsAttr>(operands[0])) {
372 // The operand is a splat so we can avoid expanding the values out and
373 // just fold based on the splat value.
374 auto op = cast<SplatElementsAttr>(operands[0]);
375 bool castStatus = true;
376 auto elementResult =
377 calculate(op.getSplatValue<ElementValueT>(), castStatus);
378 if (!castStatus)
379 return {};
380 auto shapedResType = cast<ShapedType>(resType);
381 if (!shapedResType.hasStaticShape())
382 return {};
383 return DenseElementsAttr::get(shapedResType, elementResult);
384 }
385 if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
386 // Operand is ElementsAttr-derived; perform an element-wise fold by
387 // expanding the value.
388 bool castStatus = true;
389 auto maybeOpIt = op.try_value_begin<ElementValueT>();
390 if (!maybeOpIt)
391 return {};
392 auto opIt = *maybeOpIt;
394 elementResults.reserve(op.getNumElements());
395 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
396 auto elt = calculate(*opIt, castStatus);
397 if (!castStatus)
398 return {};
399 elementResults.push_back(elt);
400 }
401
402 return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
403 }
404 return {};
405}
406} // namespace mlir
407
408#endif // MLIR_DIALECT_COMMONFOLDERS_H
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152