26#include "llvm/ADT/APFloat.h" 
   27#include "llvm/ADT/APInt.h" 
   28#include "llvm/ADT/APSInt.h" 
   29#include "llvm/ADT/FloatingPointMode.h" 
   30#include "llvm/ADT/STLExtras.h" 
   31#include "llvm/ADT/SmallVector.h" 
   32#include "llvm/ADT/TypeSwitch.h" 
   44                    function_ref<APInt(
const APInt &, 
const APInt &)> binFn) {
 
   45  APInt lhsVal = llvm::cast<IntegerAttr>(
lhs).getValue();
 
   46  APInt rhsVal = llvm::cast<IntegerAttr>(
rhs).getValue();
 
   47  APInt value = binFn(lhsVal, rhsVal);
 
   48  return IntegerAttr::get(res.
getType(), value);
 
 
   67static IntegerOverflowFlagsAttr
 
   69                   IntegerOverflowFlagsAttr val2) {
 
   70  return IntegerOverflowFlagsAttr::get(val1.getContext(),
 
   71                                       val1.getValue() & val2.getValue());
 
 
   77  case arith::CmpIPredicate::eq:
 
   78    return arith::CmpIPredicate::ne;
 
   79  case arith::CmpIPredicate::ne:
 
   80    return arith::CmpIPredicate::eq;
 
   81  case arith::CmpIPredicate::slt:
 
   82    return arith::CmpIPredicate::sge;
 
   83  case arith::CmpIPredicate::sle:
 
   84    return arith::CmpIPredicate::sgt;
 
   85  case arith::CmpIPredicate::sgt:
 
   86    return arith::CmpIPredicate::sle;
 
   87  case arith::CmpIPredicate::sge:
 
   88    return arith::CmpIPredicate::slt;
 
   89  case arith::CmpIPredicate::ult:
 
   90    return arith::CmpIPredicate::uge;
 
   91  case arith::CmpIPredicate::ule:
 
   92    return arith::CmpIPredicate::ugt;
 
   93  case arith::CmpIPredicate::ugt:
 
   94    return arith::CmpIPredicate::ule;
 
   95  case arith::CmpIPredicate::uge:
 
   96    return arith::CmpIPredicate::ult;
 
   98  llvm_unreachable(
"unknown cmpi predicate kind");
 
 
  107static llvm::RoundingMode
 
  109  switch (roundingMode) {
 
  110  case RoundingMode::downward:
 
  111    return llvm::RoundingMode::TowardNegative;
 
  112  case RoundingMode::to_nearest_away:
 
  113    return llvm::RoundingMode::NearestTiesToAway;
 
  114  case RoundingMode::to_nearest_even:
 
  115    return llvm::RoundingMode::NearestTiesToEven;
 
  116  case RoundingMode::toward_zero:
 
  117    return llvm::RoundingMode::TowardZero;
 
  118  case RoundingMode::upward:
 
  119    return llvm::RoundingMode::TowardPositive;
 
  121  llvm_unreachable(
"Unhandled rounding mode");
 
 
  125  return arith::CmpIPredicateAttr::get(pred.getContext(),
 
 
  151  ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
 
 
  162#include "ArithCanonicalization.inc" 
  171  auto i1Type = IntegerType::get(type.
getContext(), 1);
 
  172  if (
auto shapedType = dyn_cast<ShapedType>(type))
 
  173    return shapedType.cloneWith(std::nullopt, i1Type);
 
  174  if (llvm::isa<UnrankedTensorType>(type))
 
  175    return UnrankedTensorType::get(i1Type);
 
 
  183void arith::ConstantOp::getAsmResultNames(
 
  186  if (
auto intCst = dyn_cast<IntegerAttr>(getValue())) {
 
  187    auto intType = dyn_cast<IntegerType>(type);
 
  190    if (intType && intType.getWidth() == 1)
 
  191      return setNameFn(getResult(), (intCst.getInt() ? 
"true" : 
"false"));
 
  194    SmallString<32> specialNameBuffer;
 
  195    llvm::raw_svector_ostream specialName(specialNameBuffer);
 
  196    specialName << 
'c' << intCst.getValue();
 
  198      specialName << 
'_' << type;
 
  199    setNameFn(getResult(), specialName.str());
 
  201    setNameFn(getResult(), 
"cst");
 
  207LogicalResult arith::ConstantOp::verify() {
 
  210  if (llvm::isa<IntegerType>(type) &&
 
  211      !llvm::cast<IntegerType>(type).isSignless())
 
  212    return emitOpError(
"integer return type must be signless");
 
  214  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
 
  216        "value must be an integer, float, or elements attribute");
 
  222  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
 
  224        "intializing scalable vectors with elements attribute is not supported" 
  225        " unless it's a vector splat");
 
  229bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
 
  231  auto typedAttr = dyn_cast<TypedAttr>(value);
 
  232  if (!typedAttr || typedAttr.getType() != type)
 
  235  if (llvm::isa<IntegerType>(type) &&
 
  236      !llvm::cast<IntegerType>(type).isSignless())
 
  239  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
 
  242ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
 
  243                                          Type type, Location loc) {
 
  244  if (isBuildableWith(value, type))
 
  245    return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
 
  249OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { 
return getValue(); }
 
  254  arith::ConstantOp::build(builder, 
result, type,
 
 
  264  auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
 
  265  assert(
result && 
"builder didn't return the right type");
 
 
  277  arith::ConstantOp::build(builder, 
result, type,
 
 
  286  auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
 
  287  assert(
result && 
"builder didn't return the right type");
 
 
  298  arith::ConstantOp::build(builder, 
result, type,
 
 
  304                                                  const APInt &
value) {
 
  307  auto result = dyn_cast<ConstantIntOp>(builder.
create(state));
 
  308  assert(
result && 
"builder didn't return the right type");
 
 
  314                                                  const APInt &
value) {
 
 
  319  if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
 
  320    return constOp.getType().isSignlessInteger();
 
 
  325                                   FloatType type, 
const APFloat &
value) {
 
  326  arith::ConstantOp::build(builder, 
result, type,
 
 
  333                                                      const APFloat &
value) {
 
  336  auto result = dyn_cast<ConstantFloatOp>(builder.
create(state));
 
  337  assert(
result && 
"builder didn't return the right type");
 
 
  343                               const APFloat &
value) {
 
 
  348  if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
 
  349    return llvm::isa<FloatType>(constOp.getType());
 
 
  364  auto result = dyn_cast<ConstantIndexOp>(builder.
create(state));
 
  365  assert(
result && 
"builder didn't return the right type");
 
 
  375  if (
auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
 
  376    return constOp.getType().isIndex();
 
 
  384         "type doesn't have a zero representation");
 
  386  assert(zeroAttr && 
"unsupported type for zero attribute");
 
  387  return arith::ConstantOp::create(builder, loc, zeroAttr);
 
 
  400  if (
auto sub = getLhs().getDefiningOp<SubIOp>())
 
  401    if (getRhs() == sub.getRhs())
 
  405  if (
auto sub = getRhs().getDefiningOp<SubIOp>())
 
  406    if (getLhs() == sub.getRhs())
 
  410      adaptor.getOperands(),
 
  411      [](APInt a, 
const APInt &
b) { return std::move(a) + b; });
 
  416  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
 
  417               AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
 
  424std::optional<SmallVector<int64_t, 4>>
 
  425arith::AddUIExtendedOp::getShapeForUnroll() {
 
  426  if (
auto vt = dyn_cast<VectorType>(
getType(0)))
 
  427    return llvm::to_vector<4>(vt.getShape());
 
  434  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
 
 
  438arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
 
  439                             SmallVectorImpl<OpFoldResult> &results) {
 
  440  Type overflowTy = getOverflow().getType();
 
  446    results.push_back(getLhs());
 
  447    results.push_back(falseValue);
 
  456          adaptor.getOperands(),
 
  457          [](APInt a, 
const APInt &
b) { return std::move(a) + b; })) {
 
  459        ArrayRef({sumAttr, adaptor.getLhs()}),
 
  465    results.push_back(sumAttr);
 
  466    results.push_back(overflowAttr);
 
  473void arith::AddUIExtendedOp::getCanonicalizationPatterns(
 
  474    RewritePatternSet &
patterns, MLIRContext *context) {
 
  475  patterns.add<AddUIExtendedToAddI>(context);
 
  482OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
 
  484  if (getOperand(0) == getOperand(1)) {
 
  485    auto shapedType = dyn_cast<ShapedType>(
getType());
 
  487    if (!shapedType || shapedType.hasStaticShape())
 
  494  if (
auto add = getLhs().getDefiningOp<AddIOp>()) {
 
  496    if (getRhs() == 
add.getRhs())
 
  499    if (getRhs() == 
add.getLhs())
 
  504      adaptor.getOperands(),
 
  505      [](APInt a, 
const APInt &
b) { return std::move(a) - b; });
 
  508void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
  509                                                MLIRContext *context) {
 
  510  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
 
  511               SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
 
  512               SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
 
  519OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
 
  530      adaptor.getOperands(),
 
  531      [](
const APInt &a, 
const APInt &
b) { return a * b; });
 
  534void arith::MulIOp::getAsmResultNames(
 
  536  if (!isa<IndexType>(
getType()))
 
  541  auto isVscale = [](Operation *op) {
 
  542    return op && op->getName().getStringRef() == 
"vector.vscale";
 
  545  IntegerAttr baseValue;
 
  546  auto isVscaleExpr = [&](Value a, Value 
b) {
 
  548           isVscale(
b.getDefiningOp());
 
  551  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
 
  555  SmallString<32> specialNameBuffer;
 
  556  llvm::raw_svector_ostream specialName(specialNameBuffer);
 
  557  specialName << 
'c' << baseValue.getInt() << 
"_vscale";
 
  558  setNameFn(getResult(), specialName.str());
 
  561void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
  562                                                MLIRContext *context) {
 
  563  patterns.add<MulIMulIConstant>(context);
 
  570std::optional<SmallVector<int64_t, 4>>
 
  571arith::MulSIExtendedOp::getShapeForUnroll() {
 
  572  if (
auto vt = dyn_cast<VectorType>(
getType(0)))
 
  573    return llvm::to_vector<4>(vt.getShape());
 
  578arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
 
  579                             SmallVectorImpl<OpFoldResult> &results) {
 
  582    Attribute zero = adaptor.getRhs();
 
  583    results.push_back(zero);
 
  584    results.push_back(zero);
 
  590          adaptor.getOperands(),
 
  591          [](
const APInt &a, 
const APInt &
b) { return a * b; })) {
 
  594        adaptor.getOperands(), [](
const APInt &a, 
const APInt &
b) {
 
  595          return llvm::APIntOps::mulhs(a, b);
 
  597    assert(highAttr && 
"Unexpected constant-folding failure");
 
  599    results.push_back(lowAttr);
 
  600    results.push_back(highAttr);
 
  607void arith::MulSIExtendedOp::getCanonicalizationPatterns(
 
  608    RewritePatternSet &
patterns, MLIRContext *context) {
 
  609  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
 
  616std::optional<SmallVector<int64_t, 4>>
 
  617arith::MulUIExtendedOp::getShapeForUnroll() {
 
  618  if (
auto vt = dyn_cast<VectorType>(
getType(0)))
 
  619    return llvm::to_vector<4>(vt.getShape());
 
  624arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
 
  625                             SmallVectorImpl<OpFoldResult> &results) {
 
  628    Attribute zero = adaptor.getRhs();
 
  629    results.push_back(zero);
 
  630    results.push_back(zero);
 
  638    results.push_back(getLhs());
 
  639    results.push_back(zero);
 
  645          adaptor.getOperands(),
 
  646          [](
const APInt &a, 
const APInt &
b) { return a * b; })) {
 
  649        adaptor.getOperands(), [](
const APInt &a, 
const APInt &
b) {
 
  650          return llvm::APIntOps::mulhu(a, b);
 
  652    assert(highAttr && 
"Unexpected constant-folding failure");
 
  654    results.push_back(lowAttr);
 
  655    results.push_back(highAttr);
 
  662void arith::MulUIExtendedOp::getCanonicalizationPatterns(
 
  663    RewritePatternSet &
patterns, MLIRContext *context) {
 
  664  patterns.add<MulUIExtendedToMulI>(context);
 
  673                        arith::IntegerOverflowFlags ovfFlags) {
 
  674  auto mul = 
lhs.getDefiningOp<mlir::arith::MulIOp>();
 
  675  if (!
mul || !bitEnumContainsAll(
mul.getOverflowFlags(), ovfFlags))
 
 
  687OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 
  693  if (Value val = 
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
 
  699                                               [&](APInt a, 
const APInt &
b) {
 
  707  return div0 ? Attribute() : 
result;
 
  727OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
 
  733  if (Value val = 
foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
 
  737  bool overflowOrDiv0 = 
false;
 
  739      adaptor.getOperands(), [&](APInt a, 
const APInt &
b) {
 
  740        if (overflowOrDiv0 || !b) {
 
  741          overflowOrDiv0 = true;
 
  744        return a.sdiv_ov(
b, overflowOrDiv0);
 
  747  return overflowOrDiv0 ? Attribute() : 
result;
 
  774  APInt one(a.getBitWidth(), 1, 
true); 
 
  775  APInt val = a.ssub_ov(one, overflow).sdiv_ov(
b, overflow);
 
  776  return val.sadd_ov(one, overflow);
 
 
  783OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 
  788  bool overflowOrDiv0 = 
false;
 
  790      adaptor.getOperands(), [&](APInt a, 
const APInt &
b) {
 
  791        if (overflowOrDiv0 || !b) {
 
  792          overflowOrDiv0 = true;
 
  795        APInt quotient = a.udiv(
b);
 
  798        APInt one(a.getBitWidth(), 1, 
true);
 
  799        return quotient.uadd_ov(one, overflowOrDiv0);
 
  802  return overflowOrDiv0 ? Attribute() : 
result;
 
  813OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
 
  821  bool overflowOrDiv0 = 
false;
 
  823      adaptor.getOperands(), [&](APInt a, 
const APInt &
b) {
 
  824        if (overflowOrDiv0 || !b) {
 
  825          overflowOrDiv0 = true;
 
  831        unsigned bits = a.getBitWidth();
 
  832        APInt zero = APInt::getZero(bits);
 
  833        bool aGtZero = a.sgt(zero);
 
  834        bool bGtZero = 
b.sgt(zero);
 
  835        if (aGtZero && bGtZero) {
 
  842        bool overflowNegA = 
false;
 
  843        bool overflowNegB = 
false;
 
  844        bool overflowDiv = 
false;
 
  845        bool overflowNegRes = 
false;
 
  846        if (!aGtZero && !bGtZero) {
 
  848          APInt posA = zero.ssub_ov(a, overflowNegA);
 
  849          APInt posB = zero.ssub_ov(
b, overflowNegB);
 
  851          overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
 
  854        if (!aGtZero && bGtZero) {
 
  856          APInt posA = zero.ssub_ov(a, overflowNegA);
 
  857          APInt 
div = posA.sdiv_ov(
b, overflowDiv);
 
  858          APInt res = zero.ssub_ov(
div, overflowNegRes);
 
  859          overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
 
  863        APInt posB = zero.ssub_ov(
b, overflowNegB);
 
  864        APInt 
div = a.sdiv_ov(posB, overflowDiv);
 
  865        APInt res = zero.ssub_ov(
div, overflowNegRes);
 
  867        overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
 
  871  return overflowOrDiv0 ? Attribute() : 
result;
 
  882OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
 
  888  bool overflowOrDiv = 
false;
 
  890      adaptor.getOperands(), [&](APInt a, 
const APInt &
b) {
 
  892          overflowOrDiv = true;
 
  895        return a.sfloordiv_ov(
b, overflowOrDiv);
 
  898  return overflowOrDiv ? Attribute() : 
result;
 
  905OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
 
  913                                               [&](APInt a, 
const APInt &
b) {
 
  914                                                 if (div0 || b.isZero()) {
 
  921  return div0 ? Attribute() : 
result;
 
  928OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
 
  936                                               [&](APInt a, 
const APInt &
b) {
 
  937                                                 if (div0 || b.isZero()) {
 
  944  return div0 ? Attribute() : 
result;
 
  953  for (
bool reversePrev : {
false, 
true}) {
 
  954    auto prev = (reversePrev ? op.getRhs() : op.getLhs())
 
  955                    .getDefiningOp<arith::AndIOp>();
 
  959    Value other = (reversePrev ? op.getLhs() : op.getRhs());
 
  960    if (other != prev.getLhs() && other != prev.getRhs())
 
  963    return prev.getResult();
 
 
  968OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
 
  975      intValue.isAllOnes())
 
  980      intValue.isAllOnes())
 
  985      intValue.isAllOnes())
 
  993      adaptor.getOperands(),
 
  994      [](APInt a, 
const APInt &
b) { return std::move(a) & b; });
 
 1001OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
 
 1004    if (rhsVal.isZero())
 
 1007    if (rhsVal.isAllOnes())
 
 1008      return adaptor.getRhs();
 
 1015      intValue.isAllOnes())
 
 1016    return getRhs().getDefiningOp<XOrIOp>().getRhs();
 
 1020      intValue.isAllOnes())
 
 1021    return getLhs().getDefiningOp<XOrIOp>().getRhs();
 
 1024      adaptor.getOperands(),
 
 1025      [](APInt a, 
const APInt &
b) { return std::move(a) | b; });
 
 1032OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
 
 1037  if (getLhs() == getRhs())
 
 1041  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
 
 1042    if (prev.getRhs() == getRhs())
 
 1043      return prev.getLhs();
 
 1044    if (prev.getLhs() == getRhs())
 
 1045      return prev.getRhs();
 
 1049  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
 
 1050    if (prev.getRhs() == getLhs())
 
 1051      return prev.getLhs();
 
 1052    if (prev.getLhs() == getLhs())
 
 1053      return prev.getRhs();
 
 1057      adaptor.getOperands(),
 
 1058      [](APInt a, 
const APInt &
b) { return std::move(a) ^ b; });
 
 1061void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1062                                                MLIRContext *context) {
 
 1063  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
 
 1070OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
 
 1072  if (
auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
 
 1073    return op.getOperand();
 
 1075                                     [](
const APFloat &a) { return -a; });
 
 1082OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
 
 1088      adaptor.getOperands(),
 
 1089      [](
const APFloat &a, 
const APFloat &
b) { return a + b; });
 
 1096OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
 
 1102      adaptor.getOperands(),
 
 1103      [](
const APFloat &a, 
const APFloat &
b) { return a - b; });
 
 1110OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
 
 1112  if (getLhs() == getRhs())
 
 1120      adaptor.getOperands(),
 
 1121      [](
const APFloat &a, 
const APFloat &
b) { return llvm::maximum(a, b); });
 
 1128OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
 
 1130  if (getLhs() == getRhs())
 
 1144OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
 
 1146  if (getLhs() == getRhs())
 
 1152    if (intValue.isMaxSignedValue())
 
 1155    if (intValue.isMinSignedValue())
 
 1160                                        [](
const APInt &a, 
const APInt &
b) {
 
 1161                                          return llvm::APIntOps::smax(a, b);
 
 1169OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
 
 1171  if (getLhs() == getRhs())
 
 1177    if (intValue.isMaxValue())
 
 1180    if (intValue.isMinValue())
 
 1185                                        [](
const APInt &a, 
const APInt &
b) {
 
 1186                                          return llvm::APIntOps::umax(a, b);
 
 1194OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
 
 1196  if (getLhs() == getRhs())
 
 1204      adaptor.getOperands(),
 
 1205      [](
const APFloat &a, 
const APFloat &
b) { return llvm::minimum(a, b); });
 
 1212OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
 
 1214  if (getLhs() == getRhs())
 
 1222      adaptor.getOperands(),
 
 1223      [](
const APFloat &a, 
const APFloat &
b) { return llvm::minnum(a, b); });
 
 1230OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
 
 1232  if (getLhs() == getRhs())
 
 1238    if (intValue.isMinSignedValue())
 
 1241    if (intValue.isMaxSignedValue())
 
 1246                                        [](
const APInt &a, 
const APInt &
b) {
 
 1247                                          return llvm::APIntOps::smin(a, b);
 
 1255OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
 
 1257  if (getLhs() == getRhs())
 
 1263    if (intValue.isMinValue())
 
 1266    if (intValue.isMaxValue())
 
 1271                                        [](
const APInt &a, 
const APInt &
b) {
 
 1272                                          return llvm::APIntOps::umin(a, b);
 
 1280OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
 
 1285  if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
 
 1286                                                   arith::FastMathFlags::nsz)) {
 
 1293      adaptor.getOperands(),
 
 1294      [](
const APFloat &a, 
const APFloat &
b) { return a * b; });
 
 1297void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1298                                                MLIRContext *context) {
 
 1306OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
 
 1312      adaptor.getOperands(),
 
 1313      [](
const APFloat &a, 
const APFloat &
b) { return a / b; });
 
 1316void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1317                                                MLIRContext *context) {
 
 1325OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
 
 1327                                      [](
const APFloat &a, 
const APFloat &
b) {
 
 1332                                        (void)result.mod(b);
 
 1341template <
typename... Types>
 
 1347template <
typename... ShapedTypes, 
typename... ElementTypes>
 
 1350  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
 
 1354  if (!llvm::isa<ElementTypes...>(underlyingType))
 
 1357  return underlyingType;
 
 
 1361template <
typename... ElementTypes>
 
 1368template <
typename... ElementTypes>
 
 1377  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
 
 1378  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
 
 1379  if (!rankedTensorA || !rankedTensorB)
 
 1381  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
 
 
 1385  if (inputs.size() != 1 || outputs.size() != 1)
 
 
 1397template <
typename ValType, 
typename Op>
 
 1402  if (llvm::cast<ValType>(srcType).getWidth() >=
 
 1403      llvm::cast<ValType>(dstType).getWidth())
 
 1405           << dstType << 
" must be wider than operand type " << srcType;
 
 
 1411template <
typename ValType, 
typename Op>
 
 1416  if (llvm::cast<ValType>(srcType).getWidth() <=
 
 1417      llvm::cast<ValType>(dstType).getWidth())
 
 1419           << dstType << 
" must be shorter than operand type " << srcType;
 
 
 1425template <
template <
typename> 
class WidthComparator, 
typename... ElementTypes>
 
 1430  auto srcType = 
getTypeIfLike<ElementTypes...>(inputs.front());
 
 1431  auto dstType = 
getTypeIfLike<ElementTypes...>(outputs.front());
 
 1432  if (!srcType || !dstType)
 
 1435  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
 
 1436                                     srcType.getIntOrFloatBitWidth());
 
 
 1442    APFloat sourceValue, 
const llvm::fltSemantics &targetSemantics,
 
 1443    llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
 
 1444  bool losesInfo = 
false;
 
 1445  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
 
 1446  if (losesInfo || status != APFloat::opOK)
 
 
 1456OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
 
 1457  if (
auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
 
 1458    getInMutable().assign(
lhs.getIn());
 
 1463  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
 
 1465      adaptor.getOperands(), 
getType(),
 
 1466      [bitWidth](
const APInt &a, 
bool &castStatus) {
 
 1467        return a.zext(bitWidth);
 
 1475LogicalResult arith::ExtUIOp::verify() {
 
 1483OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
 
 1484  if (
auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
 
 1485    getInMutable().assign(
lhs.getIn());
 
 1490  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
 
 1492      adaptor.getOperands(), 
getType(),
 
 1493      [bitWidth](
const APInt &a, 
bool &castStatus) {
 
 1494        return a.sext(bitWidth);
 
 1502void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1503                                                 MLIRContext *context) {
 
 1504  patterns.add<ExtSIOfExtUI>(context);
 
 1507LogicalResult arith::ExtSIOp::verify() {
 
 1517OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
 
 1518  if (
auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
 
 1519    if (truncFOp.getOperand().getType() == 
getType()) {
 
 1520      arith::FastMathFlags truncFMF =
 
 1521          truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
 
 1522      bool isTruncContract =
 
 1523          bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
 
 1524      arith::FastMathFlags extFMF =
 
 1525          getFastmath().value_or(arith::FastMathFlags::none);
 
 1526      bool isExtContract =
 
 1527          bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
 
 1528      if (isTruncContract && isExtContract) {
 
 1529        return truncFOp.getOperand();
 
 1535  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
 
 1537      adaptor.getOperands(), 
getType(),
 
 1538      [&targetSemantics](
const APFloat &a, 
bool &castStatus) {
 
 1558bool arith::ScalingExtFOp::areCastCompatible(
TypeRange inputs,
 
 1563LogicalResult arith::ScalingExtFOp::verify() {
 
 1571OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
 
 1574    Value src = getOperand().getDefiningOp()->getOperand(0);
 
 1579    if (llvm::cast<IntegerType>(srcType).getWidth() >
 
 1580        llvm::cast<IntegerType>(dstType).getWidth()) {
 
 1587    if (srcType == dstType)
 
 1593    setOperand(getOperand().getDefiningOp()->getOperand(0));
 
 1598  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
 
 1600      adaptor.getOperands(), 
getType(),
 
 1601      [bitWidth](
const APInt &a, 
bool &castStatus) {
 
 1602        return a.trunc(bitWidth);
 
 1610void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1611                                                  MLIRContext *context) {
 
 1613      .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
 
 1617LogicalResult arith::TruncIOp::verify() {
 
 1627OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
 
 1629  if (
auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
 
 1630    Value src = extOp.getIn();
 
 1632    auto intermediateType =
 
 1635    if (llvm::APFloatBase::isRepresentableBy(
 
 1636            srcType.getFloatSemantics(),
 
 1637            intermediateType.getFloatSemantics())) {
 
 1639      if (srcType.getWidth() > resElemType.getWidth()) {
 
 1645      if (srcType == resElemType)
 
 1650  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
 
 1652      adaptor.getOperands(), 
getType(),
 
 1653      [
this, &targetSemantics](
const APFloat &a, 
bool &castStatus) {
 
 1654        RoundingMode roundingMode =
 
 1655            getRoundingmode().value_or(RoundingMode::to_nearest_even);
 
 1656        llvm::RoundingMode llvmRoundingMode =
 
 1658        FailureOr<APFloat> 
result =
 
 1668void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1669                                                  MLIRContext *context) {
 
 1670  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
 
 1677LogicalResult arith::TruncFOp::verify() {
 
 1685bool arith::ScalingTruncFOp::areCastCompatible(
TypeRange inputs,
 
 1690LogicalResult arith::ScalingTruncFOp::verify() {
 
 1698void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1699                                                MLIRContext *context) {
 
 1700  patterns.add<AndOfExtUI, AndOfExtSI>(context);
 
 1707void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1708                                               MLIRContext *context) {
 
 1709  patterns.add<OrOfExtUI, OrOfExtSI>(context);
 
 1716template <
typename From, 
typename To>
 
 1724  return srcType && dstType;
 
 
 1735OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
 
 1738      adaptor.getOperands(), 
getType(),
 
 1739      [&resEleType](
const APInt &a, 
bool &castStatus) {
 
 1740        FloatType floatTy = llvm::cast<FloatType>(resEleType);
 
 1741        APFloat apf(floatTy.getFloatSemantics(),
 
 1742                    APInt::getZero(floatTy.getWidth()));
 
 1743        apf.convertFromAPInt(a, 
false,
 
 1744                             APFloat::rmNearestTiesToEven);
 
 1757OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
 
 1760      adaptor.getOperands(), 
getType(),
 
 1761      [&resEleType](
const APInt &a, 
bool &castStatus) {
 
 1762        FloatType floatTy = llvm::cast<FloatType>(resEleType);
 
 1763        APFloat apf(floatTy.getFloatSemantics(),
 
 1764                    APInt::getZero(floatTy.getWidth()));
 
 1765        apf.convertFromAPInt(a, 
true,
 
 1766                             APFloat::rmNearestTiesToEven);
 
 1779OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
 
 1781  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
 
 1783      adaptor.getOperands(), 
getType(),
 
 1784      [&bitWidth](
const APFloat &a, 
bool &castStatus) {
 
 1786        APSInt api(bitWidth, 
true);
 
 1787        castStatus = APFloat::opInvalidOp !=
 
 1788                     a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
 
 1801OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
 
 1803  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
 
 1805      adaptor.getOperands(), 
getType(),
 
 1806      [&bitWidth](
const APFloat &a, 
bool &castStatus) {
 
 1808        APSInt api(bitWidth, 
false);
 
 1809        castStatus = APFloat::opInvalidOp !=
 
 1810                     a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
 
 1825  if (!srcType || !dstType)
 
 
 1832bool arith::IndexCastOp::areCastCompatible(
TypeRange inputs,
 
 1837OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
 
 1839  unsigned resultBitwidth = 64; 
 
 1841    resultBitwidth = intTy.getWidth();
 
 1844      adaptor.getOperands(), 
getType(),
 
 1845      [resultBitwidth](
const APInt &a, 
bool & ) {
 
 1846        return a.sextOrTrunc(resultBitwidth);
 
 1850void arith::IndexCastOp::getCanonicalizationPatterns(
 
 1851    RewritePatternSet &
patterns, MLIRContext *context) {
 
 1852  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
 
 1859bool arith::IndexCastUIOp::areCastCompatible(
TypeRange inputs,
 
 1864OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
 
 1866  unsigned resultBitwidth = 64; 
 
 1868    resultBitwidth = intTy.getWidth();
 
 1871      adaptor.getOperands(), 
getType(),
 
 1872      [resultBitwidth](
const APInt &a, 
bool & ) {
 
 1873        return a.zextOrTrunc(resultBitwidth);
 
 1877void arith::IndexCastUIOp::getCanonicalizationPatterns(
 
 1878    RewritePatternSet &
patterns, MLIRContext *context) {
 
 1879  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
 
 1892  if (!srcType || !dstType)
 
 1898OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
 
 1900  auto operand = adaptor.getIn();
 
 1905  if (
auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
 
 1906    return denseAttr.bitcast(llvm::cast<ShapedType>(resType).
getElementType());
 
 1908  if (llvm::isa<ShapedType>(resType))
 
 1912  if (llvm::isa<ub::PoisonAttr>(operand))
 
 1916  APInt bits = llvm::isa<FloatAttr>(operand)
 
 1917                   ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
 
 1918                   : llvm::cast<IntegerAttr>(operand).getValue();
 
 1920         "trying to fold on broken IR: operands have incompatible types");
 
 1922  if (
auto resFloatType = dyn_cast<FloatType>(resType))
 
 1923    return FloatAttr::get(resType,
 
 1924                          APFloat(resFloatType.getFloatSemantics(), bits));
 
 1925  return IntegerAttr::get(resType, bits);
 
 1928void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 1929                                                   MLIRContext *context) {
 
 1930  patterns.add<BitcastOfBitcast>(context);
 
 1940                                    const APInt &
lhs, 
const APInt &
rhs) {
 
 1941  switch (predicate) {
 
 1942  case arith::CmpIPredicate::eq:
 
 1944  case arith::CmpIPredicate::ne:
 
 1946  case arith::CmpIPredicate::slt:
 
 1948  case arith::CmpIPredicate::sle:
 
 1950  case arith::CmpIPredicate::sgt:
 
 1952  case arith::CmpIPredicate::sge:
 
 1954  case arith::CmpIPredicate::ult:
 
 1956  case arith::CmpIPredicate::ule:
 
 1958  case arith::CmpIPredicate::ugt:
 
 1960  case arith::CmpIPredicate::uge:
 
 1963  llvm_unreachable(
"unknown cmpi predicate kind");
 
 
 1968  switch (predicate) {
 
 1969  case arith::CmpIPredicate::eq:
 
 1970  case arith::CmpIPredicate::sle:
 
 1971  case arith::CmpIPredicate::sge:
 
 1972  case arith::CmpIPredicate::ule:
 
 1973  case arith::CmpIPredicate::uge:
 
 1975  case arith::CmpIPredicate::ne:
 
 1976  case arith::CmpIPredicate::slt:
 
 1977  case arith::CmpIPredicate::sgt:
 
 1978  case arith::CmpIPredicate::ult:
 
 1979  case arith::CmpIPredicate::ugt:
 
 1982  llvm_unreachable(
"unknown cmpi predicate kind");
 
 
 1986  if (
auto intType = dyn_cast<IntegerType>(t)) {
 
 1987    return intType.getWidth();
 
 1989  if (
auto vectorIntType = dyn_cast<VectorType>(t)) {
 
 1990    return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
 
 1992  return std::nullopt;
 
 
 1995OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
 
 1997  if (getLhs() == getRhs()) {
 
 2003    if (
auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
 
 2005      std::optional<int64_t> integerWidth =
 
 2007      if (integerWidth && integerWidth.value() == 1 &&
 
 2008          getPredicate() == arith::CmpIPredicate::ne)
 
 2009        return extOp.getOperand();
 
 2011    if (
auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
 
 2013      std::optional<int64_t> integerWidth =
 
 2015      if (integerWidth && integerWidth.value() == 1 &&
 
 2016          getPredicate() == arith::CmpIPredicate::ne)
 
 2017        return extOp.getOperand();
 
 2022        getPredicate() == arith::CmpIPredicate::ne)
 
 2029        getPredicate() == arith::CmpIPredicate::eq)
 
 2034  if (adaptor.getLhs() && !adaptor.getRhs()) {
 
 2036    using Pred = CmpIPredicate;
 
 2037    const std::pair<Pred, Pred> invPreds[] = {
 
 2038        {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
 
 2039        {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
 
 2040        {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
 
 2041        {Pred::ne, Pred::ne},
 
 2043    Pred origPred = getPredicate();
 
 2044    for (
auto pred : invPreds) {
 
 2045      if (origPred == pred.first) {
 
 2046        setPredicate(pred.second);
 
 2047        Value 
lhs = getLhs();
 
 2048        Value 
rhs = getRhs();
 
 2049        getLhsMutable().assign(
rhs);
 
 2050        getRhsMutable().assign(
lhs);
 
 2054    llvm_unreachable(
"unknown cmpi predicate kind");
 
 2059  if (
auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
 
 2062        [pred = getPredicate()](
const APInt &
lhs, 
const APInt &
rhs) {
 
 2071void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 2072                                                MLIRContext *context) {
 
 2073  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
 
 2083                                    const APFloat &
lhs, 
const APFloat &
rhs) {
 
 2084  auto cmpResult = 
lhs.compare(
rhs);
 
 2085  switch (predicate) {
 
 2086  case arith::CmpFPredicate::AlwaysFalse:
 
 2088  case arith::CmpFPredicate::OEQ:
 
 2089    return cmpResult == APFloat::cmpEqual;
 
 2090  case arith::CmpFPredicate::OGT:
 
 2091    return cmpResult == APFloat::cmpGreaterThan;
 
 2092  case arith::CmpFPredicate::OGE:
 
 2093    return cmpResult == APFloat::cmpGreaterThan ||
 
 2094           cmpResult == APFloat::cmpEqual;
 
 2095  case arith::CmpFPredicate::OLT:
 
 2096    return cmpResult == APFloat::cmpLessThan;
 
 2097  case arith::CmpFPredicate::OLE:
 
 2098    return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
 
 2099  case arith::CmpFPredicate::ONE:
 
 2100    return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
 
 2101  case arith::CmpFPredicate::ORD:
 
 2102    return cmpResult != APFloat::cmpUnordered;
 
 2103  case arith::CmpFPredicate::UEQ:
 
 2104    return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
 
 2105  case arith::CmpFPredicate::UGT:
 
 2106    return cmpResult == APFloat::cmpUnordered ||
 
 2107           cmpResult == APFloat::cmpGreaterThan;
 
 2108  case arith::CmpFPredicate::UGE:
 
 2109    return cmpResult == APFloat::cmpUnordered ||
 
 2110           cmpResult == APFloat::cmpGreaterThan ||
 
 2111           cmpResult == APFloat::cmpEqual;
 
 2112  case arith::CmpFPredicate::ULT:
 
 2113    return cmpResult == APFloat::cmpUnordered ||
 
 2114           cmpResult == APFloat::cmpLessThan;
 
 2115  case arith::CmpFPredicate::ULE:
 
 2116    return cmpResult == APFloat::cmpUnordered ||
 
 2117           cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
 
 2118  case arith::CmpFPredicate::UNE:
 
 2119    return cmpResult != APFloat::cmpEqual;
 
 2120  case arith::CmpFPredicate::UNO:
 
 2121    return cmpResult == APFloat::cmpUnordered;
 
 2122  case arith::CmpFPredicate::AlwaysTrue:
 
 2125  llvm_unreachable(
"unknown cmpf predicate kind");
 
 
 2129  auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
 
 2130  auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
 
 2133  if (
lhs && 
lhs.getValue().isNaN())
 
 2135  if (
rhs && 
rhs.getValue().isNaN())
 
 2151    using namespace arith;
 
 2153    case CmpFPredicate::UEQ:
 
 2154    case CmpFPredicate::OEQ:
 
 2155      return CmpIPredicate::eq;
 
 2156    case CmpFPredicate::UGT:
 
 2157    case CmpFPredicate::OGT:
 
 2158      return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
 
 2159    case CmpFPredicate::UGE:
 
 2160    case CmpFPredicate::OGE:
 
 2161      return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
 
 2162    case CmpFPredicate::ULT:
 
 2163    case CmpFPredicate::OLT:
 
 2164      return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
 
 2165    case CmpFPredicate::ULE:
 
 2166    case CmpFPredicate::OLE:
 
 2167      return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
 
 2168    case CmpFPredicate::UNE:
 
 2169    case CmpFPredicate::ONE:
 
 2170      return CmpIPredicate::ne;
 
 2172      llvm_unreachable(
"Unexpected predicate!");
 
 
 2182    const APFloat &
rhs = flt.getValue();
 
 2190    FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
 
 2191    int mantissaWidth = floatTy.getFPMantissaWidth();
 
 2192    if (mantissaWidth <= 0)
 
 2198    if (
auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
 
 2200      intVal = si.getIn();
 
 2201    } 
else if (
auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
 
 2203      intVal = ui.getIn();
 
 2210    auto intTy = llvm::cast<IntegerType>(intVal.
getType());
 
 2211    auto intWidth = intTy.getWidth();
 
 2214    auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
 
 2219    if ((
int)intWidth > mantissaWidth) {
 
 2221      int exponent = ilogb(
rhs);
 
 2222      if (exponent == APFloat::IEK_Inf) {
 
 2223        int maxExponent = ilogb(APFloat::getLargest(
rhs.getSemantics()));
 
 2224        if (maxExponent < (
int)valueBits) {
 
 2231        if (mantissaWidth <= exponent && exponent <= (
int)valueBits) {
 
 2240    switch (op.getPredicate()) {
 
 2241    case CmpFPredicate::ORD:
 
 2246    case CmpFPredicate::UNO:
 
 2259      APFloat signedMax(
rhs.getSemantics());
 
 2260      signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), 
true,
 
 2261                                 APFloat::rmNearestTiesToEven);
 
 2262      if (signedMax < 
rhs) { 
 
 2263        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
 
 2264            pred == CmpIPredicate::sle)
 
 2275      APFloat unsignedMax(
rhs.getSemantics());
 
 2276      unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), 
false,
 
 2277                                   APFloat::rmNearestTiesToEven);
 
 2278      if (unsignedMax < 
rhs) { 
 
 2279        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
 
 2280            pred == CmpIPredicate::ule)
 
 2292      APFloat signedMin(
rhs.getSemantics());
 
 2293      signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), 
true,
 
 2294                                 APFloat::rmNearestTiesToEven);
 
 2295      if (signedMin > 
rhs) { 
 
 2296        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
 
 2297            pred == CmpIPredicate::sge)
 
 2307      APFloat unsignedMin(
rhs.getSemantics());
 
 2308      unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), 
false,
 
 2309                                   APFloat::rmNearestTiesToEven);
 
 2310      if (unsignedMin > 
rhs) { 
 
 2311        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
 
 2312            pred == CmpIPredicate::uge)
 
 2327    APSInt rhsInt(intWidth, isUnsigned);
 
 2328    if (APFloat::opInvalidOp ==
 
 2329        rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
 
 2335    if (!
rhs.isZero()) {
 
 2336      APFloat apf(floatTy.getFloatSemantics(),
 
 2337                  APInt::getZero(floatTy.getWidth()));
 
 2338      apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
 
 2340      bool equal = apf == 
rhs;
 
 2346        case CmpIPredicate::ne: 
 
 2350        case CmpIPredicate::eq: 
 
 2354        case CmpIPredicate::ule:
 
 2357          if (
rhs.isNegative()) {
 
 2363        case CmpIPredicate::sle:
 
 2366          if (
rhs.isNegative())
 
 2367            pred = CmpIPredicate::slt;
 
 2369        case CmpIPredicate::ult:
 
 2372          if (
rhs.isNegative()) {
 
 2377          pred = CmpIPredicate::ule;
 
 2379        case CmpIPredicate::slt:
 
 2382          if (!
rhs.isNegative())
 
 2383            pred = CmpIPredicate::sle;
 
 2385        case CmpIPredicate::ugt:
 
 2388          if (
rhs.isNegative()) {
 
 2394        case CmpIPredicate::sgt:
 
 2397          if (
rhs.isNegative())
 
 2398            pred = CmpIPredicate::sge;
 
 2400        case CmpIPredicate::uge:
 
 2403          if (
rhs.isNegative()) {
 
 2408          pred = CmpIPredicate::ugt;
 
 2410        case CmpIPredicate::sge:
 
 2413          if (!
rhs.isNegative())
 
 2414            pred = CmpIPredicate::sgt;
 
 2424        ConstantOp::create(rewriter, op.getLoc(), intVal.
getType(),
 
 
 
 2430void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
 
 2431                                                MLIRContext *context) {
 
 2432  patterns.insert<CmpFIntToFPConst>(context);
 
 2446    if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
 
 2462          arith::XOrIOp::create(
 
 2463              rewriter, op.getLoc(), op.getCondition(),
 
 2465                                           op.getCondition().
getType(), 1)));
 
 
 
 2473void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 2474                                                  MLIRContext *context) {
 
 2475  results.
add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
 
 2476              SelectI1ToNot, SelectToExtUI>(context);
 
 2479OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
 
 2480  Value trueVal = getTrueValue();
 
 2481  Value falseVal = getFalseValue();
 
 2482  if (trueVal == falseVal)
 
 2485  Value condition = getCondition();
 
 2496  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
 
 2499  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
 
 2503  if (
getType().isSignlessInteger(1) &&
 
 2509    auto pred = cmp.getPredicate();
 
 2510    if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
 
 2511      auto cmpLhs = cmp.getLhs();
 
 2512      auto cmpRhs = cmp.getRhs();
 
 2520      if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
 
 2521          (cmpRhs == trueVal && cmpLhs == falseVal))
 
 2522        return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
 
 2529          dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
 
 2531            dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
 
 2533              dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
 
 2534        SmallVector<Attribute> results;
 
 2535        results.reserve(
static_cast<size_t>(cond.getNumElements()));
 
 2536        auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
 
 2537                                         cond.value_end<BoolAttr>());
 
 2538        auto lhsVals = llvm::make_range(
lhs.value_begin<Attribute>(),
 
 2539                                        lhs.value_end<Attribute>());
 
 2540        auto rhsVals = llvm::make_range(
rhs.value_begin<Attribute>(),
 
 2541                                        rhs.value_end<Attribute>());
 
 2543        for (
auto [condVal, lhsVal, rhsVal] :
 
 2544             llvm::zip_equal(condVals, lhsVals, rhsVals))
 
 2545          results.push_back(condVal.getValue() ? lhsVal : rhsVal);
 
 2555ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 2556  Type conditionType, resultType;
 
 2557  SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
 
 2565    conditionType = resultType;
 
 2572  result.addTypes(resultType);
 
 2574                                {conditionType, resultType, resultType},
 
 2578void arith::SelectOp::print(OpAsmPrinter &p) {
 
 2579  p << 
" " << getOperands();
 
 2582  if (ShapedType condType = dyn_cast<ShapedType>(getCondition().
getType()))
 
 2583    p << condType << 
", ";
 
 2587LogicalResult arith::SelectOp::verify() {
 
 2588  Type conditionType = getCondition().getType();
 
 2595  if (!llvm::isa<TensorType, VectorType>(resultType))
 
 2596    return emitOpError() << 
"expected condition to be a signless i1, but got " 
 2599  if (conditionType != shapedConditionType) {
 
 2600    return emitOpError() << 
"expected condition type to have the same shape " 
 2601                            "as the result type, expected " 
 2602                         << shapedConditionType << 
", but got " 
 2611OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
 
 2616  bool bounded = 
false;
 
 2618      adaptor.getOperands(), [&](
const APInt &a, 
const APInt &
b) {
 
 2619        bounded = b.ult(b.getBitWidth());
 
 2622  return bounded ? 
result : Attribute();
 
 2629OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
 
 2634  bool bounded = 
false;
 
 2636      adaptor.getOperands(), [&](
const APInt &a, 
const APInt &
b) {
 
 2637        bounded = b.ult(b.getBitWidth());
 
 2640  return bounded ? 
result : Attribute();
 
 2647OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
 
 2652  bool bounded = 
false;
 
 2654      adaptor.getOperands(), [&](
const APInt &a, 
const APInt &
b) {
 
 2655        bounded = b.ult(b.getBitWidth());
 
 2658  return bounded ? 
result : Attribute();
 
 2668                                            bool useOnlyFiniteValue) {
 
 2670  case AtomicRMWKind::maximumf: {
 
 2671    const llvm::fltSemantics &semantic =
 
 2672        llvm::cast<FloatType>(resultType).getFloatSemantics();
 
 2673    APFloat identity = useOnlyFiniteValue
 
 2674                           ? APFloat::getLargest(semantic, 
true)
 
 2675                           : APFloat::getInf(semantic, 
true);
 
 2678  case AtomicRMWKind::maxnumf: {
 
 2679    const llvm::fltSemantics &semantic =
 
 2680        llvm::cast<FloatType>(resultType).getFloatSemantics();
 
 2681    APFloat identity = APFloat::getNaN(semantic, 
true);
 
 2684  case AtomicRMWKind::addf:
 
 2685  case AtomicRMWKind::addi:
 
 2686  case AtomicRMWKind::maxu:
 
 2687  case AtomicRMWKind::ori:
 
 2688  case AtomicRMWKind::xori:
 
 2690  case AtomicRMWKind::andi:
 
 2693        APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
 
 2694  case AtomicRMWKind::maxs:
 
 2696        resultType, APInt::getSignedMinValue(
 
 2697                        llvm::cast<IntegerType>(resultType).getWidth()));
 
 2698  case AtomicRMWKind::minimumf: {
 
 2699    const llvm::fltSemantics &semantic =
 
 2700        llvm::cast<FloatType>(resultType).getFloatSemantics();
 
 2701    APFloat identity = useOnlyFiniteValue
 
 2702                           ? APFloat::getLargest(semantic, 
false)
 
 2703                           : APFloat::getInf(semantic, 
false);
 
 2707  case AtomicRMWKind::minnumf: {
 
 2708    const llvm::fltSemantics &semantic =
 
 2709        llvm::cast<FloatType>(resultType).getFloatSemantics();
 
 2710    APFloat identity = APFloat::getNaN(semantic, 
false);
 
 2713  case AtomicRMWKind::mins:
 
 2715        resultType, APInt::getSignedMaxValue(
 
 2716                        llvm::cast<IntegerType>(resultType).getWidth()));
 
 2717  case AtomicRMWKind::minu:
 
 2720        APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
 
 2721  case AtomicRMWKind::muli:
 
 2723  case AtomicRMWKind::mulf:
 
 
 2735  std::optional<AtomicRMWKind> maybeKind =
 
 2738          .Case([](arith::AddFOp op) { 
return AtomicRMWKind::addf; })
 
 2739          .Case([](arith::MulFOp op) { 
return AtomicRMWKind::mulf; })
 
 2740          .Case([](arith::MaximumFOp op) { 
return AtomicRMWKind::maximumf; })
 
 2741          .Case([](arith::MinimumFOp op) { 
return AtomicRMWKind::minimumf; })
 
 2742          .Case([](arith::MaxNumFOp op) { 
return AtomicRMWKind::maxnumf; })
 
 2743          .Case([](arith::MinNumFOp op) { 
return AtomicRMWKind::minnumf; })
 
 2745          .Case([](arith::AddIOp op) { 
return AtomicRMWKind::addi; })
 
 2746          .Case([](arith::OrIOp op) { 
return AtomicRMWKind::ori; })
 
 2747          .Case([](arith::XOrIOp op) { 
return AtomicRMWKind::xori; })
 
 2748          .Case([](arith::AndIOp op) { 
return AtomicRMWKind::andi; })
 
 2749          .Case([](arith::MaxUIOp op) { 
return AtomicRMWKind::maxu; })
 
 2750          .Case([](arith::MinUIOp op) { 
return AtomicRMWKind::minu; })
 
 2751          .Case([](arith::MaxSIOp op) { 
return AtomicRMWKind::maxs; })
 
 2752          .Case([](arith::MinSIOp op) { 
return AtomicRMWKind::mins; })
 
 2753          .Case([](arith::MulIOp op) { 
return AtomicRMWKind::muli; })
 
 2754          .Default(std::nullopt);
 
 2756    return std::nullopt;
 
 2759  bool useOnlyFiniteValue = 
false;
 
 2760  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
 
 2761  if (fmfOpInterface) {
 
 2762    arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
 
 2763    useOnlyFiniteValue =
 
 2764        bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
 
 2772                              useOnlyFiniteValue);
 
 
 2778                                    bool useOnlyFiniteValue) {
 
 2781  return arith::ConstantOp::create(builder, loc, attr);
 
 
 2789  case AtomicRMWKind::addf:
 
 2790    return arith::AddFOp::create(builder, loc, 
lhs, 
rhs);
 
 2791  case AtomicRMWKind::addi:
 
 2792    return arith::AddIOp::create(builder, loc, 
lhs, 
rhs);
 
 2793  case AtomicRMWKind::mulf:
 
 2794    return arith::MulFOp::create(builder, loc, 
lhs, 
rhs);
 
 2795  case AtomicRMWKind::muli:
 
 2796    return arith::MulIOp::create(builder, loc, 
lhs, 
rhs);
 
 2797  case AtomicRMWKind::maximumf:
 
 2798    return arith::MaximumFOp::create(builder, loc, 
lhs, 
rhs);
 
 2799  case AtomicRMWKind::minimumf:
 
 2800    return arith::MinimumFOp::create(builder, loc, 
lhs, 
rhs);
 
 2801  case AtomicRMWKind::maxnumf:
 
 2802    return arith::MaxNumFOp::create(builder, loc, 
lhs, 
rhs);
 
 2803  case AtomicRMWKind::minnumf:
 
 2804    return arith::MinNumFOp::create(builder, loc, 
lhs, 
rhs);
 
 2805  case AtomicRMWKind::maxs:
 
 2806    return arith::MaxSIOp::create(builder, loc, 
lhs, 
rhs);
 
 2807  case AtomicRMWKind::mins:
 
 2808    return arith::MinSIOp::create(builder, loc, 
lhs, 
rhs);
 
 2809  case AtomicRMWKind::maxu:
 
 2810    return arith::MaxUIOp::create(builder, loc, 
lhs, 
rhs);
 
 2811  case AtomicRMWKind::minu:
 
 2812    return arith::MinUIOp::create(builder, loc, 
lhs, 
rhs);
 
 2813  case AtomicRMWKind::ori:
 
 2814    return arith::OrIOp::create(builder, loc, 
lhs, 
rhs);
 
 2815  case AtomicRMWKind::andi:
 
 2816    return arith::AndIOp::create(builder, loc, 
lhs, 
rhs);
 
 2817  case AtomicRMWKind::xori:
 
 2818    return arith::XOrIOp::create(builder, loc, 
lhs, 
rhs);
 
 
 2831#define GET_OP_CLASSES 
 2832#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" 
 2838#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc" 
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
 
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
 
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
 
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
 
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
 
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
 
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
 
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
 
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
 
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
 
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
 
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
 
static std::optional< int64_t > getIntegerWidth(Type t)
 
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
 
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
 
static Attribute getBoolAttribute(Type type, bool value)
 
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
 
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
 
static LogicalResult verifyExtOp(Op op)
 
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
 
static int64_t getScalarOrElementWidth(Type type)
 
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
 
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
 
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
 
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
 
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
 
std::tuple< Types... > * type_list
 
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
 
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
 
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
 
static LogicalResult verifyTruncateOp(Op op)
 
static Type getElementType(Type type)
Determine the element type of type.
 
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
 
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
 
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
 
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
 
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
 
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
 
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
 
virtual ParseResult parseType(Type &result)=0
Parse a type.
 
Attributes are known-constant values of operations.
 
static BoolAttr get(MLIRContext *context, bool value)
 
IntegerAttr getIndexAttr(int64_t value)
 
IntegerAttr getIntegerAttr(Type type, int64_t value)
 
FloatAttr getFloatAttr(Type type, double value)
 
IntegerType getIntegerType(unsigned width)
 
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
 
TypedAttr getZeroAttr(Type type)
 
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
 
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
 
Location getLoc() const
Accessors for the implied location.
 
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
 
MLIRContext is the top-level object for a collection of MLIR operations.
 
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
 
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
 
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
 
This class helps build Operations.
 
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
 
This class represents a single result from folding an operation.
 
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
 
This provides public APIs that all operations should have.
 
Operation is the basic unit of execution within MLIR.
 
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
 
Location getLoc()
The source location the operation was defined or derived from.
 
MLIRContext * getContext()
Return the context this operation is associated with.
 
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
 
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
 
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
 
This class provides an abstraction over the various different ranges of value types.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
 
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
 
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
 
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
 
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
 
Type getType() const
Return the type of this value.
 
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
 
Specialization of arith.constant op that returns a floating point value.
 
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
 
static bool classof(Operation *op)
 
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
 
Specialization of arith.constant op that returns an integer of index type.
 
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
 
static bool classof(Operation *op)
 
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
 
Specialization of arith.constant op that returns an integer value.
 
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
 
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
 
static bool classof(Operation *op)
 
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
 
constexpr auto Speculatable
 
constexpr auto NotSpeculatable
 
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
 
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
 
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
 
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
 
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
 
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
 
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
 
Include the generated interface declarations.
 
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
 
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
 
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
 
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
 
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
 
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
 
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
 
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
 
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
 
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
 
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
 
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
 
const FrozenRewritePatternSet & patterns
 
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
 
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
 
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
 
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
 
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
 
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
 
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
 
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
 
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
 
llvm::function_ref< Fn > function_ref
 
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
 
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
 
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
 
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
 
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
 
This represents an operation in an abstracted form, suitable for use with the builder APIs.