25 auto i1Type = IntegerType::get(type.
getContext(), 1);
26 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
46 [](
const APFloat &a) { return abs(a); });
55 [](
const APInt &a) { return a.abs(); });
64 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
67 return APFloat(acos(a.convertToDouble()));
69 return APFloat(acosf(a.convertToFloat()));
82 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
85 return APFloat(acosh(a.convertToDouble()));
87 return APFloat(acoshf(a.convertToFloat()));
100 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
103 return APFloat(asin(a.convertToDouble()));
105 return APFloat(asinf(a.convertToFloat()));
118 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
121 return APFloat(asinh(a.convertToDouble()));
123 return APFloat(asinhf(a.convertToFloat()));
136 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
139 return APFloat(atan(a.convertToDouble()));
141 return APFloat(atanf(a.convertToFloat()));
154 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
157 return APFloat(atanh(a.convertToDouble()));
159 return APFloat(atanhf(a.convertToFloat()));
172 adaptor.getOperands(),
173 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
195 adaptor.getOperands(), [](
const APFloat &a) {
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
208 [](
const APFloat &a,
const APFloat &
b) {
221 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
224 return APFloat(cos(a.convertToDouble()));
226 return APFloat(cosf(a.convertToFloat()));
239 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
242 return APFloat(cosh(a.convertToDouble()));
244 return APFloat(coshf(a.convertToFloat()));
257 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
260 return APFloat(sin(a.convertToDouble()));
262 return APFloat(sinf(a.convertToFloat()));
275 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
278 return APFloat(sinh(a.convertToDouble()));
280 return APFloat(sinhf(a.convertToFloat()));
291std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (
auto vt = mlir::dyn_cast<VectorType>(getOperand().
getType()))
293 return llvm::to_vector<4>(vt.getShape());
301OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
303 adaptor.getOperands(),
304 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
311OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
313 adaptor.getOperands(),
314 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
323 adaptor.getOperands(),
324 [](
const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
333 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
336 return APFloat(erf(a.convertToDouble()));
338 return APFloat(erff(a.convertToFloat()));
351 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
369 adaptor.getOperands(),
370 [](
const APInt &base,
const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
373 APInt oneValue{width, 1ULL, true};
374 APInt minusOneValue{width, -1ULL, true};
379 if (power.isNegative()) {
383 if (base.eq(oneValue))
386 if (base.ne(minusOneValue))
392 return minusOneValue;
399 APInt curBase = base;
400 APInt curPower = power;
402 if (curPower[0] == 1)
404 curPower.lshrInPlace(1);
405 if (curPower.isZero())
420 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
424 if (a.getSizeInBits(a.getSemantics()) == 64)
425 return APFloat(log(a.convertToDouble()));
427 if (a.getSizeInBits(a.getSemantics()) == 32)
428 return APFloat(logf(a.convertToFloat()));
440 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
444 if (a.getSizeInBits(a.getSemantics()) == 64)
445 return APFloat(log2(a.convertToDouble()));
447 if (a.getSizeInBits(a.getSemantics()) == 32)
448 return APFloat(log2f(a.convertToFloat()));
460 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
464 switch (a.getSizeInBits(a.getSemantics())) {
466 return APFloat(log10(a.convertToDouble()));
468 return APFloat(log10f(a.convertToFloat()));
481 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
482 switch (a.getSizeInBits(a.getSemantics())) {
484 if ((a + APFloat(1.0)).isNegative())
486 return APFloat(log1p(a.convertToDouble()));
488 if ((a + APFloat(1.0f)).isNegative())
490 return APFloat(log1pf(a.convertToFloat()));
503 adaptor.getOperands(),
504 [](
const APFloat &a,
const APFloat &
b) -> std::optional<APFloat> {
505 if (a.getSizeInBits(a.getSemantics()) == 64 &&
506 b.getSizeInBits(b.getSemantics()) == 64)
507 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
509 if (a.getSizeInBits(a.getSemantics()) == 32 &&
510 b.getSizeInBits(b.getSemantics()) == 32)
511 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
523 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
527 switch (a.getSizeInBits(a.getSemantics())) {
529 return APFloat(sqrt(a.convertToDouble()));
531 return APFloat(sqrtf(a.convertToFloat()));
544 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
545 switch (a.getSizeInBits(a.getSemantics())) {
547 return APFloat(exp(a.convertToDouble()));
549 return APFloat(expf(a.convertToFloat()));
562 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
563 switch (a.getSizeInBits(a.getSemantics())) {
565 return APFloat(exp2(a.convertToDouble()));
567 return APFloat(exp2f(a.convertToFloat()));
580 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
581 switch (a.getSizeInBits(a.getSemantics())) {
583 return APFloat(expm1(a.convertToDouble()));
585 return APFloat(expm1f(a.convertToFloat()));
596OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
597 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
598 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
600 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
603 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
613 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
614 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
616 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
619 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
629 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
630 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
632 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
635 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
644OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
645 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
646 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
648 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
651 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
662 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
663 switch (a.getSizeInBits(a.getSemantics())) {
665 return APFloat(tan(a.convertToDouble()));
667 return APFloat(tanf(a.convertToFloat()));
680 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
681 switch (a.getSizeInBits(a.getSemantics())) {
683 return APFloat(tanh(a.convertToDouble()));
685 return APFloat(tanhf(a.convertToFloat()));
696OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
698 adaptor.getOperands(), [](
const APFloat &a) {
700 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
711 adaptor.getOperands(), [](
const APFloat &a) {
713 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
724 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
725 switch (a.getSizeInBits(a.getSemantics())) {
727 return APFloat(round(a.convertToDouble()));
729 return APFloat(roundf(a.convertToFloat()));
742 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
743 switch (a.getSizeInBits(a.getSemantics())) {
745 return APFloat(trunc(a.convertToDouble()));
747 return APFloat(truncf(a.convertToFloat()));
758 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
759 return ub::PoisonOp::create(builder, loc, type, poison);
761 return arith::ConstantOp::materialize(builder, value, type, loc);
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)