25 auto i1Type = IntegerType::get(type.
getContext(), 1);
26 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
46 [](
const APFloat &a) { return abs(a); });
55 [](
const APInt &a) { return a.abs(); });
64 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
67 return APFloat(acos(a.convertToDouble()));
69 return APFloat(acosf(a.convertToFloat()));
82 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
85 return APFloat(acosh(a.convertToDouble()));
87 return APFloat(acoshf(a.convertToFloat()));
100 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
103 return APFloat(asin(a.convertToDouble()));
105 return APFloat(asinf(a.convertToFloat()));
118 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
121 return APFloat(asinh(a.convertToDouble()));
123 return APFloat(asinhf(a.convertToFloat()));
136 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
139 return APFloat(atan(a.convertToDouble()));
141 return APFloat(atanf(a.convertToFloat()));
154 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
157 return APFloat(atanh(a.convertToDouble()));
159 return APFloat(atanhf(a.convertToFloat()));
172 adaptor.getOperands(),
173 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
195 adaptor.getOperands(), [](
const APFloat &a) {
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
208 [](
const APFloat &a,
const APFloat &
b) {
221 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
224 return APFloat(cos(a.convertToDouble()));
226 return APFloat(cosf(a.convertToFloat()));
239 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
242 return APFloat(cosh(a.convertToDouble()));
244 return APFloat(coshf(a.convertToFloat()));
257 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
260 return APFloat(sin(a.convertToDouble()));
262 return APFloat(sinf(a.convertToFloat()));
275 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
278 return APFloat(sinh(a.convertToDouble()));
280 return APFloat(sinhf(a.convertToFloat()));
291std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (
auto vt = mlir::dyn_cast<VectorType>(getOperand().
getType()))
293 return llvm::to_vector<4>(vt.getShape());
301OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
303 adaptor.getOperands(),
304 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
311OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
313 adaptor.getOperands(),
314 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
323 adaptor.getOperands(),
324 [](
const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
333 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
336 return APFloat(erf(a.convertToDouble()));
338 return APFloat(erff(a.convertToFloat()));
351 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
369 adaptor.getOperands(),
370 [](
const APInt &base,
const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
376 APInt oneValue{width, 1ULL, true};
377 APInt minusOneValue{width, -1ULL, true};
382 if (power.isNegative()) {
389 if (base.ne(minusOneValue))
395 return minusOneValue;
402 APInt curBase = base;
403 APInt curPower = power;
405 if (curPower[0] == 1)
407 curPower.lshrInPlace(1);
408 if (curPower.isZero())
423 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
427 if (a.getSizeInBits(a.getSemantics()) == 64)
428 return APFloat(log(a.convertToDouble()));
430 if (a.getSizeInBits(a.getSemantics()) == 32)
431 return APFloat(logf(a.convertToFloat()));
443 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
447 if (a.getSizeInBits(a.getSemantics()) == 64)
448 return APFloat(log2(a.convertToDouble()));
450 if (a.getSizeInBits(a.getSemantics()) == 32)
451 return APFloat(log2f(a.convertToFloat()));
463 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
467 switch (a.getSizeInBits(a.getSemantics())) {
469 return APFloat(log10(a.convertToDouble()));
471 return APFloat(log10f(a.convertToFloat()));
484 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
485 switch (a.getSizeInBits(a.getSemantics())) {
487 if ((a + APFloat(1.0)).isNegative())
489 return APFloat(log1p(a.convertToDouble()));
491 if ((a + APFloat(1.0f)).isNegative())
493 return APFloat(log1pf(a.convertToFloat()));
506 adaptor.getOperands(),
507 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
508 if (a.getSizeInBits(a.getSemantics()) == 64 &&
509 b.getSizeInBits(b.getSemantics()) == 64)
510 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
512 if (a.getSizeInBits(a.getSemantics()) == 32 &&
513 b.getSizeInBits(b.getSemantics()) == 32)
514 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
526 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
530 APFloat one(a.getSemantics(), 1);
531 switch (a.getSizeInBits(a.getSemantics())) {
533 return one / APFloat(sqrt(a.convertToDouble()));
535 return one / APFloat(sqrtf(a.convertToFloat()));
548 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
552 switch (a.getSizeInBits(a.getSemantics())) {
554 return APFloat(sqrt(a.convertToDouble()));
556 return APFloat(sqrtf(a.convertToFloat()));
569 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
570 switch (a.getSizeInBits(a.getSemantics())) {
572 return APFloat(exp(a.convertToDouble()));
574 return APFloat(expf(a.convertToFloat()));
587 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
588 switch (a.getSizeInBits(a.getSemantics())) {
590 return APFloat(exp2(a.convertToDouble()));
592 return APFloat(exp2f(a.convertToFloat()));
605 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
606 switch (a.getSizeInBits(a.getSemantics())) {
608 return APFloat(expm1(a.convertToDouble()));
610 return APFloat(expm1f(a.convertToFloat()));
621OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
622 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
623 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
625 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
628 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
638 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
639 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
641 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
644 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
654 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
655 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
657 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
660 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
669OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
670 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
671 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
673 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
676 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
687 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
688 switch (a.getSizeInBits(a.getSemantics())) {
690 return APFloat(tan(a.convertToDouble()));
692 return APFloat(tanf(a.convertToFloat()));
705 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
706 switch (a.getSizeInBits(a.getSemantics())) {
708 return APFloat(tanh(a.convertToDouble()));
710 return APFloat(tanhf(a.convertToFloat()));
721OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
723 adaptor.getOperands(), [](
const APFloat &a) {
725 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
736 adaptor.getOperands(), [](
const APFloat &a) {
738 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
749 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
750 switch (a.getSizeInBits(a.getSemantics())) {
752 return APFloat(round(a.convertToDouble()));
754 return APFloat(roundf(a.convertToFloat()));
767 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
768 switch (a.getSizeInBits(a.getSemantics())) {
770 return APFloat(trunc(a.convertToDouble()));
772 return APFloat(truncf(a.convertToFloat()));
783 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
784 return ub::PoisonOp::create(builder, loc, type, poison);
786 return arith::ConstantOp::materialize(builder, value, type, loc);
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)