MLIR 23.0.0git
MathOps.cpp
Go to the documentation of this file.
1//===- MathOps.cpp - MLIR operations for math implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::math;
18
19//===----------------------------------------------------------------------===//
20// Common helpers
21//===----------------------------------------------------------------------===//
22
23/// Return the type of the same shape (scalar, vector or tensor) containing i1.
24static Type getI1SameShape(Type type) {
25 auto i1Type = IntegerType::get(type.getContext(), 1);
26 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
30 return i1Type;
31}
32
33//===----------------------------------------------------------------------===//
34// TableGen'd op method definitions
35//===----------------------------------------------------------------------===//
36
37#define GET_OP_CLASSES
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39
40//===----------------------------------------------------------------------===//
41// AbsFOp folder
42//===----------------------------------------------------------------------===//
43
44OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](const APFloat &a) { return abs(a); });
47}
48
49//===----------------------------------------------------------------------===//
50// AbsIOp folder
51//===----------------------------------------------------------------------===//
52
53OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](const APInt &a) { return a.abs(); });
56}
57
58//===----------------------------------------------------------------------===//
59// AcosOp folder
60//===----------------------------------------------------------------------===//
61
62OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
64 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
66 case APFloat::Semantics::S_IEEEdouble:
67 return APFloat(acos(a.convertToDouble()));
68 case APFloat::Semantics::S_IEEEsingle:
69 return APFloat(acosf(a.convertToFloat()));
70 default:
71 return {};
72 }
73 });
74}
75
76//===----------------------------------------------------------------------===//
77// AcoshOp folder
78//===----------------------------------------------------------------------===//
79
80OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
82 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
84 case APFloat::Semantics::S_IEEEdouble:
85 return APFloat(acosh(a.convertToDouble()));
86 case APFloat::Semantics::S_IEEEsingle:
87 return APFloat(acoshf(a.convertToFloat()));
88 default:
89 return {};
90 }
91 });
92}
93
94//===----------------------------------------------------------------------===//
95// AsinOp folder
96//===----------------------------------------------------------------------===//
97
98OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
100 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
102 case APFloat::Semantics::S_IEEEdouble:
103 return APFloat(asin(a.convertToDouble()));
104 case APFloat::Semantics::S_IEEEsingle:
105 return APFloat(asinf(a.convertToFloat()));
106 default:
107 return {};
108 }
109 });
110}
111
112//===----------------------------------------------------------------------===//
113// AsinhOp folder
114//===----------------------------------------------------------------------===//
115
116OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
118 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
120 case APFloat::Semantics::S_IEEEdouble:
121 return APFloat(asinh(a.convertToDouble()));
122 case APFloat::Semantics::S_IEEEsingle:
123 return APFloat(asinhf(a.convertToFloat()));
124 default:
125 return {};
126 }
127 });
128}
129
130//===----------------------------------------------------------------------===//
131// AtanOp folder
132//===----------------------------------------------------------------------===//
133
134OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
136 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
138 case APFloat::Semantics::S_IEEEdouble:
139 return APFloat(atan(a.convertToDouble()));
140 case APFloat::Semantics::S_IEEEsingle:
141 return APFloat(atanf(a.convertToFloat()));
142 default:
143 return {};
144 }
145 });
146}
147
148//===----------------------------------------------------------------------===//
149// AtanhOp folder
150//===----------------------------------------------------------------------===//
151
152OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
154 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
156 case APFloat::Semantics::S_IEEEdouble:
157 return APFloat(atanh(a.convertToDouble()));
158 case APFloat::Semantics::S_IEEEsingle:
159 return APFloat(atanhf(a.convertToFloat()));
160 default:
161 return {};
162 }
163 });
164}
165
166//===----------------------------------------------------------------------===//
167// Atan2Op folder
168//===----------------------------------------------------------------------===//
169
170OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
172 adaptor.getOperands(),
173 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
176
177 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
178 case APFloat::Semantics::S_IEEEdouble:
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180 case APFloat::Semantics::S_IEEEsingle:
181 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
182 default:
183 return {};
184 }
185 });
186}
187
188//===----------------------------------------------------------------------===//
189// CbrtOp folder
190//===----------------------------------------------------------------------===//
191
192OpFoldResult math::CbrtOp::fold(FoldAdaptor adaptor) {
194 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
195 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
196 case APFloat::Semantics::S_IEEEdouble:
197 return APFloat(cbrt(a.convertToDouble()));
198 case APFloat::Semantics::S_IEEEsingle:
199 return APFloat(cbrtf(a.convertToFloat()));
200 default:
201 return {};
202 }
203 });
204}
205
206//===----------------------------------------------------------------------===//
207// CeilOp folder
208//===----------------------------------------------------------------------===//
209
210OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
212 adaptor.getOperands(), [](const APFloat &a) {
213 APFloat result(a);
214 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
215 return result;
216 });
217}
218
219//===----------------------------------------------------------------------===//
220// CopySignOp folder
221//===----------------------------------------------------------------------===//
222
223OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
224 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
225 [](const APFloat &a, const APFloat &b) {
226 APFloat result(a);
227 result.copySign(b);
228 return result;
229 });
230}
231
232//===----------------------------------------------------------------------===//
233// CosOp folder
234//===----------------------------------------------------------------------===//
235
236OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
238 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
239 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
240 case APFloat::Semantics::S_IEEEdouble:
241 return APFloat(cos(a.convertToDouble()));
242 case APFloat::Semantics::S_IEEEsingle:
243 return APFloat(cosf(a.convertToFloat()));
244 default:
245 return {};
246 }
247 });
248}
249
250//===----------------------------------------------------------------------===//
251// CoshOp folder
252//===----------------------------------------------------------------------===//
253
254OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
256 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
257 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
258 case APFloat::Semantics::S_IEEEdouble:
259 return APFloat(cosh(a.convertToDouble()));
260 case APFloat::Semantics::S_IEEEsingle:
261 return APFloat(coshf(a.convertToFloat()));
262 default:
263 return {};
264 }
265 });
266}
267
268//===----------------------------------------------------------------------===//
269// SinOp folder
270//===----------------------------------------------------------------------===//
271
272OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
274 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
275 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
276 case APFloat::Semantics::S_IEEEdouble:
277 return APFloat(sin(a.convertToDouble()));
278 case APFloat::Semantics::S_IEEEsingle:
279 return APFloat(sinf(a.convertToFloat()));
280 default:
281 return {};
282 }
283 });
284}
285
286//===----------------------------------------------------------------------===//
287// SinhOp folder
288//===----------------------------------------------------------------------===//
289
290OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
292 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
293 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
294 case APFloat::Semantics::S_IEEEdouble:
295 return APFloat(sinh(a.convertToDouble()));
296 case APFloat::Semantics::S_IEEEsingle:
297 return APFloat(sinhf(a.convertToFloat()));
298 default:
299 return {};
300 }
301 });
302}
303
304//===----------------------------------------------------------------------===//
305// SinCosOp
306//===----------------------------------------------------------------------===//
307
308std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
309 if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
310 return llvm::to_vector<4>(vt.getShape());
311 return std::nullopt;
312}
313
314LogicalResult math::SincosOp::fold(FoldAdaptor adaptor,
316 auto foldSincos = [](const APFloat &a, double (*fnDouble)(double),
317 float (*fnFloat)(float)) -> std::optional<APFloat> {
318 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
319 case APFloat::Semantics::S_IEEEdouble:
320 return APFloat(fnDouble(a.convertToDouble()));
321 case APFloat::Semantics::S_IEEEsingle:
322 return APFloat(fnFloat(a.convertToFloat()));
323 default:
324 return {};
325 }
326 };
327
329 adaptor.getOperands(),
330 [&](const APFloat &a) { return foldSincos(a, sin, sinf); });
332 adaptor.getOperands(),
333 [&](const APFloat &a) { return foldSincos(a, cos, cosf); });
334
335 if (sinRes && cosRes) {
336 result.push_back(sinRes);
337 result.push_back(cosRes);
338 return success();
339 }
340 return failure();
341}
342
343//===----------------------------------------------------------------------===//
344// CountLeadingZerosOp folder
345//===----------------------------------------------------------------------===//
346
347OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
349 adaptor.getOperands(),
350 [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
351}
352
353//===----------------------------------------------------------------------===//
354// CountTrailingZerosOp folder
355//===----------------------------------------------------------------------===//
356
357OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
359 adaptor.getOperands(),
360 [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
361}
362
363//===----------------------------------------------------------------------===//
364// CtPopOp folder
365//===----------------------------------------------------------------------===//
366
367OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
369 adaptor.getOperands(),
370 [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
371}
372
373//===----------------------------------------------------------------------===//
374// ErfOp folder
375//===----------------------------------------------------------------------===//
376
377OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
379 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
380 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
381 case APFloat::Semantics::S_IEEEdouble:
382 return APFloat(erf(a.convertToDouble()));
383 case APFloat::Semantics::S_IEEEsingle:
384 return APFloat(erff(a.convertToFloat()));
385 default:
386 return {};
387 }
388 });
389}
390
391//===----------------------------------------------------------------------===//
392// ErfcOp folder
393//===----------------------------------------------------------------------===//
394
395OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
397 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
398 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
399 case APFloat::Semantics::S_IEEEdouble:
400 return APFloat(erfc(a.convertToDouble()));
401 case APFloat::Semantics::S_IEEEsingle:
402 return APFloat(erfcf(a.convertToFloat()));
403 default:
404 return {};
405 }
406 });
407}
408
409//===----------------------------------------------------------------------===//
410// IPowIOp folder
411//===----------------------------------------------------------------------===//
412
413OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
415 adaptor.getOperands(),
416 [](const APInt &base, const APInt &power) -> std::optional<APInt> {
417 unsigned width = base.getBitWidth();
418 auto zeroValue = APInt::getZero(width);
419 // i1 folding is ambiguous with signed semantics, don't fold.
420 if (width == 1)
421 return {};
422 APInt oneValue{width, 1ULL, /*isSigned=*/true};
423 APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
424
425 if (power.isZero())
426 return oneValue;
427
428 if (power.isNegative()) {
429 // Leave 0 raised to negative power not folded.
430 if (base.isZero())
431 return {};
432 if (base.isOne())
433 return oneValue;
434 // If abs(base) > 1, then the result is zero.
435 if (base.ne(minusOneValue))
436 return zeroValue;
437 // base == -1:
438 // -1: power is odd
439 // 1: power is even
440 if (power[0] == 1)
441 return minusOneValue;
442
443 return oneValue;
444 }
445
446 // power is positive.
447 APInt result = oneValue;
448 APInt curBase = base;
449 APInt curPower = power;
450 while (true) {
451 if (curPower[0] == 1)
452 result *= curBase;
453 curPower.lshrInPlace(1);
454 if (curPower.isZero())
455 return result;
456 curBase *= curBase;
457 }
458 });
459
460 return Attribute();
461}
462
463//===----------------------------------------------------------------------===//
464// LogOp folder
465//===----------------------------------------------------------------------===//
466
467OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
469 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
470 if (a.isNegative())
471 return {};
472
473 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
474 case APFloat::Semantics::S_IEEEdouble:
475 return APFloat(log(a.convertToDouble()));
476 case APFloat::Semantics::S_IEEEsingle:
477 return APFloat(logf(a.convertToFloat()));
478 default:
479 return {};
480 }
481 });
482}
483
484//===----------------------------------------------------------------------===//
485// Log2Op folder
486//===----------------------------------------------------------------------===//
487
488OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
490 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
491 if (a.isNegative())
492 return {};
493
494 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
495 case APFloat::Semantics::S_IEEEdouble:
496 return APFloat(log2(a.convertToDouble()));
497 case APFloat::Semantics::S_IEEEsingle:
498 return APFloat(log2f(a.convertToFloat()));
499 default:
500 return {};
501 }
502 });
503}
504
505//===----------------------------------------------------------------------===//
506// Log10Op folder
507//===----------------------------------------------------------------------===//
508
509OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
511 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
512 if (a.isNegative())
513 return {};
514
515 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
516 case APFloat::Semantics::S_IEEEdouble:
517 return APFloat(log10(a.convertToDouble()));
518 case APFloat::Semantics::S_IEEEsingle:
519 return APFloat(log10f(a.convertToFloat()));
520 default:
521 return {};
522 }
523 });
524}
525
526//===----------------------------------------------------------------------===//
527// Log1pOp folder
528//===----------------------------------------------------------------------===//
529
530OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
532 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
533 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
534 case APFloat::Semantics::S_IEEEdouble:
535 if ((a + APFloat(1.0)).isNegative())
536 return {};
537 return APFloat(log1p(a.convertToDouble()));
538 case APFloat::Semantics::S_IEEEsingle:
539 if ((a + APFloat(1.0f)).isNegative())
540 return {};
541 return APFloat(log1pf(a.convertToFloat()));
542 default:
543 return {};
544 }
545 });
546}
547
548//===----------------------------------------------------------------------===//
549// PowFOp folder
550//===----------------------------------------------------------------------===//
551
552OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
554 adaptor.getOperands(),
555 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
556 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
557 case APFloat::Semantics::S_IEEEdouble:
558 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
559 case APFloat::Semantics::S_IEEEsingle:
560 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
561 default:
562 return {};
563 }
564 });
565}
566
567//===----------------------------------------------------------------------===//
568// RsqrtOp folder
569//===----------------------------------------------------------------------===//
570
571OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
573 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
574 if (a.isNegative())
575 return {};
576
577 APFloat one(a.getSemantics(), 1);
578 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
579 case APFloat::Semantics::S_IEEEdouble:
580 return one / APFloat(sqrt(a.convertToDouble()));
581 case APFloat::Semantics::S_IEEEsingle:
582 return one / APFloat(sqrtf(a.convertToFloat()));
583 default:
584 return {};
585 }
586 });
587}
588
589//===----------------------------------------------------------------------===//
590// SqrtOp folder
591//===----------------------------------------------------------------------===//
592
593OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
595 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
596 if (a.isNegative())
597 return {};
598
599 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
600 case APFloat::Semantics::S_IEEEdouble:
601 return APFloat(sqrt(a.convertToDouble()));
602 case APFloat::Semantics::S_IEEEsingle:
603 return APFloat(sqrtf(a.convertToFloat()));
604 default:
605 return {};
606 }
607 });
608}
609
610//===----------------------------------------------------------------------===//
611// ExpOp folder
612//===----------------------------------------------------------------------===//
613
614OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
616 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
617 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
618 case APFloat::Semantics::S_IEEEdouble:
619 return APFloat(exp(a.convertToDouble()));
620 case APFloat::Semantics::S_IEEEsingle:
621 return APFloat(expf(a.convertToFloat()));
622 default:
623 return {};
624 }
625 });
626}
627
628//===----------------------------------------------------------------------===//
629// Exp2Op folder
630//===----------------------------------------------------------------------===//
631
632OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
634 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
635 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
636 case APFloat::Semantics::S_IEEEdouble:
637 return APFloat(exp2(a.convertToDouble()));
638 case APFloat::Semantics::S_IEEEsingle:
639 return APFloat(exp2f(a.convertToFloat()));
640 default:
641 return {};
642 }
643 });
644}
645
646//===----------------------------------------------------------------------===//
647// ExpM1Op folder
648//===----------------------------------------------------------------------===//
649
650OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
652 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
653 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
654 case APFloat::Semantics::S_IEEEdouble:
655 return APFloat(expm1(a.convertToDouble()));
656 case APFloat::Semantics::S_IEEEsingle:
657 return APFloat(expm1f(a.convertToFloat()));
658 default:
659 return {};
660 }
661 });
662}
663
664//===----------------------------------------------------------------------===//
665// IsFiniteOp folder
666//===----------------------------------------------------------------------===//
667
668OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
669 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
670 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
671 }
672 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
674 cast<ShapedType>(getType()),
675 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
676 }
677 return {};
678}
679
680//===----------------------------------------------------------------------===//
681// IsInfOp folder
682//===----------------------------------------------------------------------===//
683
684OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
685 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
686 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
687 }
688 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
690 cast<ShapedType>(getType()),
691 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
692 }
693 return {};
694}
695
696//===----------------------------------------------------------------------===//
697// IsNaNOp folder
698//===----------------------------------------------------------------------===//
699
700OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
701 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
702 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
703 }
704 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
706 cast<ShapedType>(getType()),
707 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
708 }
709 return {};
710}
711
712//===----------------------------------------------------------------------===//
713// IsNormalOp folder
714//===----------------------------------------------------------------------===//
715
716OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
717 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
718 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
719 }
720 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
722 cast<ShapedType>(getType()),
723 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
724 }
725 return {};
726}
727
728//===----------------------------------------------------------------------===//
729// TanOp folder
730//===----------------------------------------------------------------------===//
731
732OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
734 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
735 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
736 case APFloat::Semantics::S_IEEEdouble:
737 return APFloat(tan(a.convertToDouble()));
738 case APFloat::Semantics::S_IEEEsingle:
739 return APFloat(tanf(a.convertToFloat()));
740 default:
741 return {};
742 }
743 });
744}
745
746//===----------------------------------------------------------------------===//
747// TanhOp folder
748//===----------------------------------------------------------------------===//
749
750OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
752 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
753 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
754 case APFloat::Semantics::S_IEEEdouble:
755 return APFloat(tanh(a.convertToDouble()));
756 case APFloat::Semantics::S_IEEEsingle:
757 return APFloat(tanhf(a.convertToFloat()));
758 default:
759 return {};
760 }
761 });
762}
763
764//===----------------------------------------------------------------------===//
765// RoundEvenOp folder
766//===----------------------------------------------------------------------===//
767
768OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
770 adaptor.getOperands(), [](const APFloat &a) {
771 APFloat result(a);
772 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
773 return result;
774 });
775}
776
777//===----------------------------------------------------------------------===//
778// FloorOp folder
779//===----------------------------------------------------------------------===//
780
781OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
783 adaptor.getOperands(), [](const APFloat &a) {
784 APFloat result(a);
785 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
786 return result;
787 });
788}
789
790//===----------------------------------------------------------------------===//
791// RoundOp folder
792//===----------------------------------------------------------------------===//
793
794OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
796 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
797 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
798 case APFloat::Semantics::S_IEEEdouble:
799 return APFloat(round(a.convertToDouble()));
800 case APFloat::Semantics::S_IEEEsingle:
801 return APFloat(roundf(a.convertToFloat()));
802 default:
803 return {};
804 }
805 });
806}
807
808//===----------------------------------------------------------------------===//
809// TruncOp folder
810//===----------------------------------------------------------------------===//
811
812OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
814 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
815 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
816 case APFloat::Semantics::S_IEEEdouble:
817 return APFloat(trunc(a.convertToDouble()));
818 case APFloat::Semantics::S_IEEEsingle:
819 return APFloat(truncf(a.convertToFloat()));
820 default:
821 return {};
822 }
823 });
824}
825
826//===----------------------------------------------------------------------===//
827// FPowIOp folder
828//===----------------------------------------------------------------------===//
829
830OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
832 adaptor.getOperands(),
833 [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
834 const llvm::fltSemantics &sem = base.getSemantics();
835 // Fold when the exponent is exactly representable in the
836 // floating-point type of the base.
837 APFloat fExp(sem);
838 if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
839 APFloat::rmNearestTiesToEven) !=
840 APFloat::opOK)
841 return {};
842
843 switch (APFloat::SemanticsToEnum(sem)) {
844 case APFloat::Semantics::S_IEEEdouble:
845 return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
846 case APFloat::Semantics::S_IEEEsingle:
847 return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
848 default:
849 return {};
850 }
851 });
852}
853
854/// Materialize an integer or floating point constant.
855Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
856 Attribute value, Type type,
857 Location loc) {
858 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
859 return ub::PoisonOp::create(builder, loc, type, poison);
860
861 return arith::ConstantOp::materialize(builder, value, type, loc);
862}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition MathOps.cpp:24
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
Include the generated interface declarations.
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)