26 if (
auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
37 #define GET_OP_CLASSES
38 #include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](
const APFloat &a) { return abs(a); });
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](
const APInt &a) { return a.abs(); });
63 return constFoldUnaryOpConditional<FloatAttr>(
64 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
67 return APFloat(acos(a.convertToDouble()));
69 return APFloat(acosf(a.convertToFloat()));
81 return constFoldUnaryOpConditional<FloatAttr>(
82 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
85 return APFloat(acosh(a.convertToDouble()));
87 return APFloat(acoshf(a.convertToFloat()));
99 return constFoldUnaryOpConditional<FloatAttr>(
100 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
103 return APFloat(asin(a.convertToDouble()));
105 return APFloat(asinf(a.convertToFloat()));
117 return constFoldUnaryOpConditional<FloatAttr>(
118 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
121 return APFloat(asinh(a.convertToDouble()));
123 return APFloat(asinhf(a.convertToFloat()));
135 return constFoldUnaryOpConditional<FloatAttr>(
136 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
139 return APFloat(atan(a.convertToDouble()));
141 return APFloat(atanf(a.convertToFloat()));
153 return constFoldUnaryOpConditional<FloatAttr>(
154 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
157 return APFloat(atanh(a.convertToDouble()));
159 return APFloat(atanhf(a.convertToFloat()));
171 return constFoldBinaryOpConditional<FloatAttr>(
172 adaptor.getOperands(),
173 [](
const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
194 return constFoldUnaryOp<FloatAttr>(
195 adaptor.getOperands(), [](
const APFloat &a) {
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
206 OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208 [](
const APFloat &a,
const APFloat &b) {
220 return constFoldUnaryOpConditional<FloatAttr>(
221 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
224 return APFloat(cos(a.convertToDouble()));
226 return APFloat(cosf(a.convertToFloat()));
238 return constFoldUnaryOpConditional<FloatAttr>(
239 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
242 return APFloat(cosh(a.convertToDouble()));
244 return APFloat(coshf(a.convertToFloat()));
256 return constFoldUnaryOpConditional<FloatAttr>(
257 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
260 return APFloat(sin(a.convertToDouble()));
262 return APFloat(sinf(a.convertToFloat()));
274 return constFoldUnaryOpConditional<FloatAttr>(
275 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
278 return APFloat(sinh(a.convertToDouble()));
280 return APFloat(sinhf(a.convertToFloat()));
291 std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (
auto vt = mlir::dyn_cast<VectorType>(getOperand().
getType()))
293 return llvm::to_vector<4>(vt.getShape());
301 OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
302 return constFoldUnaryOp<IntegerAttr>(
303 adaptor.getOperands(),
304 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
311 OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
312 return constFoldUnaryOp<IntegerAttr>(
313 adaptor.getOperands(),
314 [](
const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
322 return constFoldUnaryOp<IntegerAttr>(
323 adaptor.getOperands(),
324 [](
const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
332 return constFoldUnaryOpConditional<FloatAttr>(
333 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
336 return APFloat(erf(a.convertToDouble()));
338 return APFloat(erff(a.convertToFloat()));
350 return constFoldUnaryOpConditional<FloatAttr>(
351 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
368 return constFoldBinaryOpConditional<IntegerAttr>(
369 adaptor.getOperands(),
370 [](
const APInt &base,
const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
373 APInt oneValue{width, 1ULL, true};
374 APInt minusOneValue{width, -1ULL, true};
379 if (power.isNegative()) {
383 if (base.eq(oneValue))
386 if (base.ne(minusOneValue))
392 return minusOneValue;
398 APInt result = oneValue;
399 APInt curBase = base;
400 APInt curPower = power;
402 if (curPower[0] == 1)
404 curPower.lshrInPlace(1);
405 if (curPower.isZero())
419 return constFoldUnaryOpConditional<FloatAttr>(
420 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
424 if (a.getSizeInBits(a.getSemantics()) == 64)
425 return APFloat(log(a.convertToDouble()));
427 if (a.getSizeInBits(a.getSemantics()) == 32)
428 return APFloat(logf(a.convertToFloat()));
439 return constFoldUnaryOpConditional<FloatAttr>(
440 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
444 if (a.getSizeInBits(a.getSemantics()) == 64)
445 return APFloat(log2(a.convertToDouble()));
447 if (a.getSizeInBits(a.getSemantics()) == 32)
448 return APFloat(log2f(a.convertToFloat()));
459 return constFoldUnaryOpConditional<FloatAttr>(
460 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
464 switch (a.getSizeInBits(a.getSemantics())) {
466 return APFloat(log10(a.convertToDouble()));
468 return APFloat(log10f(a.convertToFloat()));
480 return constFoldUnaryOpConditional<FloatAttr>(
481 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
482 switch (a.getSizeInBits(a.getSemantics())) {
484 if ((a + APFloat(1.0)).isNegative())
486 return APFloat(log1p(a.convertToDouble()));
488 if ((a + APFloat(1.0f)).isNegative())
490 return APFloat(log1pf(a.convertToFloat()));
502 return constFoldBinaryOpConditional<FloatAttr>(
503 adaptor.getOperands(),
504 [](
const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
505 if (a.getSizeInBits(a.getSemantics()) == 64 &&
506 b.getSizeInBits(b.getSemantics()) == 64)
507 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
509 if (a.getSizeInBits(a.getSemantics()) == 32 &&
510 b.getSizeInBits(b.getSemantics()) == 32)
511 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
522 return constFoldUnaryOpConditional<FloatAttr>(
523 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
527 switch (a.getSizeInBits(a.getSemantics())) {
529 return APFloat(sqrt(a.convertToDouble()));
531 return APFloat(sqrtf(a.convertToFloat()));
543 return constFoldUnaryOpConditional<FloatAttr>(
544 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
545 switch (a.getSizeInBits(a.getSemantics())) {
547 return APFloat(exp(a.convertToDouble()));
549 return APFloat(expf(a.convertToFloat()));
561 return constFoldUnaryOpConditional<FloatAttr>(
562 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
563 switch (a.getSizeInBits(a.getSemantics())) {
565 return APFloat(exp2(a.convertToDouble()));
567 return APFloat(exp2f(a.convertToFloat()));
579 return constFoldUnaryOpConditional<FloatAttr>(
580 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
581 switch (a.getSizeInBits(a.getSemantics())) {
583 return APFloat(expm1(a.convertToDouble()));
585 return APFloat(expm1f(a.convertToFloat()));
596 OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
597 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
598 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
600 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
603 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
613 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
614 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
616 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
619 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
629 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
630 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
632 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
635 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
644 OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
645 if (
auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
646 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
648 if (
auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
651 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
661 return constFoldUnaryOpConditional<FloatAttr>(
662 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
663 switch (a.getSizeInBits(a.getSemantics())) {
665 return APFloat(tan(a.convertToDouble()));
667 return APFloat(tanf(a.convertToFloat()));
679 return constFoldUnaryOpConditional<FloatAttr>(
680 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
681 switch (a.getSizeInBits(a.getSemantics())) {
683 return APFloat(tanh(a.convertToDouble()));
685 return APFloat(tanhf(a.convertToFloat()));
696 OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
697 return constFoldUnaryOp<FloatAttr>(
698 adaptor.getOperands(), [](
const APFloat &a) {
700 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
710 return constFoldUnaryOp<FloatAttr>(
711 adaptor.getOperands(), [](
const APFloat &a) {
713 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
723 return constFoldUnaryOpConditional<FloatAttr>(
724 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
725 switch (a.getSizeInBits(a.getSemantics())) {
727 return APFloat(round(a.convertToDouble()));
729 return APFloat(roundf(a.convertToFloat()));
741 return constFoldUnaryOpConditional<FloatAttr>(
742 adaptor.getOperands(), [](
const APFloat &a) -> std::optional<APFloat> {
743 switch (a.getSizeInBits(a.getSemantics())) {
745 return APFloat(trunc(a.convertToDouble()));
747 return APFloat(truncf(a.convertToFloat()));
758 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
759 return ub::PoisonOp::create(builder, loc, type, poison);
761 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...