MLIR  20.0.0git
MathOps.cpp
Go to the documentation of this file.
1 //===- MathOps.cpp - MLIR operations for math implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::math;
18 
19 //===----------------------------------------------------------------------===//
20 // TableGen'd op method definitions
21 //===----------------------------------------------------------------------===//
22 
23 #define GET_OP_CLASSES
24 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // AbsFOp folder
28 //===----------------------------------------------------------------------===//
29 
30 OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
31  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
32  [](const APFloat &a) { return abs(a); });
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // AbsIOp folder
37 //===----------------------------------------------------------------------===//
38 
39 OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
40  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
41  [](const APInt &a) { return a.abs(); });
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // AcosOp folder
46 //===----------------------------------------------------------------------===//
47 
48 OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
49  return constFoldUnaryOpConditional<FloatAttr>(
50  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
51  switch (a.getSizeInBits(a.getSemantics())) {
52  case 64:
53  return APFloat(acos(a.convertToDouble()));
54  case 32:
55  return APFloat(acosf(a.convertToFloat()));
56  default:
57  return {};
58  }
59  });
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // AcoshOp folder
64 //===----------------------------------------------------------------------===//
65 
66 OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
67  return constFoldUnaryOpConditional<FloatAttr>(
68  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
69  switch (a.getSizeInBits(a.getSemantics())) {
70  case 64:
71  return APFloat(acosh(a.convertToDouble()));
72  case 32:
73  return APFloat(acoshf(a.convertToFloat()));
74  default:
75  return {};
76  }
77  });
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // AsinOp folder
82 //===----------------------------------------------------------------------===//
83 
84 OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
85  return constFoldUnaryOpConditional<FloatAttr>(
86  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
87  switch (a.getSizeInBits(a.getSemantics())) {
88  case 64:
89  return APFloat(asin(a.convertToDouble()));
90  case 32:
91  return APFloat(asinf(a.convertToFloat()));
92  default:
93  return {};
94  }
95  });
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // AsinhOp folder
100 //===----------------------------------------------------------------------===//
101 
102 OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
103  return constFoldUnaryOpConditional<FloatAttr>(
104  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
105  switch (a.getSizeInBits(a.getSemantics())) {
106  case 64:
107  return APFloat(asinh(a.convertToDouble()));
108  case 32:
109  return APFloat(asinhf(a.convertToFloat()));
110  default:
111  return {};
112  }
113  });
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // AtanOp folder
118 //===----------------------------------------------------------------------===//
119 
120 OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
121  return constFoldUnaryOpConditional<FloatAttr>(
122  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
123  switch (a.getSizeInBits(a.getSemantics())) {
124  case 64:
125  return APFloat(atan(a.convertToDouble()));
126  case 32:
127  return APFloat(atanf(a.convertToFloat()));
128  default:
129  return {};
130  }
131  });
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // AtanhOp folder
136 //===----------------------------------------------------------------------===//
137 
138 OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
139  return constFoldUnaryOpConditional<FloatAttr>(
140  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
141  switch (a.getSizeInBits(a.getSemantics())) {
142  case 64:
143  return APFloat(atanh(a.convertToDouble()));
144  case 32:
145  return APFloat(atanhf(a.convertToFloat()));
146  default:
147  return {};
148  }
149  });
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Atan2Op folder
154 //===----------------------------------------------------------------------===//
155 
156 OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
157  return constFoldBinaryOpConditional<FloatAttr>(
158  adaptor.getOperands(),
159  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
160  if (a.isZero() && b.isZero())
161  return llvm::APFloat::getNaN(a.getSemantics());
162 
163  if (a.getSizeInBits(a.getSemantics()) == 64 &&
164  b.getSizeInBits(b.getSemantics()) == 64)
165  return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
166 
167  if (a.getSizeInBits(a.getSemantics()) == 32 &&
168  b.getSizeInBits(b.getSemantics()) == 32)
169  return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
170 
171  return {};
172  });
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // CeilOp folder
177 //===----------------------------------------------------------------------===//
178 
179 OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
180  return constFoldUnaryOp<FloatAttr>(
181  adaptor.getOperands(), [](const APFloat &a) {
182  APFloat result(a);
183  result.roundToIntegral(llvm::RoundingMode::TowardPositive);
184  return result;
185  });
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // CopySignOp folder
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
193  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
194  [](const APFloat &a, const APFloat &b) {
195  APFloat result(a);
196  result.copySign(b);
197  return result;
198  });
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // CosOp folder
203 //===----------------------------------------------------------------------===//
204 
205 OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
206  return constFoldUnaryOpConditional<FloatAttr>(
207  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
208  switch (a.getSizeInBits(a.getSemantics())) {
209  case 64:
210  return APFloat(cos(a.convertToDouble()));
211  case 32:
212  return APFloat(cosf(a.convertToFloat()));
213  default:
214  return {};
215  }
216  });
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // CoshOp folder
221 //===----------------------------------------------------------------------===//
222 
223 OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
224  return constFoldUnaryOpConditional<FloatAttr>(
225  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
226  switch (a.getSizeInBits(a.getSemantics())) {
227  case 64:
228  return APFloat(cosh(a.convertToDouble()));
229  case 32:
230  return APFloat(coshf(a.convertToFloat()));
231  default:
232  return {};
233  }
234  });
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SinOp folder
239 //===----------------------------------------------------------------------===//
240 
241 OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
242  return constFoldUnaryOpConditional<FloatAttr>(
243  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
244  switch (a.getSizeInBits(a.getSemantics())) {
245  case 64:
246  return APFloat(sin(a.convertToDouble()));
247  case 32:
248  return APFloat(sinf(a.convertToFloat()));
249  default:
250  return {};
251  }
252  });
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // SinhOp folder
257 //===----------------------------------------------------------------------===//
258 
259 OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
260  return constFoldUnaryOpConditional<FloatAttr>(
261  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
262  switch (a.getSizeInBits(a.getSemantics())) {
263  case 64:
264  return APFloat(sinh(a.convertToDouble()));
265  case 32:
266  return APFloat(sinhf(a.convertToFloat()));
267  default:
268  return {};
269  }
270  });
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // CountLeadingZerosOp folder
275 //===----------------------------------------------------------------------===//
276 
277 OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
278  return constFoldUnaryOp<IntegerAttr>(
279  adaptor.getOperands(),
280  [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // CountTrailingZerosOp folder
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
288  return constFoldUnaryOp<IntegerAttr>(
289  adaptor.getOperands(),
290  [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // CtPopOp folder
295 //===----------------------------------------------------------------------===//
296 
297 OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
298  return constFoldUnaryOp<IntegerAttr>(
299  adaptor.getOperands(),
300  [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ErfOp folder
305 //===----------------------------------------------------------------------===//
306 
307 OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
308  return constFoldUnaryOpConditional<FloatAttr>(
309  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
310  switch (a.getSizeInBits(a.getSemantics())) {
311  case 64:
312  return APFloat(erf(a.convertToDouble()));
313  case 32:
314  return APFloat(erff(a.convertToFloat()));
315  default:
316  return {};
317  }
318  });
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // IPowIOp folder
323 //===----------------------------------------------------------------------===//
324 
325 OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
326  return constFoldBinaryOpConditional<IntegerAttr>(
327  adaptor.getOperands(),
328  [](const APInt &base, const APInt &power) -> std::optional<APInt> {
329  unsigned width = base.getBitWidth();
330  auto zeroValue = APInt::getZero(width);
331  APInt oneValue{width, 1ULL, /*isSigned=*/true};
332  APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
333 
334  if (power.isZero())
335  return oneValue;
336 
337  if (power.isNegative()) {
338  // Leave 0 raised to negative power not folded.
339  if (base.isZero())
340  return {};
341  if (base.eq(oneValue))
342  return oneValue;
343  // If abs(base) > 1, then the result is zero.
344  if (base.ne(minusOneValue))
345  return zeroValue;
346  // base == -1:
347  // -1: power is odd
348  // 1: power is even
349  if (power[0] == 1)
350  return minusOneValue;
351 
352  return oneValue;
353  }
354 
355  // power is positive.
356  APInt result = oneValue;
357  APInt curBase = base;
358  APInt curPower = power;
359  while (true) {
360  if (curPower[0] == 1)
361  result *= curBase;
362  curPower.lshrInPlace(1);
363  if (curPower.isZero())
364  return result;
365  curBase *= curBase;
366  }
367  });
368 
369  return Attribute();
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // LogOp folder
374 //===----------------------------------------------------------------------===//
375 
376 OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
377  return constFoldUnaryOpConditional<FloatAttr>(
378  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
379  if (a.isNegative())
380  return {};
381 
382  if (a.getSizeInBits(a.getSemantics()) == 64)
383  return APFloat(log(a.convertToDouble()));
384 
385  if (a.getSizeInBits(a.getSemantics()) == 32)
386  return APFloat(logf(a.convertToFloat()));
387 
388  return {};
389  });
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // Log2Op folder
394 //===----------------------------------------------------------------------===//
395 
396 OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
397  return constFoldUnaryOpConditional<FloatAttr>(
398  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
399  if (a.isNegative())
400  return {};
401 
402  if (a.getSizeInBits(a.getSemantics()) == 64)
403  return APFloat(log2(a.convertToDouble()));
404 
405  if (a.getSizeInBits(a.getSemantics()) == 32)
406  return APFloat(log2f(a.convertToFloat()));
407 
408  return {};
409  });
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Log10Op folder
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
417  return constFoldUnaryOpConditional<FloatAttr>(
418  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
419  if (a.isNegative())
420  return {};
421 
422  switch (a.getSizeInBits(a.getSemantics())) {
423  case 64:
424  return APFloat(log10(a.convertToDouble()));
425  case 32:
426  return APFloat(log10f(a.convertToFloat()));
427  default:
428  return {};
429  }
430  });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // Log1pOp folder
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
438  return constFoldUnaryOpConditional<FloatAttr>(
439  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
440  switch (a.getSizeInBits(a.getSemantics())) {
441  case 64:
442  if ((a + APFloat(1.0)).isNegative())
443  return {};
444  return APFloat(log1p(a.convertToDouble()));
445  case 32:
446  if ((a + APFloat(1.0f)).isNegative())
447  return {};
448  return APFloat(log1pf(a.convertToFloat()));
449  default:
450  return {};
451  }
452  });
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // PowFOp folder
457 //===----------------------------------------------------------------------===//
458 
459 OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
460  return constFoldBinaryOpConditional<FloatAttr>(
461  adaptor.getOperands(),
462  [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
463  if (a.getSizeInBits(a.getSemantics()) == 64 &&
464  b.getSizeInBits(b.getSemantics()) == 64)
465  return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
466 
467  if (a.getSizeInBits(a.getSemantics()) == 32 &&
468  b.getSizeInBits(b.getSemantics()) == 32)
469  return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
470 
471  return {};
472  });
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // SqrtOp folder
477 //===----------------------------------------------------------------------===//
478 
479 OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
480  return constFoldUnaryOpConditional<FloatAttr>(
481  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
482  if (a.isNegative())
483  return {};
484 
485  switch (a.getSizeInBits(a.getSemantics())) {
486  case 64:
487  return APFloat(sqrt(a.convertToDouble()));
488  case 32:
489  return APFloat(sqrtf(a.convertToFloat()));
490  default:
491  return {};
492  }
493  });
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // ExpOp folder
498 //===----------------------------------------------------------------------===//
499 
500 OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
501  return constFoldUnaryOpConditional<FloatAttr>(
502  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
503  switch (a.getSizeInBits(a.getSemantics())) {
504  case 64:
505  return APFloat(exp(a.convertToDouble()));
506  case 32:
507  return APFloat(expf(a.convertToFloat()));
508  default:
509  return {};
510  }
511  });
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // Exp2Op folder
516 //===----------------------------------------------------------------------===//
517 
518 OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
519  return constFoldUnaryOpConditional<FloatAttr>(
520  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
521  switch (a.getSizeInBits(a.getSemantics())) {
522  case 64:
523  return APFloat(exp2(a.convertToDouble()));
524  case 32:
525  return APFloat(exp2f(a.convertToFloat()));
526  default:
527  return {};
528  }
529  });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // ExpM1Op folder
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
537  return constFoldUnaryOpConditional<FloatAttr>(
538  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
539  switch (a.getSizeInBits(a.getSemantics())) {
540  case 64:
541  return APFloat(expm1(a.convertToDouble()));
542  case 32:
543  return APFloat(expm1f(a.convertToFloat()));
544  default:
545  return {};
546  }
547  });
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // TanOp folder
552 //===----------------------------------------------------------------------===//
553 
554 OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
555  return constFoldUnaryOpConditional<FloatAttr>(
556  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
557  switch (a.getSizeInBits(a.getSemantics())) {
558  case 64:
559  return APFloat(tan(a.convertToDouble()));
560  case 32:
561  return APFloat(tanf(a.convertToFloat()));
562  default:
563  return {};
564  }
565  });
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // TanhOp folder
570 //===----------------------------------------------------------------------===//
571 
572 OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
573  return constFoldUnaryOpConditional<FloatAttr>(
574  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
575  switch (a.getSizeInBits(a.getSemantics())) {
576  case 64:
577  return APFloat(tanh(a.convertToDouble()));
578  case 32:
579  return APFloat(tanhf(a.convertToFloat()));
580  default:
581  return {};
582  }
583  });
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // RoundEvenOp folder
588 //===----------------------------------------------------------------------===//
589 
590 OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
591  return constFoldUnaryOp<FloatAttr>(
592  adaptor.getOperands(), [](const APFloat &a) {
593  APFloat result(a);
594  result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
595  return result;
596  });
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // FloorOp folder
601 //===----------------------------------------------------------------------===//
602 
603 OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
604  return constFoldUnaryOp<FloatAttr>(
605  adaptor.getOperands(), [](const APFloat &a) {
606  APFloat result(a);
607  result.roundToIntegral(llvm::RoundingMode::TowardNegative);
608  return result;
609  });
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // RoundOp folder
614 //===----------------------------------------------------------------------===//
615 
616 OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
617  return constFoldUnaryOpConditional<FloatAttr>(
618  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
619  switch (a.getSizeInBits(a.getSemantics())) {
620  case 64:
621  return APFloat(round(a.convertToDouble()));
622  case 32:
623  return APFloat(roundf(a.convertToFloat()));
624  default:
625  return {};
626  }
627  });
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // TruncOp folder
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
635  return constFoldUnaryOpConditional<FloatAttr>(
636  adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
637  switch (a.getSizeInBits(a.getSemantics())) {
638  case 64:
639  return APFloat(trunc(a.convertToDouble()));
640  case 32:
641  return APFloat(truncf(a.convertToFloat()));
642  default:
643  return {};
644  }
645  });
646 }
647 
648 /// Materialize an integer or floating point constant.
650  Attribute value, Type type,
651  Location loc) {
652  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
653  return builder.create<ub::PoisonOp>(loc, type, poison);
654 
655  return arith::ConstantOp::materialize(builder, value, type, loc);
656 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.