MLIR  22.0.0git
MathOps.cpp
Go to the documentation of this file.
1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::math;
18 
19 //===----------------------------------------------------------------------===//
20 // Common helpers
21 //===----------------------------------------------------------------------===//
22 
23 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
24 static Type getI1SameShape(Type type) {
25  auto i1Type = IntegerType::get(type.getContext(), 1);
26  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27  return shapedType.cloneWith(std::nullopt, i1Type);
28  if (llvm::isa<UnrankedTensorType>(type))
29  return UnrankedTensorType::get(i1Type);
30  return i1Type;
31 }
32 
33 //===----------------------------------------------------------------------===//
34 // TableGen'd op method definitions
35 //===----------------------------------------------------------------------===//
36 
37 #define GET_OP_CLASSES
38 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // AbsFOp folder
42 //===----------------------------------------------------------------------===//
43 
44 OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46  [](const APFloat &a) { return abs(a); });
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // AbsIOp folder
51 //===----------------------------------------------------------------------===//
52 
53 OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55  [](const APInt &a) { return a.abs(); });
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // AcosOp folder
60 //===----------------------------------------------------------------------===//
61 
62 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
63  return constFoldUnaryOpConditional<FloatAttr>(
64  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65  switch (a.getSizeInBits(a.getSemantics())) {
66  case 64:
67  return APFloat(acos(a.convertToDouble()));
68  case 32:
69  return APFloat(acosf(a.convertToFloat()));
70  default:
71  return {};
72  }
73  });
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // AcoshOp folder
78 //===----------------------------------------------------------------------===//
79 
80 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
81  return constFoldUnaryOpConditional<FloatAttr>(
82  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83  switch (a.getSizeInBits(a.getSemantics())) {
84  case 64:
85  return APFloat(acosh(a.convertToDouble()));
86  case 32:
87  return APFloat(acoshf(a.convertToFloat()));
88  default:
89  return {};
90  }
91  });
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // AsinOp folder
96 //===----------------------------------------------------------------------===//
97 
98 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
99  return constFoldUnaryOpConditional<FloatAttr>(
100  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101  switch (a.getSizeInBits(a.getSemantics())) {
102  case 64:
103  return APFloat(asin(a.convertToDouble()));
104  case 32:
105  return APFloat(asinf(a.convertToFloat()));
106  default:
107  return {};
108  }
109  });
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // AsinhOp folder
114 //===----------------------------------------------------------------------===//
115 
116 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
117  return constFoldUnaryOpConditional<FloatAttr>(
118  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119  switch (a.getSizeInBits(a.getSemantics())) {
120  case 64:
121  return APFloat(asinh(a.convertToDouble()));
122  case 32:
123  return APFloat(asinhf(a.convertToFloat()));
124  default:
125  return {};
126  }
127  });
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // AtanOp folder
132 //===----------------------------------------------------------------------===//
133 
134 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
135  return constFoldUnaryOpConditional<FloatAttr>(
136  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137  switch (a.getSizeInBits(a.getSemantics())) {
138  case 64:
139  return APFloat(atan(a.convertToDouble()));
140  case 32:
141  return APFloat(atanf(a.convertToFloat()));
142  default:
143  return {};
144  }
145  });
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AtanhOp folder
150 //===----------------------------------------------------------------------===//
151 
152 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
153  return constFoldUnaryOpConditional<FloatAttr>(
154  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155  switch (a.getSizeInBits(a.getSemantics())) {
156  case 64:
157  return APFloat(atanh(a.convertToDouble()));
158  case 32:
159  return APFloat(atanhf(a.convertToFloat()));
160  default:
161  return {};
162  }
163  });
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Atan2Op folder
168 //===----------------------------------------------------------------------===//
169 
170 OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
171  return constFoldBinaryOpConditional<FloatAttr>(
172  adaptor.getOperands(),
173  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174  if (a.isZero() && b.isZero())
175  return llvm::APFloat::getNaN(a.getSemantics());
176 
177  if (a.getSizeInBits(a.getSemantics()) == 64 &&
178  b.getSizeInBits(b.getSemantics()) == 64)
179  return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180 
181  if (a.getSizeInBits(a.getSemantics()) == 32 &&
182  b.getSizeInBits(b.getSemantics()) == 32)
183  return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
184 
185  return {};
186  });
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // CeilOp folder
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
194  return constFoldUnaryOp<FloatAttr>(
195  adaptor.getOperands(), [](const APFloat &a) {
196  APFloat result(a);
197  result.roundToIntegral(llvm::RoundingMode::TowardPositive);
198  return result;
199  });
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // CopySignOp folder
204 //===----------------------------------------------------------------------===//
205 
206 OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208  [](const APFloat &a, const APFloat &b) {
209  APFloat result(a);
210  result.copySign(b);
211  return result;
212  });
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // CosOp folder
217 //===----------------------------------------------------------------------===//
218 
219 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
220  return constFoldUnaryOpConditional<FloatAttr>(
221  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
222  switch (a.getSizeInBits(a.getSemantics())) {
223  case 64:
224  return APFloat(cos(a.convertToDouble()));
225  case 32:
226  return APFloat(cosf(a.convertToFloat()));
227  default:
228  return {};
229  }
230  });
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // CoshOp folder
235 //===----------------------------------------------------------------------===//
236 
237 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
238  return constFoldUnaryOpConditional<FloatAttr>(
239  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
240  switch (a.getSizeInBits(a.getSemantics())) {
241  case 64:
242  return APFloat(cosh(a.convertToDouble()));
243  case 32:
244  return APFloat(coshf(a.convertToFloat()));
245  default:
246  return {};
247  }
248  });
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // SinOp folder
253 //===----------------------------------------------------------------------===//
254 
255 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
256  return constFoldUnaryOpConditional<FloatAttr>(
257  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
258  switch (a.getSizeInBits(a.getSemantics())) {
259  case 64:
260  return APFloat(sin(a.convertToDouble()));
261  case 32:
262  return APFloat(sinf(a.convertToFloat()));
263  default:
264  return {};
265  }
266  });
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // SinhOp folder
271 //===----------------------------------------------------------------------===//
272 
273 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
274  return constFoldUnaryOpConditional<FloatAttr>(
275  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
276  switch (a.getSizeInBits(a.getSemantics())) {
277  case 64:
278  return APFloat(sinh(a.convertToDouble()));
279  case 32:
280  return APFloat(sinhf(a.convertToFloat()));
281  default:
282  return {};
283  }
284  });
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // SinCosOp getShapeForUnroll
289 //===----------------------------------------------------------------------===//
290 
291 std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292  if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
293  return llvm::to_vector<4>(vt.getShape());
294  return std::nullopt;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // CountLeadingZerosOp folder
299 //===----------------------------------------------------------------------===//
300 
301 OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
302  return constFoldUnaryOp<IntegerAttr>(
303  adaptor.getOperands(),
304  [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // CountTrailingZerosOp folder
309 //===----------------------------------------------------------------------===//
310 
311 OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
312  return constFoldUnaryOp<IntegerAttr>(
313  adaptor.getOperands(),
314  [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // CtPopOp folder
319 //===----------------------------------------------------------------------===//
320 
321 OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
322  return constFoldUnaryOp<IntegerAttr>(
323  adaptor.getOperands(),
324  [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // ErfOp folder
329 //===----------------------------------------------------------------------===//
330 
331 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
332  return constFoldUnaryOpConditional<FloatAttr>(
333  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
334  switch (a.getSizeInBits(a.getSemantics())) {
335  case 64:
336  return APFloat(erf(a.convertToDouble()));
337  case 32:
338  return APFloat(erff(a.convertToFloat()));
339  default:
340  return {};
341  }
342  });
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ErfcOp folder
347 //===----------------------------------------------------------------------===//
348 
349 OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
350  return constFoldUnaryOpConditional<FloatAttr>(
351  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
352  switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353  case APFloat::Semantics::S_IEEEdouble:
354  return APFloat(erfc(a.convertToDouble()));
355  case APFloat::Semantics::S_IEEEsingle:
356  return APFloat(erfcf(a.convertToFloat()));
357  default:
358  return {};
359  }
360  });
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // IPowIOp folder
365 //===----------------------------------------------------------------------===//
366 
367 OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
368  return constFoldBinaryOpConditional<IntegerAttr>(
369  adaptor.getOperands(),
370  [](const APInt &base, const APInt &power) -> std::optional<APInt> {
371  unsigned width = base.getBitWidth();
372  auto zeroValue = APInt::getZero(width);
373  APInt oneValue{width, 1ULL, /*isSigned=*/true};
374  APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
375 
376  if (power.isZero())
377  return oneValue;
378 
379  if (power.isNegative()) {
380  // Leave 0 raised to negative power not folded.
381  if (base.isZero())
382  return {};
383  if (base.eq(oneValue))
384  return oneValue;
385  // If abs(base) > 1, then the result is zero.
386  if (base.ne(minusOneValue))
387  return zeroValue;
388  // base == -1:
389  // -1: power is odd
390  // 1: power is even
391  if (power[0] == 1)
392  return minusOneValue;
393 
394  return oneValue;
395  }
396 
397  // power is positive.
398  APInt result = oneValue;
399  APInt curBase = base;
400  APInt curPower = power;
401  while (true) {
402  if (curPower[0] == 1)
403  result *= curBase;
404  curPower.lshrInPlace(1);
405  if (curPower.isZero())
406  return result;
407  curBase *= curBase;
408  }
409  });
410 
411  return Attribute();
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // LogOp folder
416 //===----------------------------------------------------------------------===//
417 
418 OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
419  return constFoldUnaryOpConditional<FloatAttr>(
420  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
421  if (a.isNegative())
422  return {};
423 
424  if (a.getSizeInBits(a.getSemantics()) == 64)
425  return APFloat(log(a.convertToDouble()));
426 
427  if (a.getSizeInBits(a.getSemantics()) == 32)
428  return APFloat(logf(a.convertToFloat()));
429 
430  return {};
431  });
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // Log2Op folder
436 //===----------------------------------------------------------------------===//
437 
438 OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
439  return constFoldUnaryOpConditional<FloatAttr>(
440  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
441  if (a.isNegative())
442  return {};
443 
444  if (a.getSizeInBits(a.getSemantics()) == 64)
445  return APFloat(log2(a.convertToDouble()));
446 
447  if (a.getSizeInBits(a.getSemantics()) == 32)
448  return APFloat(log2f(a.convertToFloat()));
449 
450  return {};
451  });
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // Log10Op folder
456 //===----------------------------------------------------------------------===//
457 
458 OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
459  return constFoldUnaryOpConditional<FloatAttr>(
460  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
461  if (a.isNegative())
462  return {};
463 
464  switch (a.getSizeInBits(a.getSemantics())) {
465  case 64:
466  return APFloat(log10(a.convertToDouble()));
467  case 32:
468  return APFloat(log10f(a.convertToFloat()));
469  default:
470  return {};
471  }
472  });
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Log1pOp folder
477 //===----------------------------------------------------------------------===//
478 
479 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
480  return constFoldUnaryOpConditional<FloatAttr>(
481  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
482  switch (a.getSizeInBits(a.getSemantics())) {
483  case 64:
484  if ((a + APFloat(1.0)).isNegative())
485  return {};
486  return APFloat(log1p(a.convertToDouble()));
487  case 32:
488  if ((a + APFloat(1.0f)).isNegative())
489  return {};
490  return APFloat(log1pf(a.convertToFloat()));
491  default:
492  return {};
493  }
494  });
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // PowFOp folder
499 //===----------------------------------------------------------------------===//
500 
501 OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
502  return constFoldBinaryOpConditional<FloatAttr>(
503  adaptor.getOperands(),
504  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
505  if (a.getSizeInBits(a.getSemantics()) == 64 &&
506  b.getSizeInBits(b.getSemantics()) == 64)
507  return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
508 
509  if (a.getSizeInBits(a.getSemantics()) == 32 &&
510  b.getSizeInBits(b.getSemantics()) == 32)
511  return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
512 
513  return {};
514  });
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // SqrtOp folder
519 //===----------------------------------------------------------------------===//
520 
521 OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
522  return constFoldUnaryOpConditional<FloatAttr>(
523  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
524  if (a.isNegative())
525  return {};
526 
527  switch (a.getSizeInBits(a.getSemantics())) {
528  case 64:
529  return APFloat(sqrt(a.convertToDouble()));
530  case 32:
531  return APFloat(sqrtf(a.convertToFloat()));
532  default:
533  return {};
534  }
535  });
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // ExpOp folder
540 //===----------------------------------------------------------------------===//
541 
542 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
543  return constFoldUnaryOpConditional<FloatAttr>(
544  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
545  switch (a.getSizeInBits(a.getSemantics())) {
546  case 64:
547  return APFloat(exp(a.convertToDouble()));
548  case 32:
549  return APFloat(expf(a.convertToFloat()));
550  default:
551  return {};
552  }
553  });
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // Exp2Op folder
558 //===----------------------------------------------------------------------===//
559 
560 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
561  return constFoldUnaryOpConditional<FloatAttr>(
562  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
563  switch (a.getSizeInBits(a.getSemantics())) {
564  case 64:
565  return APFloat(exp2(a.convertToDouble()));
566  case 32:
567  return APFloat(exp2f(a.convertToFloat()));
568  default:
569  return {};
570  }
571  });
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // ExpM1Op folder
576 //===----------------------------------------------------------------------===//
577 
578 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
579  return constFoldUnaryOpConditional<FloatAttr>(
580  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
581  switch (a.getSizeInBits(a.getSemantics())) {
582  case 64:
583  return APFloat(expm1(a.convertToDouble()));
584  case 32:
585  return APFloat(expm1f(a.convertToFloat()));
586  default:
587  return {};
588  }
589  });
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // IsFiniteOp folder
594 //===----------------------------------------------------------------------===//
595 
596 OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
597  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
598  return BoolAttr::get(val.getContext(), val.getValue().isFinite());
599  }
600  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
601  return DenseElementsAttr::get(
602  cast<ShapedType>(getType()),
603  APInt(1, splat.getSplatValue<APFloat>().isFinite()));
604  }
605  return {};
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // IsInfOp folder
610 //===----------------------------------------------------------------------===//
611 
612 OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
613  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
614  return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
615  }
616  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
617  return DenseElementsAttr::get(
618  cast<ShapedType>(getType()),
619  APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
620  }
621  return {};
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // IsNaNOp folder
626 //===----------------------------------------------------------------------===//
627 
628 OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
629  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
630  return BoolAttr::get(val.getContext(), val.getValue().isNaN());
631  }
632  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
633  return DenseElementsAttr::get(
634  cast<ShapedType>(getType()),
635  APInt(1, splat.getSplatValue<APFloat>().isNaN()));
636  }
637  return {};
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // IsNormalOp folder
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
645  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
646  return BoolAttr::get(val.getContext(), val.getValue().isNormal());
647  }
648  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
649  return DenseElementsAttr::get(
650  cast<ShapedType>(getType()),
651  APInt(1, splat.getSplatValue<APFloat>().isNormal()));
652  }
653  return {};
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // TanOp folder
658 //===----------------------------------------------------------------------===//
659 
660 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
661  return constFoldUnaryOpConditional<FloatAttr>(
662  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
663  switch (a.getSizeInBits(a.getSemantics())) {
664  case 64:
665  return APFloat(tan(a.convertToDouble()));
666  case 32:
667  return APFloat(tanf(a.convertToFloat()));
668  default:
669  return {};
670  }
671  });
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // TanhOp folder
676 //===----------------------------------------------------------------------===//
677 
678 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
679  return constFoldUnaryOpConditional<FloatAttr>(
680  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
681  switch (a.getSizeInBits(a.getSemantics())) {
682  case 64:
683  return APFloat(tanh(a.convertToDouble()));
684  case 32:
685  return APFloat(tanhf(a.convertToFloat()));
686  default:
687  return {};
688  }
689  });
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // RoundEvenOp folder
694 //===----------------------------------------------------------------------===//
695 
696 OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
697  return constFoldUnaryOp<FloatAttr>(
698  adaptor.getOperands(), [](const APFloat &a) {
699  APFloat result(a);
700  result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
701  return result;
702  });
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // FloorOp folder
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
710  return constFoldUnaryOp<FloatAttr>(
711  adaptor.getOperands(), [](const APFloat &a) {
712  APFloat result(a);
713  result.roundToIntegral(llvm::RoundingMode::TowardNegative);
714  return result;
715  });
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // RoundOp folder
720 //===----------------------------------------------------------------------===//
721 
722 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
723  return constFoldUnaryOpConditional<FloatAttr>(
724  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
725  switch (a.getSizeInBits(a.getSemantics())) {
726  case 64:
727  return APFloat(round(a.convertToDouble()));
728  case 32:
729  return APFloat(roundf(a.convertToFloat()));
730  default:
731  return {};
732  }
733  });
734 }
735 
736 //===----------------------------------------------------------------------===//
737 // TruncOp folder
738 //===----------------------------------------------------------------------===//
739 
740 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
741  return constFoldUnaryOpConditional<FloatAttr>(
742  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
743  switch (a.getSizeInBits(a.getSemantics())) {
744  case 64:
745  return APFloat(trunc(a.convertToDouble()));
746  case 32:
747  return APFloat(truncf(a.convertToFloat()));
748  default:
749  return {};
750  }
751  });
752 }
753 
754 /// Materialize an integer or floating point constant.
756  Attribute value, Type type,
757  Location loc) {
758  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
759  return ub::PoisonOp::create(builder, loc, type, poison);
760 
761  return arith::ConstantOp::materialize(builder, value, type, loc);
762 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: MathOps.cpp:24
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...