MLIR 23.0.0git
MathOps.cpp
Go to the documentation of this file.
1//===- MathOps.cpp - MLIR operations for math implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::math;
18
19//===----------------------------------------------------------------------===//
20// Common helpers
21//===----------------------------------------------------------------------===//
22
23/// Return the type of the same shape (scalar, vector or tensor) containing i1.
24static Type getI1SameShape(Type type) {
25 auto i1Type = IntegerType::get(type.getContext(), 1);
26 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
30 return i1Type;
31}
32
33//===----------------------------------------------------------------------===//
34// TableGen'd op method definitions
35//===----------------------------------------------------------------------===//
36
37#define GET_OP_CLASSES
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39
40//===----------------------------------------------------------------------===//
41// AbsFOp folder
42//===----------------------------------------------------------------------===//
43
44OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](const APFloat &a) { return abs(a); });
47}
48
49//===----------------------------------------------------------------------===//
50// AbsIOp folder
51//===----------------------------------------------------------------------===//
52
53OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](const APInt &a) { return a.abs(); });
56}
57
58//===----------------------------------------------------------------------===//
59// AcosOp folder
60//===----------------------------------------------------------------------===//
61
62OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
64 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
66 case 64:
67 return APFloat(acos(a.convertToDouble()));
68 case 32:
69 return APFloat(acosf(a.convertToFloat()));
70 default:
71 return {};
72 }
73 });
74}
75
76//===----------------------------------------------------------------------===//
77// AcoshOp folder
78//===----------------------------------------------------------------------===//
79
80OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
82 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
84 case 64:
85 return APFloat(acosh(a.convertToDouble()));
86 case 32:
87 return APFloat(acoshf(a.convertToFloat()));
88 default:
89 return {};
90 }
91 });
92}
93
94//===----------------------------------------------------------------------===//
95// AsinOp folder
96//===----------------------------------------------------------------------===//
97
98OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
100 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
102 case 64:
103 return APFloat(asin(a.convertToDouble()));
104 case 32:
105 return APFloat(asinf(a.convertToFloat()));
106 default:
107 return {};
108 }
109 });
110}
111
112//===----------------------------------------------------------------------===//
113// AsinhOp folder
114//===----------------------------------------------------------------------===//
115
116OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
118 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
120 case 64:
121 return APFloat(asinh(a.convertToDouble()));
122 case 32:
123 return APFloat(asinhf(a.convertToFloat()));
124 default:
125 return {};
126 }
127 });
128}
129
130//===----------------------------------------------------------------------===//
131// AtanOp folder
132//===----------------------------------------------------------------------===//
133
134OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
136 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
138 case 64:
139 return APFloat(atan(a.convertToDouble()));
140 case 32:
141 return APFloat(atanf(a.convertToFloat()));
142 default:
143 return {};
144 }
145 });
146}
147
148//===----------------------------------------------------------------------===//
149// AtanhOp folder
150//===----------------------------------------------------------------------===//
151
152OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
154 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
156 case 64:
157 return APFloat(atanh(a.convertToDouble()));
158 case 32:
159 return APFloat(atanhf(a.convertToFloat()));
160 default:
161 return {};
162 }
163 });
164}
165
166//===----------------------------------------------------------------------===//
167// Atan2Op folder
168//===----------------------------------------------------------------------===//
169
170OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
172 adaptor.getOperands(),
173 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
176
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
184
185 return {};
186 });
187}
188
189//===----------------------------------------------------------------------===//
190// CeilOp folder
191//===----------------------------------------------------------------------===//
192
193OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
195 adaptor.getOperands(), [](const APFloat &a) {
196 APFloat result(a);
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
198 return result;
199 });
200}
201
202//===----------------------------------------------------------------------===//
203// CopySignOp folder
204//===----------------------------------------------------------------------===//
205
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208 [](const APFloat &a, const APFloat &b) {
209 APFloat result(a);
210 result.copySign(b);
211 return result;
212 });
213}
214
215//===----------------------------------------------------------------------===//
216// CosOp folder
217//===----------------------------------------------------------------------===//
218
219OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
221 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
223 case 64:
224 return APFloat(cos(a.convertToDouble()));
225 case 32:
226 return APFloat(cosf(a.convertToFloat()));
227 default:
228 return {};
229 }
230 });
231}
232
233//===----------------------------------------------------------------------===//
234// CoshOp folder
235//===----------------------------------------------------------------------===//
236
237OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
239 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
241 case 64:
242 return APFloat(cosh(a.convertToDouble()));
243 case 32:
244 return APFloat(coshf(a.convertToFloat()));
245 default:
246 return {};
247 }
248 });
249}
250
251//===----------------------------------------------------------------------===//
252// SinOp folder
253//===----------------------------------------------------------------------===//
254
255OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
257 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
259 case 64:
260 return APFloat(sin(a.convertToDouble()));
261 case 32:
262 return APFloat(sinf(a.convertToFloat()));
263 default:
264 return {};
265 }
266 });
267}
268
269//===----------------------------------------------------------------------===//
270// SinhOp folder
271//===----------------------------------------------------------------------===//
272
273OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
275 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
277 case 64:
278 return APFloat(sinh(a.convertToDouble()));
279 case 32:
280 return APFloat(sinhf(a.convertToFloat()));
281 default:
282 return {};
283 }
284 });
285}
286
287//===----------------------------------------------------------------------===//
288// SinCosOp getShapeForUnroll
289//===----------------------------------------------------------------------===//
290
291std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
293 return llvm::to_vector<4>(vt.getShape());
294 return std::nullopt;
295}
296
297//===----------------------------------------------------------------------===//
298// CountLeadingZerosOp folder
299//===----------------------------------------------------------------------===//
300
301OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
303 adaptor.getOperands(),
304 [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
305}
306
307//===----------------------------------------------------------------------===//
308// CountTrailingZerosOp folder
309//===----------------------------------------------------------------------===//
310
311OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
313 adaptor.getOperands(),
314 [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
315}
316
317//===----------------------------------------------------------------------===//
318// CtPopOp folder
319//===----------------------------------------------------------------------===//
320
321OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
323 adaptor.getOperands(),
324 [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
325}
326
327//===----------------------------------------------------------------------===//
328// ErfOp folder
329//===----------------------------------------------------------------------===//
330
331OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
333 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
335 case 64:
336 return APFloat(erf(a.convertToDouble()));
337 case 32:
338 return APFloat(erff(a.convertToFloat()));
339 default:
340 return {};
341 }
342 });
343}
344
345//===----------------------------------------------------------------------===//
346// ErfcOp folder
347//===----------------------------------------------------------------------===//
348
349OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
351 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
357 default:
358 return {};
359 }
360 });
361}
362
363//===----------------------------------------------------------------------===//
364// IPowIOp folder
365//===----------------------------------------------------------------------===//
366
367OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
369 adaptor.getOperands(),
370 [](const APInt &base, const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
373 // i1 folding is ambiguous with signed semantics, don't fold.
374 if (width == 1)
375 return {};
376 APInt oneValue{width, 1ULL, /*isSigned=*/true};
377 APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
378
379 if (power.isZero())
380 return oneValue;
381
382 if (power.isNegative()) {
383 // Leave 0 raised to negative power not folded.
384 if (base.isZero())
385 return {};
386 if (base.isOne())
387 return oneValue;
388 // If abs(base) > 1, then the result is zero.
389 if (base.ne(minusOneValue))
390 return zeroValue;
391 // base == -1:
392 // -1: power is odd
393 // 1: power is even
394 if (power[0] == 1)
395 return minusOneValue;
396
397 return oneValue;
398 }
399
400 // power is positive.
401 APInt result = oneValue;
402 APInt curBase = base;
403 APInt curPower = power;
404 while (true) {
405 if (curPower[0] == 1)
406 result *= curBase;
407 curPower.lshrInPlace(1);
408 if (curPower.isZero())
409 return result;
410 curBase *= curBase;
411 }
412 });
413
414 return Attribute();
415}
416
417//===----------------------------------------------------------------------===//
418// LogOp folder
419//===----------------------------------------------------------------------===//
420
421OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
423 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
424 if (a.isNegative())
425 return {};
426
427 if (a.getSizeInBits(a.getSemantics()) == 64)
428 return APFloat(log(a.convertToDouble()));
429
430 if (a.getSizeInBits(a.getSemantics()) == 32)
431 return APFloat(logf(a.convertToFloat()));
432
433 return {};
434 });
435}
436
437//===----------------------------------------------------------------------===//
438// Log2Op folder
439//===----------------------------------------------------------------------===//
440
441OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
443 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
444 if (a.isNegative())
445 return {};
446
447 if (a.getSizeInBits(a.getSemantics()) == 64)
448 return APFloat(log2(a.convertToDouble()));
449
450 if (a.getSizeInBits(a.getSemantics()) == 32)
451 return APFloat(log2f(a.convertToFloat()));
452
453 return {};
454 });
455}
456
457//===----------------------------------------------------------------------===//
458// Log10Op folder
459//===----------------------------------------------------------------------===//
460
461OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
463 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
464 if (a.isNegative())
465 return {};
466
467 switch (a.getSizeInBits(a.getSemantics())) {
468 case 64:
469 return APFloat(log10(a.convertToDouble()));
470 case 32:
471 return APFloat(log10f(a.convertToFloat()));
472 default:
473 return {};
474 }
475 });
476}
477
478//===----------------------------------------------------------------------===//
479// Log1pOp folder
480//===----------------------------------------------------------------------===//
481
482OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
484 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
485 switch (a.getSizeInBits(a.getSemantics())) {
486 case 64:
487 if ((a + APFloat(1.0)).isNegative())
488 return {};
489 return APFloat(log1p(a.convertToDouble()));
490 case 32:
491 if ((a + APFloat(1.0f)).isNegative())
492 return {};
493 return APFloat(log1pf(a.convertToFloat()));
494 default:
495 return {};
496 }
497 });
498}
499
500//===----------------------------------------------------------------------===//
501// PowFOp folder
502//===----------------------------------------------------------------------===//
503
504OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
506 adaptor.getOperands(),
507 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
508 if (a.getSizeInBits(a.getSemantics()) == 64 &&
509 b.getSizeInBits(b.getSemantics()) == 64)
510 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
511
512 if (a.getSizeInBits(a.getSemantics()) == 32 &&
513 b.getSizeInBits(b.getSemantics()) == 32)
514 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
515
516 return {};
517 });
518}
519
520//===----------------------------------------------------------------------===//
521// RsqrtOp folder
522//===----------------------------------------------------------------------===//
523
524OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
526 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
527 if (a.isNegative())
528 return {};
529
530 APFloat one(a.getSemantics(), 1);
531 switch (a.getSizeInBits(a.getSemantics())) {
532 case 64:
533 return one / APFloat(sqrt(a.convertToDouble()));
534 case 32:
535 return one / APFloat(sqrtf(a.convertToFloat()));
536 default:
537 return {};
538 }
539 });
540}
541
542//===----------------------------------------------------------------------===//
543// SqrtOp folder
544//===----------------------------------------------------------------------===//
545
546OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
548 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
549 if (a.isNegative())
550 return {};
551
552 switch (a.getSizeInBits(a.getSemantics())) {
553 case 64:
554 return APFloat(sqrt(a.convertToDouble()));
555 case 32:
556 return APFloat(sqrtf(a.convertToFloat()));
557 default:
558 return {};
559 }
560 });
561}
562
563//===----------------------------------------------------------------------===//
564// ExpOp folder
565//===----------------------------------------------------------------------===//
566
567OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
569 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
570 switch (a.getSizeInBits(a.getSemantics())) {
571 case 64:
572 return APFloat(exp(a.convertToDouble()));
573 case 32:
574 return APFloat(expf(a.convertToFloat()));
575 default:
576 return {};
577 }
578 });
579}
580
581//===----------------------------------------------------------------------===//
582// Exp2Op folder
583//===----------------------------------------------------------------------===//
584
585OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
587 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
588 switch (a.getSizeInBits(a.getSemantics())) {
589 case 64:
590 return APFloat(exp2(a.convertToDouble()));
591 case 32:
592 return APFloat(exp2f(a.convertToFloat()));
593 default:
594 return {};
595 }
596 });
597}
598
599//===----------------------------------------------------------------------===//
600// ExpM1Op folder
601//===----------------------------------------------------------------------===//
602
603OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
605 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
606 switch (a.getSizeInBits(a.getSemantics())) {
607 case 64:
608 return APFloat(expm1(a.convertToDouble()));
609 case 32:
610 return APFloat(expm1f(a.convertToFloat()));
611 default:
612 return {};
613 }
614 });
615}
616
617//===----------------------------------------------------------------------===//
618// IsFiniteOp folder
619//===----------------------------------------------------------------------===//
620
621OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
622 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
623 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
624 }
625 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
627 cast<ShapedType>(getType()),
628 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
629 }
630 return {};
631}
632
633//===----------------------------------------------------------------------===//
634// IsInfOp folder
635//===----------------------------------------------------------------------===//
636
637OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
638 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
639 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
640 }
641 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
643 cast<ShapedType>(getType()),
644 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
645 }
646 return {};
647}
648
649//===----------------------------------------------------------------------===//
650// IsNaNOp folder
651//===----------------------------------------------------------------------===//
652
653OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
654 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
655 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
656 }
657 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
659 cast<ShapedType>(getType()),
660 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
661 }
662 return {};
663}
664
665//===----------------------------------------------------------------------===//
666// IsNormalOp folder
667//===----------------------------------------------------------------------===//
668
669OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
670 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
671 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
672 }
673 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
675 cast<ShapedType>(getType()),
676 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
677 }
678 return {};
679}
680
681//===----------------------------------------------------------------------===//
682// TanOp folder
683//===----------------------------------------------------------------------===//
684
685OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
687 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
688 switch (a.getSizeInBits(a.getSemantics())) {
689 case 64:
690 return APFloat(tan(a.convertToDouble()));
691 case 32:
692 return APFloat(tanf(a.convertToFloat()));
693 default:
694 return {};
695 }
696 });
697}
698
699//===----------------------------------------------------------------------===//
700// TanhOp folder
701//===----------------------------------------------------------------------===//
702
703OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
705 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
706 switch (a.getSizeInBits(a.getSemantics())) {
707 case 64:
708 return APFloat(tanh(a.convertToDouble()));
709 case 32:
710 return APFloat(tanhf(a.convertToFloat()));
711 default:
712 return {};
713 }
714 });
715}
716
717//===----------------------------------------------------------------------===//
718// RoundEvenOp folder
719//===----------------------------------------------------------------------===//
720
721OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
723 adaptor.getOperands(), [](const APFloat &a) {
724 APFloat result(a);
725 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
726 return result;
727 });
728}
729
730//===----------------------------------------------------------------------===//
731// FloorOp folder
732//===----------------------------------------------------------------------===//
733
734OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
736 adaptor.getOperands(), [](const APFloat &a) {
737 APFloat result(a);
738 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
739 return result;
740 });
741}
742
743//===----------------------------------------------------------------------===//
744// RoundOp folder
745//===----------------------------------------------------------------------===//
746
747OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
749 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
750 switch (a.getSizeInBits(a.getSemantics())) {
751 case 64:
752 return APFloat(round(a.convertToDouble()));
753 case 32:
754 return APFloat(roundf(a.convertToFloat()));
755 default:
756 return {};
757 }
758 });
759}
760
761//===----------------------------------------------------------------------===//
762// TruncOp folder
763//===----------------------------------------------------------------------===//
764
765OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
767 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
768 switch (a.getSizeInBits(a.getSemantics())) {
769 case 64:
770 return APFloat(trunc(a.convertToDouble()));
771 case 32:
772 return APFloat(truncf(a.convertToFloat()));
773 default:
774 return {};
775 }
776 });
777}
778
779/// Materialize an integer or floating point constant.
780Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
781 Attribute value, Type type,
782 Location loc) {
783 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
784 return ub::PoisonOp::create(builder, loc, type, poison);
785
786 return arith::ConstantOp::materialize(builder, value, type, loc);
787}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition MathOps.cpp:24
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)