25 auto i1Type = IntegerType::get(type.
getContext(), 1);
26 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
46 [](
const APFloat &a) { return abs(a); });
55 [](
const APInt &a) { return a.abs(); });
64 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
67 return APFloat(acos(a.convertToDouble()));
69 return APFloat(acosf(a.convertToFloat()));
82 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
85 return APFloat(acosh(a.convertToDouble()));
87 return APFloat(acoshf(a.convertToFloat()));
100 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
103 return APFloat(asin(a.convertToDouble()));
105 return APFloat(asinf(a.convertToFloat()));
118 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
121 return APFloat(asinh(a.convertToDouble()));
123 return APFloat(asinhf(a.convertToFloat()));
136 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
139 return APFloat(atan(a.convertToDouble()));
141 return APFloat(atanf(a.convertToFloat()));
154 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
157 return APFloat(atanh(a.convertToDouble()));
159 return APFloat(atanhf(a.convertToFloat()));
172 adaptor.getOperands(),
173 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
195 adaptor.getOperands(), [](
const APFloat &a) {
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
208 [](
const APFloat &a,
const APFloat &
b) {
221 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
224 return APFloat(cos(a.convertToDouble()));
226 return APFloat(cosf(a.convertToFloat()));
239 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
242 return APFloat(cosh(a.convertToDouble()));
244 return APFloat(coshf(a.convertToFloat()));
257 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
260 return APFloat(sin(a.convertToDouble()));
262 return APFloat(sinf(a.convertToFloat()));
275 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
278 return APFloat(sinh(a.convertToDouble()));
280 return APFloat(sinhf(a.convertToFloat()));
291std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (
auto vt = mlir::dyn_cast<VectorType>(getOperand().
getType()))
293 return llvm::to_vector<4>(vt.getShape());
301OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
303 adaptor.getOperands(),
304 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
311OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
313 adaptor.getOperands(),
314 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
323 adaptor.getOperands(),
324 [](
const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
333 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
336 return APFloat(erf(a.convertToDouble()));
338 return APFloat(erff(a.convertToFloat()));
351 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
369 adaptor.getOperands(),
370 [](
const APInt &base,
const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
376 APInt oneValue{width, 1ULL, true};
377 APInt minusOneValue{width, -1ULL, true};
382 if (power.isNegative()) {
389 if (base.ne(minusOneValue))
395 return minusOneValue;
402 APInt curBase = base;
403 APInt curPower = power;
405 if (curPower[0] == 1)
407 curPower.lshrInPlace(1);
408 if (curPower.isZero())
423 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
427 if (a.getSizeInBits(a.getSemantics()) == 64)
428 return APFloat(log(a.convertToDouble()));
430 if (a.getSizeInBits(a.getSemantics()) == 32)
431 return APFloat(logf(a.convertToFloat()));
443 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
447 if (a.getSizeInBits(a.getSemantics()) == 64)
448 return APFloat(log2(a.convertToDouble()));
450 if (a.getSizeInBits(a.getSemantics()) == 32)
451 return APFloat(log2f(a.convertToFloat()));
463 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
467 switch (a.getSizeInBits(a.getSemantics())) {
469 return APFloat(log10(a.convertToDouble()));
471 return APFloat(log10f(a.convertToFloat()));
484 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
485 switch (a.getSizeInBits(a.getSemantics())) {
487 if ((a + APFloat(1.0)).isNegative())
489 return APFloat(log1p(a.convertToDouble()));
491 if ((a + APFloat(1.0f)).isNegative())
493 return APFloat(log1pf(a.convertToFloat()));
506 adaptor.getOperands(),
507 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
508 if (a.getSizeInBits(a.getSemantics()) == 64 &&
509 b.getSizeInBits(b.getSemantics()) == 64)
510 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
512 if (a.getSizeInBits(a.getSemantics()) == 32 &&
513 b.getSizeInBits(b.getSemantics()) == 32)
514 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
526 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
530 switch (a.getSizeInBits(a.getSemantics())) {
532 return APFloat(sqrt(a.convertToDouble()));
534 return APFloat(sqrtf(a.convertToFloat()));
547 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
548 switch (a.getSizeInBits(a.getSemantics())) {
550 return APFloat(exp(a.convertToDouble()));
552 return APFloat(expf(a.convertToFloat()));
565 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
566 switch (a.getSizeInBits(a.getSemantics())) {
568 return APFloat(exp2(a.convertToDouble()));
570 return APFloat(exp2f(a.convertToFloat()));
583 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
584 switch (a.getSizeInBits(a.getSemantics())) {
586 return APFloat(expm1(a.convertToDouble()));
588 return APFloat(expm1f(a.convertToFloat()));
599OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
600 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
601 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
603 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
606 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
616 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
617 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
619 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
622 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
632 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
633 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
635 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
638 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
647OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
648 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
649 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
651 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
654 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
665 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
666 switch (a.getSizeInBits(a.getSemantics())) {
668 return APFloat(tan(a.convertToDouble()));
670 return APFloat(tanf(a.convertToFloat()));
683 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
684 switch (a.getSizeInBits(a.getSemantics())) {
686 return APFloat(tanh(a.convertToDouble()));
688 return APFloat(tanhf(a.convertToFloat()));
699OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
701 adaptor.getOperands(), [](
const APFloat &a) {
703 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
714 adaptor.getOperands(), [](
const APFloat &a) {
716 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
727 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
728 switch (a.getSizeInBits(a.getSemantics())) {
730 return APFloat(round(a.convertToDouble()));
732 return APFloat(roundf(a.convertToFloat()));
745 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
746 switch (a.getSizeInBits(a.getSemantics())) {
748 return APFloat(trunc(a.convertToDouble()));
750 return APFloat(truncf(a.convertToFloat()));
761 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
762 return ub::PoisonOp::create(builder, loc, type, poison);
764 return arith::ConstantOp::materialize(builder, value, type, loc);
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)