MLIR  21.0.0git
MathOps.cpp
Go to the documentation of this file.
1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::math;
18 
19 //===----------------------------------------------------------------------===//
20 // Common helpers
21 //===----------------------------------------------------------------------===//
22 
23 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
24 static Type getI1SameShape(Type type) {
25  auto i1Type = IntegerType::get(type.getContext(), 1);
26  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27  return shapedType.cloneWith(std::nullopt, i1Type);
28  if (llvm::isa<UnrankedTensorType>(type))
29  return UnrankedTensorType::get(i1Type);
30  return i1Type;
31 }
32 
33 //===----------------------------------------------------------------------===//
34 // TableGen'd op method definitions
35 //===----------------------------------------------------------------------===//
36 
37 #define GET_OP_CLASSES
38 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // AbsFOp folder
42 //===----------------------------------------------------------------------===//
43 
44 OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46  [](const APFloat &a) { return abs(a); });
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // AbsIOp folder
51 //===----------------------------------------------------------------------===//
52 
53 OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55  [](const APInt &a) { return a.abs(); });
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // AcosOp folder
60 //===----------------------------------------------------------------------===//
61 
62 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
63  return constFoldUnaryOpConditional<FloatAttr>(
64  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65  switch (a.getSizeInBits(a.getSemantics())) {
66  case 64:
67  return APFloat(acos(a.convertToDouble()));
68  case 32:
69  return APFloat(acosf(a.convertToFloat()));
70  default:
71  return {};
72  }
73  });
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // AcoshOp folder
78 //===----------------------------------------------------------------------===//
79 
80 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
81  return constFoldUnaryOpConditional<FloatAttr>(
82  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83  switch (a.getSizeInBits(a.getSemantics())) {
84  case 64:
85  return APFloat(acosh(a.convertToDouble()));
86  case 32:
87  return APFloat(acoshf(a.convertToFloat()));
88  default:
89  return {};
90  }
91  });
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // AsinOp folder
96 //===----------------------------------------------------------------------===//
97 
98 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
99  return constFoldUnaryOpConditional<FloatAttr>(
100  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101  switch (a.getSizeInBits(a.getSemantics())) {
102  case 64:
103  return APFloat(asin(a.convertToDouble()));
104  case 32:
105  return APFloat(asinf(a.convertToFloat()));
106  default:
107  return {};
108  }
109  });
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // AsinhOp folder
114 //===----------------------------------------------------------------------===//
115 
116 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
117  return constFoldUnaryOpConditional<FloatAttr>(
118  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119  switch (a.getSizeInBits(a.getSemantics())) {
120  case 64:
121  return APFloat(asinh(a.convertToDouble()));
122  case 32:
123  return APFloat(asinhf(a.convertToFloat()));
124  default:
125  return {};
126  }
127  });
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // AtanOp folder
132 //===----------------------------------------------------------------------===//
133 
134 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
135  return constFoldUnaryOpConditional<FloatAttr>(
136  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137  switch (a.getSizeInBits(a.getSemantics())) {
138  case 64:
139  return APFloat(atan(a.convertToDouble()));
140  case 32:
141  return APFloat(atanf(a.convertToFloat()));
142  default:
143  return {};
144  }
145  });
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AtanhOp folder
150 //===----------------------------------------------------------------------===//
151 
152 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
153  return constFoldUnaryOpConditional<FloatAttr>(
154  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155  switch (a.getSizeInBits(a.getSemantics())) {
156  case 64:
157  return APFloat(atanh(a.convertToDouble()));
158  case 32:
159  return APFloat(atanhf(a.convertToFloat()));
160  default:
161  return {};
162  }
163  });
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Atan2Op folder
168 //===----------------------------------------------------------------------===//
169 
170 OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
171  return constFoldBinaryOpConditional<FloatAttr>(
172  adaptor.getOperands(),
173  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174  if (a.isZero() && b.isZero())
175  return llvm::APFloat::getNaN(a.getSemantics());
176 
177  if (a.getSizeInBits(a.getSemantics()) == 64 &&
178  b.getSizeInBits(b.getSemantics()) == 64)
179  return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180 
181  if (a.getSizeInBits(a.getSemantics()) == 32 &&
182  b.getSizeInBits(b.getSemantics()) == 32)
183  return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
184 
185  return {};
186  });
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // CeilOp folder
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
194  return constFoldUnaryOp<FloatAttr>(
195  adaptor.getOperands(), [](const APFloat &a) {
196  APFloat result(a);
197  result.roundToIntegral(llvm::RoundingMode::TowardPositive);
198  return result;
199  });
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // CopySignOp folder
204 //===----------------------------------------------------------------------===//
205 
206 OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208  [](const APFloat &a, const APFloat &b) {
209  APFloat result(a);
210  result.copySign(b);
211  return result;
212  });
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // CosOp folder
217 //===----------------------------------------------------------------------===//
218 
219 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
220  return constFoldUnaryOpConditional<FloatAttr>(
221  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
222  switch (a.getSizeInBits(a.getSemantics())) {
223  case 64:
224  return APFloat(cos(a.convertToDouble()));
225  case 32:
226  return APFloat(cosf(a.convertToFloat()));
227  default:
228  return {};
229  }
230  });
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // CoshOp folder
235 //===----------------------------------------------------------------------===//
236 
237 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
238  return constFoldUnaryOpConditional<FloatAttr>(
239  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
240  switch (a.getSizeInBits(a.getSemantics())) {
241  case 64:
242  return APFloat(cosh(a.convertToDouble()));
243  case 32:
244  return APFloat(coshf(a.convertToFloat()));
245  default:
246  return {};
247  }
248  });
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // SinOp folder
253 //===----------------------------------------------------------------------===//
254 
255 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
256  return constFoldUnaryOpConditional<FloatAttr>(
257  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
258  switch (a.getSizeInBits(a.getSemantics())) {
259  case 64:
260  return APFloat(sin(a.convertToDouble()));
261  case 32:
262  return APFloat(sinf(a.convertToFloat()));
263  default:
264  return {};
265  }
266  });
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // SinhOp folder
271 //===----------------------------------------------------------------------===//
272 
273 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
274  return constFoldUnaryOpConditional<FloatAttr>(
275  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
276  switch (a.getSizeInBits(a.getSemantics())) {
277  case 64:
278  return APFloat(sinh(a.convertToDouble()));
279  case 32:
280  return APFloat(sinhf(a.convertToFloat()));
281  default:
282  return {};
283  }
284  });
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // CountLeadingZerosOp folder
289 //===----------------------------------------------------------------------===//
290 
291 OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
292  return constFoldUnaryOp<IntegerAttr>(
293  adaptor.getOperands(),
294  [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // CountTrailingZerosOp folder
299 //===----------------------------------------------------------------------===//
300 
301 OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
302  return constFoldUnaryOp<IntegerAttr>(
303  adaptor.getOperands(),
304  [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // CtPopOp folder
309 //===----------------------------------------------------------------------===//
310 
311 OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
312  return constFoldUnaryOp<IntegerAttr>(
313  adaptor.getOperands(),
314  [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // ErfOp folder
319 //===----------------------------------------------------------------------===//
320 
321 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
322  return constFoldUnaryOpConditional<FloatAttr>(
323  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
324  switch (a.getSizeInBits(a.getSemantics())) {
325  case 64:
326  return APFloat(erf(a.convertToDouble()));
327  case 32:
328  return APFloat(erff(a.convertToFloat()));
329  default:
330  return {};
331  }
332  });
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // ErfcOp folder
337 //===----------------------------------------------------------------------===//
338 
339 OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
340  return constFoldUnaryOpConditional<FloatAttr>(
341  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
342  switch (APFloat::SemanticsToEnum(a.getSemantics())) {
343  case APFloat::Semantics::S_IEEEdouble:
344  return APFloat(erfc(a.convertToDouble()));
345  case APFloat::Semantics::S_IEEEsingle:
346  return APFloat(erfcf(a.convertToFloat()));
347  default:
348  return {};
349  }
350  });
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // IPowIOp folder
355 //===----------------------------------------------------------------------===//
356 
357 OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
358  return constFoldBinaryOpConditional<IntegerAttr>(
359  adaptor.getOperands(),
360  [](const APInt &base, const APInt &power) -> std::optional<APInt> {
361  unsigned width = base.getBitWidth();
362  auto zeroValue = APInt::getZero(width);
363  APInt oneValue{width, 1ULL, /*isSigned=*/true};
364  APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
365 
366  if (power.isZero())
367  return oneValue;
368 
369  if (power.isNegative()) {
370  // Leave 0 raised to negative power not folded.
371  if (base.isZero())
372  return {};
373  if (base.eq(oneValue))
374  return oneValue;
375  // If abs(base) > 1, then the result is zero.
376  if (base.ne(minusOneValue))
377  return zeroValue;
378  // base == -1:
379  // -1: power is odd
380  // 1: power is even
381  if (power[0] == 1)
382  return minusOneValue;
383 
384  return oneValue;
385  }
386 
387  // power is positive.
388  APInt result = oneValue;
389  APInt curBase = base;
390  APInt curPower = power;
391  while (true) {
392  if (curPower[0] == 1)
393  result *= curBase;
394  curPower.lshrInPlace(1);
395  if (curPower.isZero())
396  return result;
397  curBase *= curBase;
398  }
399  });
400 
401  return Attribute();
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // LogOp folder
406 //===----------------------------------------------------------------------===//
407 
408 OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
409  return constFoldUnaryOpConditional<FloatAttr>(
410  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
411  if (a.isNegative())
412  return {};
413 
414  if (a.getSizeInBits(a.getSemantics()) == 64)
415  return APFloat(log(a.convertToDouble()));
416 
417  if (a.getSizeInBits(a.getSemantics()) == 32)
418  return APFloat(logf(a.convertToFloat()));
419 
420  return {};
421  });
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // Log2Op folder
426 //===----------------------------------------------------------------------===//
427 
428 OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
429  return constFoldUnaryOpConditional<FloatAttr>(
430  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
431  if (a.isNegative())
432  return {};
433 
434  if (a.getSizeInBits(a.getSemantics()) == 64)
435  return APFloat(log2(a.convertToDouble()));
436 
437  if (a.getSizeInBits(a.getSemantics()) == 32)
438  return APFloat(log2f(a.convertToFloat()));
439 
440  return {};
441  });
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // Log10Op folder
446 //===----------------------------------------------------------------------===//
447 
448 OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
449  return constFoldUnaryOpConditional<FloatAttr>(
450  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
451  if (a.isNegative())
452  return {};
453 
454  switch (a.getSizeInBits(a.getSemantics())) {
455  case 64:
456  return APFloat(log10(a.convertToDouble()));
457  case 32:
458  return APFloat(log10f(a.convertToFloat()));
459  default:
460  return {};
461  }
462  });
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // Log1pOp folder
467 //===----------------------------------------------------------------------===//
468 
469 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
470  return constFoldUnaryOpConditional<FloatAttr>(
471  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
472  switch (a.getSizeInBits(a.getSemantics())) {
473  case 64:
474  if ((a + APFloat(1.0)).isNegative())
475  return {};
476  return APFloat(log1p(a.convertToDouble()));
477  case 32:
478  if ((a + APFloat(1.0f)).isNegative())
479  return {};
480  return APFloat(log1pf(a.convertToFloat()));
481  default:
482  return {};
483  }
484  });
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // PowFOp folder
489 //===----------------------------------------------------------------------===//
490 
491 OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
492  return constFoldBinaryOpConditional<FloatAttr>(
493  adaptor.getOperands(),
494  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
495  if (a.getSizeInBits(a.getSemantics()) == 64 &&
496  b.getSizeInBits(b.getSemantics()) == 64)
497  return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
498 
499  if (a.getSizeInBits(a.getSemantics()) == 32 &&
500  b.getSizeInBits(b.getSemantics()) == 32)
501  return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
502 
503  return {};
504  });
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // SqrtOp folder
509 //===----------------------------------------------------------------------===//
510 
511 OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
512  return constFoldUnaryOpConditional<FloatAttr>(
513  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
514  if (a.isNegative())
515  return {};
516 
517  switch (a.getSizeInBits(a.getSemantics())) {
518  case 64:
519  return APFloat(sqrt(a.convertToDouble()));
520  case 32:
521  return APFloat(sqrtf(a.convertToFloat()));
522  default:
523  return {};
524  }
525  });
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // ExpOp folder
530 //===----------------------------------------------------------------------===//
531 
532 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
533  return constFoldUnaryOpConditional<FloatAttr>(
534  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
535  switch (a.getSizeInBits(a.getSemantics())) {
536  case 64:
537  return APFloat(exp(a.convertToDouble()));
538  case 32:
539  return APFloat(expf(a.convertToFloat()));
540  default:
541  return {};
542  }
543  });
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // Exp2Op folder
548 //===----------------------------------------------------------------------===//
549 
550 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
551  return constFoldUnaryOpConditional<FloatAttr>(
552  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
553  switch (a.getSizeInBits(a.getSemantics())) {
554  case 64:
555  return APFloat(exp2(a.convertToDouble()));
556  case 32:
557  return APFloat(exp2f(a.convertToFloat()));
558  default:
559  return {};
560  }
561  });
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // ExpM1Op folder
566 //===----------------------------------------------------------------------===//
567 
568 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
569  return constFoldUnaryOpConditional<FloatAttr>(
570  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
571  switch (a.getSizeInBits(a.getSemantics())) {
572  case 64:
573  return APFloat(expm1(a.convertToDouble()));
574  case 32:
575  return APFloat(expm1f(a.convertToFloat()));
576  default:
577  return {};
578  }
579  });
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // IsFiniteOp folder
584 //===----------------------------------------------------------------------===//
585 
586 OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
587  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
588  return BoolAttr::get(val.getContext(), val.getValue().isFinite());
589  }
590  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
591  return DenseElementsAttr::get(
592  cast<ShapedType>(getType()),
593  APInt(1, splat.getSplatValue<APFloat>().isFinite()));
594  }
595  return {};
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // IsInfOp folder
600 //===----------------------------------------------------------------------===//
601 
602 OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
603  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
604  return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
605  }
606  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
607  return DenseElementsAttr::get(
608  cast<ShapedType>(getType()),
609  APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
610  }
611  return {};
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // IsNaNOp folder
616 //===----------------------------------------------------------------------===//
617 
618 OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
619  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
620  return BoolAttr::get(val.getContext(), val.getValue().isNaN());
621  }
622  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
623  return DenseElementsAttr::get(
624  cast<ShapedType>(getType()),
625  APInt(1, splat.getSplatValue<APFloat>().isNaN()));
626  }
627  return {};
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // IsNormalOp folder
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
635  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
636  return BoolAttr::get(val.getContext(), val.getValue().isNormal());
637  }
638  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
639  return DenseElementsAttr::get(
640  cast<ShapedType>(getType()),
641  APInt(1, splat.getSplatValue<APFloat>().isNormal()));
642  }
643  return {};
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // TanOp folder
648 //===----------------------------------------------------------------------===//
649 
650 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
651  return constFoldUnaryOpConditional<FloatAttr>(
652  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
653  switch (a.getSizeInBits(a.getSemantics())) {
654  case 64:
655  return APFloat(tan(a.convertToDouble()));
656  case 32:
657  return APFloat(tanf(a.convertToFloat()));
658  default:
659  return {};
660  }
661  });
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // TanhOp folder
666 //===----------------------------------------------------------------------===//
667 
668 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
669  return constFoldUnaryOpConditional<FloatAttr>(
670  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
671  switch (a.getSizeInBits(a.getSemantics())) {
672  case 64:
673  return APFloat(tanh(a.convertToDouble()));
674  case 32:
675  return APFloat(tanhf(a.convertToFloat()));
676  default:
677  return {};
678  }
679  });
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // RoundEvenOp folder
684 //===----------------------------------------------------------------------===//
685 
686 OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
687  return constFoldUnaryOp<FloatAttr>(
688  adaptor.getOperands(), [](const APFloat &a) {
689  APFloat result(a);
690  result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
691  return result;
692  });
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // FloorOp folder
697 //===----------------------------------------------------------------------===//
698 
699 OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
700  return constFoldUnaryOp<FloatAttr>(
701  adaptor.getOperands(), [](const APFloat &a) {
702  APFloat result(a);
703  result.roundToIntegral(llvm::RoundingMode::TowardNegative);
704  return result;
705  });
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // RoundOp folder
710 //===----------------------------------------------------------------------===//
711 
712 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
713  return constFoldUnaryOpConditional<FloatAttr>(
714  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
715  switch (a.getSizeInBits(a.getSemantics())) {
716  case 64:
717  return APFloat(round(a.convertToDouble()));
718  case 32:
719  return APFloat(roundf(a.convertToFloat()));
720  default:
721  return {};
722  }
723  });
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // TruncOp folder
728 //===----------------------------------------------------------------------===//
729 
730 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
731  return constFoldUnaryOpConditional<FloatAttr>(
732  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
733  switch (a.getSizeInBits(a.getSemantics())) {
734  case 64:
735  return APFloat(trunc(a.convertToDouble()));
736  case 32:
737  return APFloat(truncf(a.convertToFloat()));
738  default:
739  return {};
740  }
741  });
742 }
743 
744 /// Materialize an integer or floating point constant.
746  Attribute value, Type type,
747  Location loc) {
748  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
749  return builder.create<ub::PoisonOp>(loc, type, poison);
750 
751  return arith::ConstantOp::materialize(builder, value, type, loc);
752 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: MathOps.cpp:24
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...