26 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
37 #define GET_OP_CLASSES
38 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](
const APFloat &a) { return abs(a); });
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](
const APInt &a) { return a.abs(); });
63 return constFoldUnaryOpConditional<FloatAttr>(
64 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
67 return APFloat(acos(a.convertToDouble()));
69 return APFloat(acosf(a.convertToFloat()));
81 return constFoldUnaryOpConditional<FloatAttr>(
82 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
85 return APFloat(acosh(a.convertToDouble()));
87 return APFloat(acoshf(a.convertToFloat()));
99 return constFoldUnaryOpConditional<FloatAttr>(
100 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
103 return APFloat(asin(a.convertToDouble()));
105 return APFloat(asinf(a.convertToFloat()));
117 return constFoldUnaryOpConditional<FloatAttr>(
118 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
121 return APFloat(asinh(a.convertToDouble()));
123 return APFloat(asinhf(a.convertToFloat()));
135 return constFoldUnaryOpConditional<FloatAttr>(
136 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
139 return APFloat(atan(a.convertToDouble()));
141 return APFloat(atanf(a.convertToFloat()));
153 return constFoldUnaryOpConditional<FloatAttr>(
154 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
157 return APFloat(atanh(a.convertToDouble()));
159 return APFloat(atanhf(a.convertToFloat()));
171 return constFoldBinaryOpConditional<FloatAttr>(
172 adaptor.getOperands(),
173 [](
const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
194 return constFoldUnaryOp<FloatAttr>(
195 adaptor.getOperands(), [](
const APFloat &a) {
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
206 OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208 [](
const APFloat &a,
const APFloat &b) {
220 return constFoldUnaryOpConditional<FloatAttr>(
221 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
224 return APFloat(cos(a.convertToDouble()));
226 return APFloat(cosf(a.convertToFloat()));
238 return constFoldUnaryOpConditional<FloatAttr>(
239 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
242 return APFloat(cosh(a.convertToDouble()));
244 return APFloat(coshf(a.convertToFloat()));
256 return constFoldUnaryOpConditional<FloatAttr>(
257 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
260 return APFloat(sin(a.convertToDouble()));
262 return APFloat(sinf(a.convertToFloat()));
274 return constFoldUnaryOpConditional<FloatAttr>(
275 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
278 return APFloat(sinh(a.convertToDouble()));
280 return APFloat(sinhf(a.convertToFloat()));
291 OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
292 return constFoldUnaryOp<IntegerAttr>(
293 adaptor.getOperands(),
294 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
301 OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
302 return constFoldUnaryOp<IntegerAttr>(
303 adaptor.getOperands(),
304 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
312 return constFoldUnaryOp<IntegerAttr>(
313 adaptor.getOperands(),
314 [](
const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
322 return constFoldUnaryOpConditional<FloatAttr>(
323 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
324 switch (a.getSizeInBits(a.getSemantics())) {
326 return APFloat(erf(a.convertToDouble()));
328 return APFloat(erff(a.convertToFloat()));
340 return constFoldUnaryOpConditional<FloatAttr>(
341 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
342 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
343 case APFloat::Semantics::S_IEEEdouble:
344 return APFloat(erfc(a.convertToDouble()));
345 case APFloat::Semantics::S_IEEEsingle:
346 return APFloat(erfcf(a.convertToFloat()));
358 return constFoldBinaryOpConditional<IntegerAttr>(
359 adaptor.getOperands(),
360 [](
const APInt &base,
const APInt &power) -> std::optional<APInt> {
361 unsigned width = base.getBitWidth();
362 auto zeroValue = APInt::getZero(width);
363 APInt oneValue{width, 1ULL, true};
364 APInt minusOneValue{width, -1ULL, true};
369 if (power.isNegative()) {
373 if (base.eq(oneValue))
376 if (base.ne(minusOneValue))
382 return minusOneValue;
388 APInt result = oneValue;
389 APInt curBase = base;
390 APInt curPower = power;
392 if (curPower[0] == 1)
394 curPower.lshrInPlace(1);
395 if (curPower.isZero())
409 return constFoldUnaryOpConditional<FloatAttr>(
410 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
414 if (a.getSizeInBits(a.getSemantics()) == 64)
415 return APFloat(log(a.convertToDouble()));
417 if (a.getSizeInBits(a.getSemantics()) == 32)
418 return APFloat(logf(a.convertToFloat()));
429 return constFoldUnaryOpConditional<FloatAttr>(
430 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
434 if (a.getSizeInBits(a.getSemantics()) == 64)
435 return APFloat(log2(a.convertToDouble()));
437 if (a.getSizeInBits(a.getSemantics()) == 32)
438 return APFloat(log2f(a.convertToFloat()));
449 return constFoldUnaryOpConditional<FloatAttr>(
450 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
454 switch (a.getSizeInBits(a.getSemantics())) {
456 return APFloat(log10(a.convertToDouble()));
458 return APFloat(log10f(a.convertToFloat()));
470 return constFoldUnaryOpConditional<FloatAttr>(
471 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
472 switch (a.getSizeInBits(a.getSemantics())) {
474 if ((a + APFloat(1.0)).isNegative())
476 return APFloat(log1p(a.convertToDouble()));
478 if ((a + APFloat(1.0f)).isNegative())
480 return APFloat(log1pf(a.convertToFloat()));
492 return constFoldBinaryOpConditional<FloatAttr>(
493 adaptor.getOperands(),
494 [](
const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
495 if (a.getSizeInBits(a.getSemantics()) == 64 &&
496 b.getSizeInBits(b.getSemantics()) == 64)
497 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
499 if (a.getSizeInBits(a.getSemantics()) == 32 &&
500 b.getSizeInBits(b.getSemantics()) == 32)
501 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
512 return constFoldUnaryOpConditional<FloatAttr>(
513 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
517 switch (a.getSizeInBits(a.getSemantics())) {
519 return APFloat(sqrt(a.convertToDouble()));
521 return APFloat(sqrtf(a.convertToFloat()));
533 return constFoldUnaryOpConditional<FloatAttr>(
534 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
535 switch (a.getSizeInBits(a.getSemantics())) {
537 return APFloat(exp(a.convertToDouble()));
539 return APFloat(expf(a.convertToFloat()));
551 return constFoldUnaryOpConditional<FloatAttr>(
552 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
553 switch (a.getSizeInBits(a.getSemantics())) {
555 return APFloat(exp2(a.convertToDouble()));
557 return APFloat(exp2f(a.convertToFloat()));
569 return constFoldUnaryOpConditional<FloatAttr>(
570 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
571 switch (a.getSizeInBits(a.getSemantics())) {
573 return APFloat(expm1(a.convertToDouble()));
575 return APFloat(expm1f(a.convertToFloat()));
586 OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
587 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
588 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
590 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
593 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
603 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
604 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
606 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
609 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
619 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
620 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
622 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
625 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
634 OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
635 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
636 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
638 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
641 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
651 return constFoldUnaryOpConditional<FloatAttr>(
652 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
653 switch (a.getSizeInBits(a.getSemantics())) {
655 return APFloat(tan(a.convertToDouble()));
657 return APFloat(tanf(a.convertToFloat()));
669 return constFoldUnaryOpConditional<FloatAttr>(
670 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
671 switch (a.getSizeInBits(a.getSemantics())) {
673 return APFloat(tanh(a.convertToDouble()));
675 return APFloat(tanhf(a.convertToFloat()));
686 OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
687 return constFoldUnaryOp<FloatAttr>(
688 adaptor.getOperands(), [](
const APFloat &a) {
690 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
700 return constFoldUnaryOp<FloatAttr>(
701 adaptor.getOperands(), [](
const APFloat &a) {
703 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
713 return constFoldUnaryOpConditional<FloatAttr>(
714 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
715 switch (a.getSizeInBits(a.getSemantics())) {
717 return APFloat(round(a.convertToDouble()));
719 return APFloat(roundf(a.convertToFloat()));
731 return constFoldUnaryOpConditional<FloatAttr>(
732 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
733 switch (a.getSizeInBits(a.getSemantics())) {
735 return APFloat(trunc(a.convertToDouble()));
737 return APFloat(truncf(a.convertToFloat()));
748 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
749 return builder.
create<ub::PoisonOp>(loc, type, poison);
751 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...