MLIR 23.0.0git
CommonFolders.h
Go to the documentation of this file.
1//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file declares various common operation folders. These folders
10// are intended to be used by dialects to support common folding behavior
11// without requiring each dialect to provide its own implementation.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16#define MLIR_DIALECT_COMMONFOLDERS_H
17
18#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Types.h"
23#include "llvm/ADT/ArrayRef.h"
24#include "llvm/ADT/STLExtras.h"
25
26#include <cassert>
27#include <cstddef>
28#include <optional>
29
30namespace mlir {
31namespace ub {
32class PoisonAttr;
33}
34/// Performs constant folding `calculate` with element-wise behavior on the two
35/// attributes in `operands` and returns the result if possible.
36/// Uses `resultType` for the type of the returned attribute.
37/// Optional PoisonAttr template argument allows to specify 'poison' attribute
38/// which will be directly propagated to result.
39template <class LAttrElementT, class RAttrElementT = LAttrElementT,
40 class LElementValueT = typename LAttrElementT::ValueType,
41 class RElementValueT = typename RAttrElementT::ValueType,
42 class PoisonAttr = ub::PoisonAttr,
43 class ResultAttrElementT = LAttrElementT,
44 class ResultElementValueT = typename ResultAttrElementT::ValueType,
46 LElementValueT, RElementValueT)>>
48 Type resultType,
49 CalculationT &&calculate) {
50 assert(operands.size() == 2 && "binary op takes two operands");
51 static_assert(
52 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
53 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
54 "void as template argument to opt-out from poison semantics.");
55 if constexpr (!std::is_void_v<PoisonAttr>) {
56 if (isa_and_nonnull<PoisonAttr>(operands[0]))
57 return operands[0];
58
59 if (isa_and_nonnull<PoisonAttr>(operands[1]))
60 return operands[1];
61 }
62
63 if (!resultType || !operands[0] || !operands[1])
64 return {};
65
66 if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
67 auto lhs = cast<LAttrElementT>(operands[0]);
68 auto rhs = cast<RAttrElementT>(operands[1]);
69 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
70 if (lhs.getType() != rhs.getType())
71 return {};
72
73 auto calRes = calculate(lhs.getValue(), rhs.getValue());
74
75 if (!calRes)
76 return {};
77
78 return ResultAttrElementT::get(resultType, *calRes);
79 }
80
81 if (isa<SplatElementsAttr>(operands[0]) &&
82 isa<SplatElementsAttr>(operands[1])) {
83 // Both operands are splats so we can avoid expanding the values out and
84 // just fold based on the splat value.
85 auto lhs = cast<SplatElementsAttr>(operands[0]);
86 auto rhs = cast<SplatElementsAttr>(operands[1]);
87 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
88 if (lhs.getType() != rhs.getType())
89 return {};
90
91 auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
92 rhs.getSplatValue<RElementValueT>());
93 if (!elementResult)
94 return {};
95
96 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
97 }
98
99 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
100 // Operands are ElementsAttr-derived; perform an element-wise fold by
101 // expanding the values.
102 auto lhs = cast<ElementsAttr>(operands[0]);
103 auto rhs = cast<ElementsAttr>(operands[1]);
104 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
105 if (lhs.getType() != rhs.getType())
106 return {};
107
108 auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
109 auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
110 if (!maybeLhsIt || !maybeRhsIt)
111 return {};
112 auto lhsIt = *maybeLhsIt;
113 auto rhsIt = *maybeRhsIt;
115 elementResults.reserve(lhs.getNumElements());
116 for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
117 auto elementResult = calculate(*lhsIt, *rhsIt);
118 if (!elementResult)
119 return {};
120 elementResults.push_back(*elementResult);
121 }
122
123 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
124 }
125 return {};
126}
127
128/// Performs constant folding `calculate` with element-wise behavior on the two
129/// attributes in `operands` and returns the result if possible.
130/// Uses the operand element type for the element type of the returned
131/// attribute.
132/// Optional PoisonAttr template argument allows to specify 'poison' attribute
133/// which will be directly propagated to result.
134template <class LAttrElementT, class RAttrElementT = LAttrElementT,
135 class LElementValueT = typename LAttrElementT::ValueType,
136 class RElementValueT = typename RAttrElementT::ValueType,
137 class PoisonAttr = ub::PoisonAttr,
138 class ResultAttrElementT = LAttrElementT,
139 class ResultElementValueT = typename ResultAttrElementT::ValueType,
141 LElementValueT, RElementValueT)>>
143 CalculationT &&calculate) {
144 assert(operands.size() == 2 && "binary op takes two operands");
145 static_assert(
146 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
147 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
148 "void as template argument to opt-out from poison semantics.");
149 if constexpr (!std::is_void_v<PoisonAttr>) {
150 if (isa_and_nonnull<PoisonAttr>(operands[0]))
151 return operands[0];
152
153 if (isa_and_nonnull<PoisonAttr>(operands[1]))
154 return operands[1];
155 }
156
157 auto getAttrType = [](Attribute attr) -> Type {
158 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
159 return typed.getType();
160 return {};
161 };
162
163 Type lhsType = getAttrType(operands[0]);
164 Type rhsType = getAttrType(operands[1]);
165 if (!lhsType || !rhsType)
166 return {};
167 if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
168 if (lhsType != rhsType)
169 return {};
170
172 LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
173 ResultAttrElementT, ResultElementValueT, CalculationT>(
174 operands, lhsType, std::forward<CalculationT>(calculate));
175}
176
177template <class LAttrElementT, class RAttrElementT = LAttrElementT,
178 class LElementValueT = typename LAttrElementT::ValueType,
179 class RElementValueT = typename RAttrElementT::ValueType,
180 class PoisonAttr = void, //
181 class ResultAttrElementT = LAttrElementT,
182 class ResultElementValueT = typename ResultAttrElementT::ValueType,
183 class CalculationT =
184 function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
186 CalculationT &&calculate) {
187 return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
188 LElementValueT, RElementValueT,
189 PoisonAttr, ResultAttrElementT>(
190 operands, resultType,
191 [&](LElementValueT a, RElementValueT b)
192 -> std::optional<ResultElementValueT> { return calculate(a, b); });
193}
194
195template <class LAttrElementT, class RAttrElementT = LAttrElementT,
196 class LElementValueT = typename LAttrElementT::ValueType,
197 class RElementValueT = typename RAttrElementT::ValueType,
198 class PoisonAttr = ub::PoisonAttr,
199 class ResultAttrElementT = LAttrElementT,
200 class ResultElementValueT = typename ResultAttrElementT::ValueType,
201 class CalculationT =
202 function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
204 CalculationT &&calculate) {
205 return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
206 LElementValueT, RElementValueT,
207 PoisonAttr, ResultAttrElementT>(
208 operands,
209 [&](LElementValueT a, RElementValueT b)
210 -> std::optional<ResultElementValueT> { return calculate(a, b); });
211}
212
213/// Performs constant folding `calculate` with element-wise behavior on the one
214/// attributes in `operands` and returns the result if possible.
215/// Uses `resultType` for the type of the returned attribute.
216/// Optional PoisonAttr template argument allows to specify 'poison' attribute
217/// which will be directly propagated to result.
218template <class AttrElementT, //
219 class ElementValueT = typename AttrElementT::ValueType,
220 class PoisonAttr = ub::PoisonAttr,
221 class ResultAttrElementT = AttrElementT,
222 class ResultElementValueT = typename ResultAttrElementT::ValueType,
223 class CalculationT =
226 Type resultType,
227 CalculationT &&calculate) {
228 if (!resultType || !llvm::getSingleElement(operands))
229 return {};
230
231 static_assert(
232 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
233 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
234 "void as template argument to opt-out from poison semantics.");
235 if constexpr (!std::is_void_v<PoisonAttr>) {
236 if (isa<PoisonAttr>(operands[0]))
237 return operands[0];
238 }
239
240 if (isa<AttrElementT>(operands[0])) {
241 auto op = cast<AttrElementT>(operands[0]);
242
243 auto res = calculate(op.getValue());
244 if (!res)
245 return {};
246 return ResultAttrElementT::get(resultType, *res);
247 }
248 if (isa<SplatElementsAttr>(operands[0])) {
249 // Both operands are splats so we can avoid expanding the values out and
250 // just fold based on the splat value.
251 auto op = cast<SplatElementsAttr>(operands[0]);
252
253 auto elementResult = calculate(op.getSplatValue<ElementValueT>());
254 if (!elementResult)
255 return {};
256 return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
257 } else if (isa<ElementsAttr>(operands[0])) {
258 // Operands are ElementsAttr-derived; perform an element-wise fold by
259 // expanding the values.
260 auto op = cast<ElementsAttr>(operands[0]);
261
262 auto maybeOpIt = op.try_value_begin<ElementValueT>();
263 if (!maybeOpIt)
264 return {};
265 auto opIt = *maybeOpIt;
267 elementResults.reserve(op.getNumElements());
268 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
269 auto elementResult = calculate(*opIt);
270 if (!elementResult)
271 return {};
272 elementResults.push_back(*elementResult);
273 }
274 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
275 }
276 return {};
277}
278
279/// Performs constant folding `calculate` with element-wise behavior on the one
280/// attributes in `operands` and returns the result if possible.
281/// Uses the operand element type for the element type of the returned
282/// attribute.
283/// Optional PoisonAttr template argument allows to specify 'poison' attribute
284/// which will be directly propagated to result.
285template <class AttrElementT, //
286 class ElementValueT = typename AttrElementT::ValueType,
287 class PoisonAttr = ub::PoisonAttr,
288 class ResultAttrElementT = AttrElementT,
289 class ResultElementValueT = typename ResultAttrElementT::ValueType,
290 class CalculationT =
293 CalculationT &&calculate) {
294 if (!llvm::getSingleElement(operands))
295 return {};
296
297 static_assert(
298 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
299 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
300 "void as template argument to opt-out from poison semantics.");
301 if constexpr (!std::is_void_v<PoisonAttr>) {
302 if (isa<PoisonAttr>(operands[0]))
303 return operands[0];
304 }
305
306 auto getAttrType = [](Attribute attr) -> Type {
307 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
308 return typed.getType();
309 return {};
310 };
311
312 Type operandType = getAttrType(operands[0]);
313 if (!operandType)
314 return {};
315
316 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
317 ResultAttrElementT, ResultElementValueT,
318 CalculationT>(
319 operands, operandType, std::forward<CalculationT>(calculate));
320}
321
322template <class AttrElementT, //
323 class ElementValueT = typename AttrElementT::ValueType,
324 class PoisonAttr = ub::PoisonAttr,
325 class ResultAttrElementT = AttrElementT,
326 class ResultElementValueT = typename ResultAttrElementT::ValueType,
327 class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
329 CalculationT &&calculate) {
330 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
331 ResultAttrElementT>(
332 operands, resultType,
333 [&](ElementValueT a) -> std::optional<ResultElementValueT> {
334 return calculate(a);
335 });
336}
337
338template <class AttrElementT, //
339 class ElementValueT = typename AttrElementT::ValueType,
340 class PoisonAttr = ub::PoisonAttr,
341 class ResultAttrElementT = AttrElementT,
342 class ResultElementValueT = typename ResultAttrElementT::ValueType,
343 class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
345 CalculationT &&calculate) {
346 return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
347 ResultAttrElementT>(
348 operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
349 return calculate(a);
350 });
351}
352
353template <
354 class AttrElementT, class TargetAttrElementT,
355 class ElementValueT = typename AttrElementT::ValueType,
356 class TargetElementValueT = typename TargetAttrElementT::ValueType,
357 class PoisonAttr = ub::PoisonAttr,
358 class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
360 CalculationT &&calculate) {
361 if (!llvm::getSingleElement(operands))
362 return {};
363
364 static_assert(
365 std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
366 "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
367 "void as template argument to opt-out from poison semantics.");
368 if constexpr (!std::is_void_v<PoisonAttr>) {
369 if (isa<PoisonAttr>(operands[0]))
370 return operands[0];
371 }
372
373 if (isa<AttrElementT>(operands[0])) {
374 auto op = cast<AttrElementT>(operands[0]);
375 bool castStatus = true;
376 auto res = calculate(op.getValue(), castStatus);
377 if (!castStatus)
378 return {};
379 return TargetAttrElementT::get(resType, res);
380 }
381 if (isa<SplatElementsAttr>(operands[0])) {
382 // The operand is a splat so we can avoid expanding the values out and
383 // just fold based on the splat value.
384 auto op = cast<SplatElementsAttr>(operands[0]);
385 bool castStatus = true;
386 auto elementResult =
387 calculate(op.getSplatValue<ElementValueT>(), castStatus);
388 if (!castStatus)
389 return {};
390 auto shapedResType = cast<ShapedType>(resType);
391 if (!shapedResType.hasStaticShape())
392 return {};
393 return DenseElementsAttr::get(shapedResType, elementResult);
394 }
395 if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
396 // Operand is ElementsAttr-derived; perform an element-wise fold by
397 // expanding the value.
398 bool castStatus = true;
399 auto maybeOpIt = op.try_value_begin<ElementValueT>();
400 if (!maybeOpIt)
401 return {};
402 auto opIt = *maybeOpIt;
404 elementResults.reserve(op.getNumElements());
405 for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
406 auto elt = calculate(*opIt, castStatus);
407 if (!castStatus)
408 return {};
409 elementResults.push_back(elt);
410 }
411
412 return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
413 }
414 return {};
415}
416} // namespace mlir
417
418#endif // MLIR_DIALECT_COMMONFOLDERS_H
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147