MLIR  20.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/APSInt.h"
28 #include "llvm/ADT/FloatingPointMode.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 
37 //===----------------------------------------------------------------------===//
38 // Pattern helpers
39 //===----------------------------------------------------------------------===//
40 
41 static IntegerAttr
43  Attribute rhs,
44  function_ref<APInt(const APInt &, const APInt &)> binFn) {
45  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47  APInt value = binFn(lhsVal, rhsVal);
48  return IntegerAttr::get(res.getType(), value);
49 }
50 
51 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
52  Attribute lhs, Attribute rhs) {
53  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54 }
55 
56 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
57  Attribute lhs, Attribute rhs) {
58  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59 }
60 
61 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
62  Attribute lhs, Attribute rhs) {
63  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64 }
65 
66 // Merge overflow flags from 2 ops, selecting the most conservative combination.
67 static IntegerOverflowFlagsAttr
68 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69  IntegerOverflowFlagsAttr val2) {
70  return IntegerOverflowFlagsAttr::get(val1.getContext(),
71  val1.getValue() & val2.getValue());
72 }
73 
74 /// Invert an integer comparison predicate.
75 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76  switch (pred) {
77  case arith::CmpIPredicate::eq:
78  return arith::CmpIPredicate::ne;
79  case arith::CmpIPredicate::ne:
80  return arith::CmpIPredicate::eq;
81  case arith::CmpIPredicate::slt:
82  return arith::CmpIPredicate::sge;
83  case arith::CmpIPredicate::sle:
84  return arith::CmpIPredicate::sgt;
85  case arith::CmpIPredicate::sgt:
86  return arith::CmpIPredicate::sle;
87  case arith::CmpIPredicate::sge:
88  return arith::CmpIPredicate::slt;
89  case arith::CmpIPredicate::ult:
90  return arith::CmpIPredicate::uge;
91  case arith::CmpIPredicate::ule:
92  return arith::CmpIPredicate::ugt;
93  case arith::CmpIPredicate::ugt:
94  return arith::CmpIPredicate::ule;
95  case arith::CmpIPredicate::uge:
96  return arith::CmpIPredicate::ult;
97  }
98  llvm_unreachable("unknown cmpi predicate kind");
99 }
100 
101 /// Equivalent to
102 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103 ///
104 /// Not possible to implement as chain of calls as this would introduce a
105 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106 /// on the LLVM dialect and on translation to LLVM.
107 static llvm::RoundingMode
108 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109  switch (roundingMode) {
110  case RoundingMode::downward:
111  return llvm::RoundingMode::TowardNegative;
112  case RoundingMode::to_nearest_away:
113  return llvm::RoundingMode::NearestTiesToAway;
114  case RoundingMode::to_nearest_even:
115  return llvm::RoundingMode::NearestTiesToEven;
116  case RoundingMode::toward_zero:
117  return llvm::RoundingMode::TowardZero;
118  case RoundingMode::upward:
119  return llvm::RoundingMode::TowardPositive;
120  }
121  llvm_unreachable("Unhandled rounding mode");
122 }
123 
124 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125  return arith::CmpIPredicateAttr::get(pred.getContext(),
126  invertPredicate(pred.getValue()));
127 }
128 
129 static int64_t getScalarOrElementWidth(Type type) {
130  Type elemTy = getElementTypeOrSelf(type);
131  if (elemTy.isIntOrFloat())
132  return elemTy.getIntOrFloatBitWidth();
133 
134  return -1;
135 }
136 
137 static int64_t getScalarOrElementWidth(Value value) {
138  return getScalarOrElementWidth(value.getType());
139 }
140 
141 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142  APInt value;
143  if (matchPattern(attr, m_ConstantInt(&value)))
144  return value;
145 
146  return failure();
147 }
148 
149 static Attribute getBoolAttribute(Type type, bool value) {
150  auto boolAttr = BoolAttr::get(type.getContext(), value);
151  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
152  if (!shapedType)
153  return boolAttr;
154  return DenseElementsAttr::get(shapedType, boolAttr);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TableGen'd canonicalization patterns
159 //===----------------------------------------------------------------------===//
160 
161 namespace {
162 #include "ArithCanonicalization.inc"
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // Common helpers
167 //===----------------------------------------------------------------------===//
168 
169 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
170 static Type getI1SameShape(Type type) {
171  auto i1Type = IntegerType::get(type.getContext(), 1);
172  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
173  return shapedType.cloneWith(std::nullopt, i1Type);
174  if (llvm::isa<UnrankedTensorType>(type))
175  return UnrankedTensorType::get(i1Type);
176  return i1Type;
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // ConstantOp
181 //===----------------------------------------------------------------------===//
182 
183 void arith::ConstantOp::getAsmResultNames(
184  function_ref<void(Value, StringRef)> setNameFn) {
185  auto type = getType();
186  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
187  auto intType = llvm::dyn_cast<IntegerType>(type);
188 
189  // Sugar i1 constants with 'true' and 'false'.
190  if (intType && intType.getWidth() == 1)
191  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192 
193  // Otherwise, build a complex name with the value and type.
194  SmallString<32> specialNameBuffer;
195  llvm::raw_svector_ostream specialName(specialNameBuffer);
196  specialName << 'c' << intCst.getValue();
197  if (intType)
198  specialName << '_' << type;
199  setNameFn(getResult(), specialName.str());
200  } else {
201  setNameFn(getResult(), "cst");
202  }
203 }
204 
205 /// TODO: disallow arith.constant to return anything other than signless integer
206 /// or float like.
207 LogicalResult arith::ConstantOp::verify() {
208  auto type = getType();
209  // The value's type must match the return type.
210  if (getValue().getType() != type) {
211  return emitOpError() << "value type " << getValue().getType()
212  << " must match return type: " << type;
213  }
214  // Integer values must be signless.
215  if (llvm::isa<IntegerType>(type) &&
216  !llvm::cast<IntegerType>(type).isSignless())
217  return emitOpError("integer return type must be signless");
218  // Any float or elements attribute are acceptable.
219  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
220  return emitOpError(
221  "value must be an integer, float, or elements attribute");
222  }
223 
224  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
225  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
226  // However, this would most likely require updating the lowerings to LLVM.
227  auto vecType = dyn_cast<VectorType>(type);
228  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
229  return emitOpError(
230  "intializing scalable vectors with elements attribute is not supported"
231  " unless it's a vector splat");
232  return success();
233 }
234 
235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
236  // The value's type must be the same as the provided type.
237  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238  if (!typedAttr || typedAttr.getType() != type)
239  return false;
240  // Integer values must be signless.
241  if (llvm::isa<IntegerType>(type) &&
242  !llvm::cast<IntegerType>(type).isSignless())
243  return false;
244  // Integer, float, and element attributes are buildable.
245  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246 }
247 
248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
249  Type type, Location loc) {
250  if (isBuildableWith(value, type))
251  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
252  return nullptr;
253 }
254 
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
256 
258  int64_t value, unsigned width) {
259  auto type = builder.getIntegerType(width);
260  arith::ConstantOp::build(builder, result, type,
261  builder.getIntegerAttr(type, value));
262 }
263 
265  int64_t value, Type type) {
266  assert(type.isSignlessInteger() &&
267  "ConstantIntOp can only have signless integer type values");
268  arith::ConstantOp::build(builder, result, type,
269  builder.getIntegerAttr(type, value));
270 }
271 
273  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274  return constOp.getType().isSignlessInteger();
275  return false;
276 }
277 
279  const APFloat &value, FloatType type) {
280  arith::ConstantOp::build(builder, result, type,
281  builder.getFloatAttr(type, value));
282 }
283 
285  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286  return llvm::isa<FloatType>(constOp.getType());
287  return false;
288 }
289 
291  int64_t value) {
292  arith::ConstantOp::build(builder, result, builder.getIndexType(),
293  builder.getIndexAttr(value));
294 }
295 
297  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298  return constOp.getType().isIndex();
299  return false;
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // AddIOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
307  // addi(x, 0) -> x
308  if (matchPattern(adaptor.getRhs(), m_Zero()))
309  return getLhs();
310 
311  // addi(subi(a, b), b) -> a
312  if (auto sub = getLhs().getDefiningOp<SubIOp>())
313  if (getRhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  // addi(b, subi(a, b)) -> a
317  if (auto sub = getRhs().getDefiningOp<SubIOp>())
318  if (getLhs() == sub.getRhs())
319  return sub.getLhs();
320 
321  return constFoldBinaryOp<IntegerAttr>(
322  adaptor.getOperands(),
323  [](APInt a, const APInt &b) { return std::move(a) + b; });
324 }
325 
326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
327  MLIRContext *context) {
328  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // AddUIExtendedOp
334 //===----------------------------------------------------------------------===//
335 
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
339  return llvm::to_vector<4>(vt.getShape());
340  return std::nullopt;
341 }
342 
343 // Returns the overflow bit, assuming that `sum` is the result of unsigned
344 // addition of `operand` and another number.
345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
346  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
347 }
348 
349 LogicalResult
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352  Type overflowTy = getOverflow().getType();
353  // addui_extended(x, 0) -> x, false
354  if (matchPattern(getRhs(), m_Zero())) {
355  Builder builder(getContext());
356  auto falseValue = builder.getZeroAttr(overflowTy);
357 
358  results.push_back(getLhs());
359  results.push_back(falseValue);
360  return success();
361  }
362 
363  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
364  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
365  // operands. If that succeeds, calculate the overflow bit based on the sum
366  // and the first (constant) operand, `lhs`.
367  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368  adaptor.getOperands(),
369  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
370  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371  ArrayRef({sumAttr, adaptor.getLhs()}),
372  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
374  if (!overflowAttr)
375  return failure();
376 
377  results.push_back(sumAttr);
378  results.push_back(overflowAttr);
379  return success();
380  }
381 
382  return failure();
383 }
384 
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
386  RewritePatternSet &patterns, MLIRContext *context) {
387  patterns.add<AddUIExtendedToAddI>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // SubIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395  // subi(x,x) -> 0
396  if (getOperand(0) == getOperand(1))
397  return Builder(getContext()).getZeroAttr(getType());
398  // subi(x,0) -> x
399  if (matchPattern(adaptor.getRhs(), m_Zero()))
400  return getLhs();
401 
402  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
403  // subi(addi(a, b), b) -> a
404  if (getRhs() == add.getRhs())
405  return add.getLhs();
406  // subi(addi(a, b), a) -> b
407  if (getRhs() == add.getLhs())
408  return add.getRhs();
409  }
410 
411  return constFoldBinaryOp<IntegerAttr>(
412  adaptor.getOperands(),
413  [](APInt a, const APInt &b) { return std::move(a) - b; });
414 }
415 
416 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
417  MLIRContext *context) {
418  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
419  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
420  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // MulIOp
425 //===----------------------------------------------------------------------===//
426 
427 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
428  // muli(x, 0) -> 0
429  if (matchPattern(adaptor.getRhs(), m_Zero()))
430  return getRhs();
431  // muli(x, 1) -> x
432  if (matchPattern(adaptor.getRhs(), m_One()))
433  return getLhs();
434  // TODO: Handle the overflow case.
435 
436  // default folder
437  return constFoldBinaryOp<IntegerAttr>(
438  adaptor.getOperands(),
439  [](const APInt &a, const APInt &b) { return a * b; });
440 }
441 
442 void arith::MulIOp::getAsmResultNames(
443  function_ref<void(Value, StringRef)> setNameFn) {
444  if (!isa<IndexType>(getType()))
445  return;
446 
447  // Match vector.vscale by name to avoid depending on the vector dialect (which
448  // is a circular dependency).
449  auto isVscale = [](Operation *op) {
450  return op && op->getName().getStringRef() == "vector.vscale";
451  };
452 
453  IntegerAttr baseValue;
454  auto isVscaleExpr = [&](Value a, Value b) {
455  return matchPattern(a, m_Constant(&baseValue)) &&
456  isVscale(b.getDefiningOp());
457  };
458 
459  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
460  return;
461 
462  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
463  SmallString<32> specialNameBuffer;
464  llvm::raw_svector_ostream specialName(specialNameBuffer);
465  specialName << 'c' << baseValue.getInt() << "_vscale";
466  setNameFn(getResult(), specialName.str());
467 }
468 
469 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
470  MLIRContext *context) {
471  patterns.add<MulIMulIConstant>(context);
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // MulSIExtendedOp
476 //===----------------------------------------------------------------------===//
477 
478 std::optional<SmallVector<int64_t, 4>>
479 arith::MulSIExtendedOp::getShapeForUnroll() {
480  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
481  return llvm::to_vector<4>(vt.getShape());
482  return std::nullopt;
483 }
484 
485 LogicalResult
486 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
488  // mulsi_extended(x, 0) -> 0, 0
489  if (matchPattern(adaptor.getRhs(), m_Zero())) {
490  Attribute zero = adaptor.getRhs();
491  results.push_back(zero);
492  results.push_back(zero);
493  return success();
494  }
495 
496  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
497  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
498  adaptor.getOperands(),
499  [](const APInt &a, const APInt &b) { return a * b; })) {
500  // Invoke the constant fold helper again to calculate the 'high' result.
501  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
502  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
503  return llvm::APIntOps::mulhs(a, b);
504  });
505  assert(highAttr && "Unexpected constant-folding failure");
506 
507  results.push_back(lowAttr);
508  results.push_back(highAttr);
509  return success();
510  }
511 
512  return failure();
513 }
514 
515 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516  RewritePatternSet &patterns, MLIRContext *context) {
517  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // MulUIExtendedOp
522 //===----------------------------------------------------------------------===//
523 
524 std::optional<SmallVector<int64_t, 4>>
525 arith::MulUIExtendedOp::getShapeForUnroll() {
526  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
527  return llvm::to_vector<4>(vt.getShape());
528  return std::nullopt;
529 }
530 
531 LogicalResult
532 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
534  // mului_extended(x, 0) -> 0, 0
535  if (matchPattern(adaptor.getRhs(), m_Zero())) {
536  Attribute zero = adaptor.getRhs();
537  results.push_back(zero);
538  results.push_back(zero);
539  return success();
540  }
541 
542  // mului_extended(x, 1) -> x, 0
543  if (matchPattern(adaptor.getRhs(), m_One())) {
544  Builder builder(getContext());
545  Attribute zero = builder.getZeroAttr(getLhs().getType());
546  results.push_back(getLhs());
547  results.push_back(zero);
548  return success();
549  }
550 
551  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
552  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
553  adaptor.getOperands(),
554  [](const APInt &a, const APInt &b) { return a * b; })) {
555  // Invoke the constant fold helper again to calculate the 'high' result.
556  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
557  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
558  return llvm::APIntOps::mulhu(a, b);
559  });
560  assert(highAttr && "Unexpected constant-folding failure");
561 
562  results.push_back(lowAttr);
563  results.push_back(highAttr);
564  return success();
565  }
566 
567  return failure();
568 }
569 
570 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571  RewritePatternSet &patterns, MLIRContext *context) {
572  patterns.add<MulUIExtendedToMulI>(context);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // DivUIOp
577 //===----------------------------------------------------------------------===//
578 
579 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
580  // divui (x, 1) -> x.
581  if (matchPattern(adaptor.getRhs(), m_One()))
582  return getLhs();
583 
584  // Don't fold if it would require a division by zero.
585  bool div0 = false;
586  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
587  [&](APInt a, const APInt &b) {
588  if (div0 || !b) {
589  div0 = true;
590  return a;
591  }
592  return a.udiv(b);
593  });
594 
595  return div0 ? Attribute() : result;
596 }
597 
598 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
599  // X / 0 => UB
600  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // DivSIOp
606 //===----------------------------------------------------------------------===//
607 
608 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
609  // divsi (x, 1) -> x.
610  if (matchPattern(adaptor.getRhs(), m_One()))
611  return getLhs();
612 
613  // Don't fold if it would overflow or if it requires a division by zero.
614  bool overflowOrDiv0 = false;
615  auto result = constFoldBinaryOp<IntegerAttr>(
616  adaptor.getOperands(), [&](APInt a, const APInt &b) {
617  if (overflowOrDiv0 || !b) {
618  overflowOrDiv0 = true;
619  return a;
620  }
621  return a.sdiv_ov(b, overflowOrDiv0);
622  });
623 
624  return overflowOrDiv0 ? Attribute() : result;
625 }
626 
627 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
628  bool mayHaveUB = true;
629 
630  APInt constRHS;
631  // X / 0 => UB
632  // INT_MIN / -1 => UB
633  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
634  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
635 
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Ceil and floor division folding helpers
641 //===----------------------------------------------------------------------===//
642 
643 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
644  bool &overflow) {
645  // Returns (a-1)/b + 1
646  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
647  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
648  return val.sadd_ov(one, overflow);
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // CeilDivUIOp
653 //===----------------------------------------------------------------------===//
654 
655 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
656  // ceildivui (x, 1) -> x.
657  if (matchPattern(adaptor.getRhs(), m_One()))
658  return getLhs();
659 
660  bool overflowOrDiv0 = false;
661  auto result = constFoldBinaryOp<IntegerAttr>(
662  adaptor.getOperands(), [&](APInt a, const APInt &b) {
663  if (overflowOrDiv0 || !b) {
664  overflowOrDiv0 = true;
665  return a;
666  }
667  APInt quotient = a.udiv(b);
668  if (!a.urem(b))
669  return quotient;
670  APInt one(a.getBitWidth(), 1, true);
671  return quotient.uadd_ov(one, overflowOrDiv0);
672  });
673 
674  return overflowOrDiv0 ? Attribute() : result;
675 }
676 
677 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
678  // X / 0 => UB
679  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // CeilDivSIOp
685 //===----------------------------------------------------------------------===//
686 
687 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
688  // ceildivsi (x, 1) -> x.
689  if (matchPattern(adaptor.getRhs(), m_One()))
690  return getLhs();
691 
692  // Don't fold if it would overflow or if it requires a division by zero.
693  // TODO: This hook won't fold operations where a = MININT, because
694  // negating MININT overflows. This can be improved.
695  bool overflowOrDiv0 = false;
696  auto result = constFoldBinaryOp<IntegerAttr>(
697  adaptor.getOperands(), [&](APInt a, const APInt &b) {
698  if (overflowOrDiv0 || !b) {
699  overflowOrDiv0 = true;
700  return a;
701  }
702  if (!a)
703  return a;
704  // After this point we know that neither a or b are zero.
705  unsigned bits = a.getBitWidth();
706  APInt zero = APInt::getZero(bits);
707  bool aGtZero = a.sgt(zero);
708  bool bGtZero = b.sgt(zero);
709  if (aGtZero && bGtZero) {
710  // Both positive, return ceil(a, b).
711  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
712  }
713 
714  // No folding happens if any of the intermediate arithmetic operations
715  // overflows.
716  bool overflowNegA = false;
717  bool overflowNegB = false;
718  bool overflowDiv = false;
719  bool overflowNegRes = false;
720  if (!aGtZero && !bGtZero) {
721  // Both negative, return ceil(-a, -b).
722  APInt posA = zero.ssub_ov(a, overflowNegA);
723  APInt posB = zero.ssub_ov(b, overflowNegB);
724  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
725  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
726  return res;
727  }
728  if (!aGtZero && bGtZero) {
729  // A is negative, b is positive, return - ( -a / b).
730  APInt posA = zero.ssub_ov(a, overflowNegA);
731  APInt div = posA.sdiv_ov(b, overflowDiv);
732  APInt res = zero.ssub_ov(div, overflowNegRes);
733  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
734  return res;
735  }
736  // A is positive, b is negative, return - (a / -b).
737  APInt posB = zero.ssub_ov(b, overflowNegB);
738  APInt div = a.sdiv_ov(posB, overflowDiv);
739  APInt res = zero.ssub_ov(div, overflowNegRes);
740 
741  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
742  return res;
743  });
744 
745  return overflowOrDiv0 ? Attribute() : result;
746 }
747 
748 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
749  bool mayHaveUB = true;
750 
751  APInt constRHS;
752  // X / 0 => UB
753  // INT_MIN / -1 => UB
754  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
755  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
756 
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // FloorDivSIOp
762 //===----------------------------------------------------------------------===//
763 
764 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
765  // floordivsi (x, 1) -> x.
766  if (matchPattern(adaptor.getRhs(), m_One()))
767  return getLhs();
768 
769  // Don't fold if it would overflow or if it requires a division by zero.
770  bool overflowOrDiv = false;
771  auto result = constFoldBinaryOp<IntegerAttr>(
772  adaptor.getOperands(), [&](APInt a, const APInt &b) {
773  if (b.isZero()) {
774  overflowOrDiv = true;
775  return a;
776  }
777  return a.sfloordiv_ov(b, overflowOrDiv);
778  });
779 
780  return overflowOrDiv ? Attribute() : result;
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // RemUIOp
785 //===----------------------------------------------------------------------===//
786 
787 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
788  // remui (x, 1) -> 0.
789  if (matchPattern(adaptor.getRhs(), m_One()))
790  return Builder(getContext()).getZeroAttr(getType());
791 
792  // Don't fold if it would require a division by zero.
793  bool div0 = false;
794  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
795  [&](APInt a, const APInt &b) {
796  if (div0 || b.isZero()) {
797  div0 = true;
798  return a;
799  }
800  return a.urem(b);
801  });
802 
803  return div0 ? Attribute() : result;
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // RemSIOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
811  // remsi (x, 1) -> 0.
812  if (matchPattern(adaptor.getRhs(), m_One()))
813  return Builder(getContext()).getZeroAttr(getType());
814 
815  // Don't fold if it would require a division by zero.
816  bool div0 = false;
817  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
818  [&](APInt a, const APInt &b) {
819  if (div0 || b.isZero()) {
820  div0 = true;
821  return a;
822  }
823  return a.srem(b);
824  });
825 
826  return div0 ? Attribute() : result;
827 }
828 
829 //===----------------------------------------------------------------------===//
830 // AndIOp
831 //===----------------------------------------------------------------------===//
832 
833 /// Fold `and(a, and(a, b))` to `and(a, b)`
834 static Value foldAndIofAndI(arith::AndIOp op) {
835  for (bool reversePrev : {false, true}) {
836  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
837  .getDefiningOp<arith::AndIOp>();
838  if (!prev)
839  continue;
840 
841  Value other = (reversePrev ? op.getLhs() : op.getRhs());
842  if (other != prev.getLhs() && other != prev.getRhs())
843  continue;
844 
845  return prev.getResult();
846  }
847  return {};
848 }
849 
850 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
851  /// and(x, 0) -> 0
852  if (matchPattern(adaptor.getRhs(), m_Zero()))
853  return getRhs();
854  /// and(x, allOnes) -> x
855  APInt intValue;
856  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
857  intValue.isAllOnes())
858  return getLhs();
859  /// and(x, not(x)) -> 0
860  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
861  m_ConstantInt(&intValue))) &&
862  intValue.isAllOnes())
863  return Builder(getContext()).getZeroAttr(getType());
864  /// and(not(x), x) -> 0
865  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
866  m_ConstantInt(&intValue))) &&
867  intValue.isAllOnes())
868  return Builder(getContext()).getZeroAttr(getType());
869 
870  /// and(a, and(a, b)) -> and(a, b)
871  if (Value result = foldAndIofAndI(*this))
872  return result;
873 
874  return constFoldBinaryOp<IntegerAttr>(
875  adaptor.getOperands(),
876  [](APInt a, const APInt &b) { return std::move(a) & b; });
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // OrIOp
881 //===----------------------------------------------------------------------===//
882 
883 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
884  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
885  /// or(x, 0) -> x
886  if (rhsVal.isZero())
887  return getLhs();
888  /// or(x, <all ones>) -> <all ones>
889  if (rhsVal.isAllOnes())
890  return adaptor.getRhs();
891  }
892 
893  APInt intValue;
894  /// or(x, xor(x, 1)) -> 1
895  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
896  m_ConstantInt(&intValue))) &&
897  intValue.isAllOnes())
898  return getRhs().getDefiningOp<XOrIOp>().getRhs();
899  /// or(xor(x, 1), x) -> 1
900  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
901  m_ConstantInt(&intValue))) &&
902  intValue.isAllOnes())
903  return getLhs().getDefiningOp<XOrIOp>().getRhs();
904 
905  return constFoldBinaryOp<IntegerAttr>(
906  adaptor.getOperands(),
907  [](APInt a, const APInt &b) { return std::move(a) | b; });
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // XOrIOp
912 //===----------------------------------------------------------------------===//
913 
914 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
915  /// xor(x, 0) -> x
916  if (matchPattern(adaptor.getRhs(), m_Zero()))
917  return getLhs();
918  /// xor(x, x) -> 0
919  if (getLhs() == getRhs())
920  return Builder(getContext()).getZeroAttr(getType());
921  /// xor(xor(x, a), a) -> x
922  /// xor(xor(a, x), a) -> x
923  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
924  if (prev.getRhs() == getRhs())
925  return prev.getLhs();
926  if (prev.getLhs() == getRhs())
927  return prev.getRhs();
928  }
929  /// xor(a, xor(x, a)) -> x
930  /// xor(a, xor(a, x)) -> x
931  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
932  if (prev.getRhs() == getLhs())
933  return prev.getLhs();
934  if (prev.getLhs() == getLhs())
935  return prev.getRhs();
936  }
937 
938  return constFoldBinaryOp<IntegerAttr>(
939  adaptor.getOperands(),
940  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
941 }
942 
943 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
944  MLIRContext *context) {
945  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // NegFOp
950 //===----------------------------------------------------------------------===//
951 
952 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
953  /// negf(negf(x)) -> x
954  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
955  return op.getOperand();
956  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
957  [](const APFloat &a) { return -a; });
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // AddFOp
962 //===----------------------------------------------------------------------===//
963 
964 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
965  // addf(x, -0) -> x
966  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
967  return getLhs();
968 
969  return constFoldBinaryOp<FloatAttr>(
970  adaptor.getOperands(),
971  [](const APFloat &a, const APFloat &b) { return a + b; });
972 }
973 
974 //===----------------------------------------------------------------------===//
975 // SubFOp
976 //===----------------------------------------------------------------------===//
977 
978 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
979  // subf(x, +0) -> x
980  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
981  return getLhs();
982 
983  return constFoldBinaryOp<FloatAttr>(
984  adaptor.getOperands(),
985  [](const APFloat &a, const APFloat &b) { return a - b; });
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // MaximumFOp
990 //===----------------------------------------------------------------------===//
991 
992 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
993  // maximumf(x,x) -> x
994  if (getLhs() == getRhs())
995  return getRhs();
996 
997  // maximumf(x, -inf) -> x
998  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
999  return getLhs();
1000 
1001  return constFoldBinaryOp<FloatAttr>(
1002  adaptor.getOperands(),
1003  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // MaxNumFOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1011  // maxnumf(x,x) -> x
1012  if (getLhs() == getRhs())
1013  return getRhs();
1014 
1015  // maxnumf(x, -inf) -> x
1016  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1017  return getLhs();
1018 
1019  return constFoldBinaryOp<FloatAttr>(
1020  adaptor.getOperands(),
1021  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // MaxSIOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1029  // maxsi(x,x) -> x
1030  if (getLhs() == getRhs())
1031  return getRhs();
1032 
1033  if (APInt intValue;
1034  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1035  // maxsi(x,MAX_INT) -> MAX_INT
1036  if (intValue.isMaxSignedValue())
1037  return getRhs();
1038  // maxsi(x, MIN_INT) -> x
1039  if (intValue.isMinSignedValue())
1040  return getLhs();
1041  }
1042 
1043  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1044  [](const APInt &a, const APInt &b) {
1045  return llvm::APIntOps::smax(a, b);
1046  });
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // MaxUIOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1054  // maxui(x,x) -> x
1055  if (getLhs() == getRhs())
1056  return getRhs();
1057 
1058  if (APInt intValue;
1059  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1060  // maxui(x,MAX_INT) -> MAX_INT
1061  if (intValue.isMaxValue())
1062  return getRhs();
1063  // maxui(x, MIN_INT) -> x
1064  if (intValue.isMinValue())
1065  return getLhs();
1066  }
1067 
1068  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1069  [](const APInt &a, const APInt &b) {
1070  return llvm::APIntOps::umax(a, b);
1071  });
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // MinimumFOp
1076 //===----------------------------------------------------------------------===//
1077 
1078 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1079  // minimumf(x,x) -> x
1080  if (getLhs() == getRhs())
1081  return getRhs();
1082 
1083  // minimumf(x, +inf) -> x
1084  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1085  return getLhs();
1086 
1087  return constFoldBinaryOp<FloatAttr>(
1088  adaptor.getOperands(),
1089  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // MinNumFOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1097  // minnumf(x,x) -> x
1098  if (getLhs() == getRhs())
1099  return getRhs();
1100 
1101  // minnumf(x, +inf) -> x
1102  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1103  return getLhs();
1104 
1105  return constFoldBinaryOp<FloatAttr>(
1106  adaptor.getOperands(),
1107  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // MinSIOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1115  // minsi(x,x) -> x
1116  if (getLhs() == getRhs())
1117  return getRhs();
1118 
1119  if (APInt intValue;
1120  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1121  // minsi(x,MIN_INT) -> MIN_INT
1122  if (intValue.isMinSignedValue())
1123  return getRhs();
1124  // minsi(x, MAX_INT) -> x
1125  if (intValue.isMaxSignedValue())
1126  return getLhs();
1127  }
1128 
1129  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1130  [](const APInt &a, const APInt &b) {
1131  return llvm::APIntOps::smin(a, b);
1132  });
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // MinUIOp
1137 //===----------------------------------------------------------------------===//
1138 
1139 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1140  // minui(x,x) -> x
1141  if (getLhs() == getRhs())
1142  return getRhs();
1143 
1144  if (APInt intValue;
1145  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1146  // minui(x,MIN_INT) -> MIN_INT
1147  if (intValue.isMinValue())
1148  return getRhs();
1149  // minui(x, MAX_INT) -> x
1150  if (intValue.isMaxValue())
1151  return getLhs();
1152  }
1153 
1154  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1155  [](const APInt &a, const APInt &b) {
1156  return llvm::APIntOps::umin(a, b);
1157  });
1158 }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // MulFOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1165  // mulf(x, 1) -> x
1166  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1167  return getLhs();
1168 
1169  return constFoldBinaryOp<FloatAttr>(
1170  adaptor.getOperands(),
1171  [](const APFloat &a, const APFloat &b) { return a * b; });
1172 }
1173 
1174 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1175  MLIRContext *context) {
1176  patterns.add<MulFOfNegF>(context);
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // DivFOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1184  // divf(x, 1) -> x
1185  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1186  return getLhs();
1187 
1188  return constFoldBinaryOp<FloatAttr>(
1189  adaptor.getOperands(),
1190  [](const APFloat &a, const APFloat &b) { return a / b; });
1191 }
1192 
1193 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1194  MLIRContext *context) {
1195  patterns.add<DivFOfNegF>(context);
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // RemFOp
1200 //===----------------------------------------------------------------------===//
1201 
1202 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1203  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1204  [](const APFloat &a, const APFloat &b) {
1205  APFloat result(a);
1206  // APFloat::mod() offers the remainder
1207  // behavior we want, i.e. the result has
1208  // the sign of LHS operand.
1209  (void)result.mod(b);
1210  return result;
1211  });
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // Utility functions for verifying cast ops
1216 //===----------------------------------------------------------------------===//
1217 
1218 template <typename... Types>
1219 using type_list = std::tuple<Types...> *;
1220 
1221 /// Returns a non-null type only if the provided type is one of the allowed
1222 /// types or one of the allowed shaped types of the allowed types. Returns the
1223 /// element type if a valid shaped type is provided.
1224 template <typename... ShapedTypes, typename... ElementTypes>
1227  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1228  return {};
1229 
1230  auto underlyingType = getElementTypeOrSelf(type);
1231  if (!llvm::isa<ElementTypes...>(underlyingType))
1232  return {};
1233 
1234  return underlyingType;
1235 }
1236 
1237 /// Get allowed underlying types for vectors and tensors.
1238 template <typename... ElementTypes>
1239 static Type getTypeIfLike(Type type) {
1242 }
1243 
1244 /// Get allowed underlying types for vectors, tensors, and memrefs.
1245 template <typename... ElementTypes>
1247  return getUnderlyingType(type,
1250 }
1251 
1252 /// Return false if both types are ranked tensor with mismatching encoding.
1253 static bool hasSameEncoding(Type typeA, Type typeB) {
1254  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1255  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1256  if (!rankedTensorA || !rankedTensorB)
1257  return true;
1258  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1259 }
1260 
1261 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1262  if (inputs.size() != 1 || outputs.size() != 1)
1263  return false;
1264  if (!hasSameEncoding(inputs.front(), outputs.front()))
1265  return false;
1266  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1267 }
1268 
1269 //===----------------------------------------------------------------------===//
1270 // Verifiers for integer and floating point extension/truncation ops
1271 //===----------------------------------------------------------------------===//
1272 
1273 // Extend ops can only extend to a wider type.
1274 template <typename ValType, typename Op>
1275 static LogicalResult verifyExtOp(Op op) {
1276  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1277  Type dstType = getElementTypeOrSelf(op.getType());
1278 
1279  if (llvm::cast<ValType>(srcType).getWidth() >=
1280  llvm::cast<ValType>(dstType).getWidth())
1281  return op.emitError("result type ")
1282  << dstType << " must be wider than operand type " << srcType;
1283 
1284  return success();
1285 }
1286 
1287 // Truncate ops can only truncate to a shorter type.
1288 template <typename ValType, typename Op>
1289 static LogicalResult verifyTruncateOp(Op op) {
1290  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1291  Type dstType = getElementTypeOrSelf(op.getType());
1292 
1293  if (llvm::cast<ValType>(srcType).getWidth() <=
1294  llvm::cast<ValType>(dstType).getWidth())
1295  return op.emitError("result type ")
1296  << dstType << " must be shorter than operand type " << srcType;
1297 
1298  return success();
1299 }
1300 
1301 /// Validate a cast that changes the width of a type.
1302 template <template <typename> class WidthComparator, typename... ElementTypes>
1303 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1304  if (!areValidCastInputsAndOutputs(inputs, outputs))
1305  return false;
1306 
1307  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1308  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1309  if (!srcType || !dstType)
1310  return false;
1311 
1312  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1313  srcType.getIntOrFloatBitWidth());
1314 }
1315 
1316 /// Attempts to convert `sourceValue` to an APFloat value with
1317 /// `targetSemantics` and `roundingMode`, without any information loss.
1318 static FailureOr<APFloat> convertFloatValue(
1319  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1320  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1321  bool losesInfo = false;
1322  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1323  if (losesInfo || status != APFloat::opOK)
1324  return failure();
1325 
1326  return sourceValue;
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // ExtUIOp
1331 //===----------------------------------------------------------------------===//
1332 
1333 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1334  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1335  getInMutable().assign(lhs.getIn());
1336  return getResult();
1337  }
1338 
1339  Type resType = getElementTypeOrSelf(getType());
1340  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1341  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1342  adaptor.getOperands(), getType(),
1343  [bitWidth](const APInt &a, bool &castStatus) {
1344  return a.zext(bitWidth);
1345  });
1346 }
1347 
1348 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1349  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1350 }
1351 
1352 LogicalResult arith::ExtUIOp::verify() {
1353  return verifyExtOp<IntegerType>(*this);
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // ExtSIOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1361  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1362  getInMutable().assign(lhs.getIn());
1363  return getResult();
1364  }
1365 
1366  Type resType = getElementTypeOrSelf(getType());
1367  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1368  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1369  adaptor.getOperands(), getType(),
1370  [bitWidth](const APInt &a, bool &castStatus) {
1371  return a.sext(bitWidth);
1372  });
1373 }
1374 
1375 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1376  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1377 }
1378 
1379 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1380  MLIRContext *context) {
1381  patterns.add<ExtSIOfExtUI>(context);
1382 }
1383 
1384 LogicalResult arith::ExtSIOp::verify() {
1385  return verifyExtOp<IntegerType>(*this);
1386 }
1387 
1388 //===----------------------------------------------------------------------===//
1389 // ExtFOp
1390 //===----------------------------------------------------------------------===//
1391 
1392 /// Fold extension of float constants when there is no information loss due the
1393 /// difference in fp semantics.
1394 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1395  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1396  if (truncFOp.getOperand().getType() == getType()) {
1397  arith::FastMathFlags truncFMF =
1398  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1399  bool isTruncContract =
1400  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1401  arith::FastMathFlags extFMF =
1402  getFastmath().value_or(arith::FastMathFlags::none);
1403  bool isExtContract =
1404  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1405  if (isTruncContract && isExtContract) {
1406  return truncFOp.getOperand();
1407  }
1408  }
1409  }
1410 
1411  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1412  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1413  return constFoldCastOp<FloatAttr, FloatAttr>(
1414  adaptor.getOperands(), getType(),
1415  [&targetSemantics](const APFloat &a, bool &castStatus) {
1416  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1417  if (failed(result)) {
1418  castStatus = false;
1419  return a;
1420  }
1421  return *result;
1422  });
1423 }
1424 
1425 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1426  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1427 }
1428 
1429 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1430 
1431 //===----------------------------------------------------------------------===//
1432 // TruncIOp
1433 //===----------------------------------------------------------------------===//
1434 
1435 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1436  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1437  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1438  Value src = getOperand().getDefiningOp()->getOperand(0);
1439  Type srcType = getElementTypeOrSelf(src.getType());
1440  Type dstType = getElementTypeOrSelf(getType());
1441  // trunci(zexti(a)) -> trunci(a)
1442  // trunci(sexti(a)) -> trunci(a)
1443  if (llvm::cast<IntegerType>(srcType).getWidth() >
1444  llvm::cast<IntegerType>(dstType).getWidth()) {
1445  setOperand(src);
1446  return getResult();
1447  }
1448 
1449  // trunci(zexti(a)) -> a
1450  // trunci(sexti(a)) -> a
1451  if (srcType == dstType)
1452  return src;
1453  }
1454 
1455  // trunci(trunci(a)) -> trunci(a))
1456  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1457  setOperand(getOperand().getDefiningOp()->getOperand(0));
1458  return getResult();
1459  }
1460 
1461  Type resType = getElementTypeOrSelf(getType());
1462  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1463  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1464  adaptor.getOperands(), getType(),
1465  [bitWidth](const APInt &a, bool &castStatus) {
1466  return a.trunc(bitWidth);
1467  });
1468 }
1469 
1470 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1471  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1472 }
1473 
1474 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1475  MLIRContext *context) {
1476  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1477  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1478  context);
1479 }
1480 
1481 LogicalResult arith::TruncIOp::verify() {
1482  return verifyTruncateOp<IntegerType>(*this);
1483 }
1484 
1485 //===----------------------------------------------------------------------===//
1486 // TruncFOp
1487 //===----------------------------------------------------------------------===//
1488 
1489 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1490 /// can be represented without precision loss.
1491 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1492  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1493  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1494  return constFoldCastOp<FloatAttr, FloatAttr>(
1495  adaptor.getOperands(), getType(),
1496  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1497  RoundingMode roundingMode =
1498  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1499  llvm::RoundingMode llvmRoundingMode =
1500  convertArithRoundingModeToLLVMIR(roundingMode);
1501  FailureOr<APFloat> result =
1502  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1503  if (failed(result)) {
1504  castStatus = false;
1505  return a;
1506  }
1507  return *result;
1508  });
1509 }
1510 
1511 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1512  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1513 }
1514 
1515 LogicalResult arith::TruncFOp::verify() {
1516  return verifyTruncateOp<FloatType>(*this);
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // AndIOp
1521 //===----------------------------------------------------------------------===//
1522 
1523 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1524  MLIRContext *context) {
1525  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1526 }
1527 
1528 //===----------------------------------------------------------------------===//
1529 // OrIOp
1530 //===----------------------------------------------------------------------===//
1531 
1532 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1533  MLIRContext *context) {
1534  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // Verifiers for casts between integers and floats.
1539 //===----------------------------------------------------------------------===//
1540 
1541 template <typename From, typename To>
1542 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1543  if (!areValidCastInputsAndOutputs(inputs, outputs))
1544  return false;
1545 
1546  auto srcType = getTypeIfLike<From>(inputs.front());
1547  auto dstType = getTypeIfLike<To>(outputs.back());
1548 
1549  return srcType && dstType;
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // UIToFPOp
1554 //===----------------------------------------------------------------------===//
1555 
1556 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1557  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1558 }
1559 
1560 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1561  Type resEleType = getElementTypeOrSelf(getType());
1562  return constFoldCastOp<IntegerAttr, FloatAttr>(
1563  adaptor.getOperands(), getType(),
1564  [&resEleType](const APInt &a, bool &castStatus) {
1565  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1566  APFloat apf(floatTy.getFloatSemantics(),
1567  APInt::getZero(floatTy.getWidth()));
1568  apf.convertFromAPInt(a, /*IsSigned=*/false,
1569  APFloat::rmNearestTiesToEven);
1570  return apf;
1571  });
1572 }
1573 
1574 //===----------------------------------------------------------------------===//
1575 // SIToFPOp
1576 //===----------------------------------------------------------------------===//
1577 
1578 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1579  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1580 }
1581 
1582 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1583  Type resEleType = getElementTypeOrSelf(getType());
1584  return constFoldCastOp<IntegerAttr, FloatAttr>(
1585  adaptor.getOperands(), getType(),
1586  [&resEleType](const APInt &a, bool &castStatus) {
1587  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1588  APFloat apf(floatTy.getFloatSemantics(),
1589  APInt::getZero(floatTy.getWidth()));
1590  apf.convertFromAPInt(a, /*IsSigned=*/true,
1591  APFloat::rmNearestTiesToEven);
1592  return apf;
1593  });
1594 }
1595 
1596 //===----------------------------------------------------------------------===//
1597 // FPToUIOp
1598 //===----------------------------------------------------------------------===//
1599 
1600 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1601  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1602 }
1603 
1604 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1605  Type resType = getElementTypeOrSelf(getType());
1606  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1607  return constFoldCastOp<FloatAttr, IntegerAttr>(
1608  adaptor.getOperands(), getType(),
1609  [&bitWidth](const APFloat &a, bool &castStatus) {
1610  bool ignored;
1611  APSInt api(bitWidth, /*isUnsigned=*/true);
1612  castStatus = APFloat::opInvalidOp !=
1613  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1614  return api;
1615  });
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // FPToSIOp
1620 //===----------------------------------------------------------------------===//
1621 
1622 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1623  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1624 }
1625 
1626 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1627  Type resType = getElementTypeOrSelf(getType());
1628  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1629  return constFoldCastOp<FloatAttr, IntegerAttr>(
1630  adaptor.getOperands(), getType(),
1631  [&bitWidth](const APFloat &a, bool &castStatus) {
1632  bool ignored;
1633  APSInt api(bitWidth, /*isUnsigned=*/false);
1634  castStatus = APFloat::opInvalidOp !=
1635  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1636  return api;
1637  });
1638 }
1639 
1640 //===----------------------------------------------------------------------===//
1641 // IndexCastOp
1642 //===----------------------------------------------------------------------===//
1643 
1644 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1645  if (!areValidCastInputsAndOutputs(inputs, outputs))
1646  return false;
1647 
1648  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1649  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1650  if (!srcType || !dstType)
1651  return false;
1652 
1653  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1654  (srcType.isSignlessInteger() && dstType.isIndex());
1655 }
1656 
1657 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1658  TypeRange outputs) {
1659  return areIndexCastCompatible(inputs, outputs);
1660 }
1661 
1662 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1663  // index_cast(constant) -> constant
1664  unsigned resultBitwidth = 64; // Default for index integer attributes.
1665  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1666  resultBitwidth = intTy.getWidth();
1667 
1668  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1669  adaptor.getOperands(), getType(),
1670  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1671  return a.sextOrTrunc(resultBitwidth);
1672  });
1673 }
1674 
1675 void arith::IndexCastOp::getCanonicalizationPatterns(
1676  RewritePatternSet &patterns, MLIRContext *context) {
1677  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // IndexCastUIOp
1682 //===----------------------------------------------------------------------===//
1683 
1684 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1685  TypeRange outputs) {
1686  return areIndexCastCompatible(inputs, outputs);
1687 }
1688 
1689 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1690  // index_castui(constant) -> constant
1691  unsigned resultBitwidth = 64; // Default for index integer attributes.
1692  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1693  resultBitwidth = intTy.getWidth();
1694 
1695  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1696  adaptor.getOperands(), getType(),
1697  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1698  return a.zextOrTrunc(resultBitwidth);
1699  });
1700 }
1701 
1702 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1703  RewritePatternSet &patterns, MLIRContext *context) {
1704  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // BitcastOp
1709 //===----------------------------------------------------------------------===//
1710 
1711 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1712  if (!areValidCastInputsAndOutputs(inputs, outputs))
1713  return false;
1714 
1715  auto srcType =
1716  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1717  auto dstType =
1718  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1719  if (!srcType || !dstType)
1720  return false;
1721 
1722  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1723 }
1724 
1725 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1726  auto resType = getType();
1727  auto operand = adaptor.getIn();
1728  if (!operand)
1729  return {};
1730 
1731  /// Bitcast dense elements.
1732  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1733  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1734  /// Other shaped types unhandled.
1735  if (llvm::isa<ShapedType>(resType))
1736  return {};
1737 
1738  /// Bitcast integer or float to integer or float.
1739  APInt bits = llvm::isa<FloatAttr>(operand)
1740  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1741  : llvm::cast<IntegerAttr>(operand).getValue();
1742 
1743  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1744  return FloatAttr::get(resType,
1745  APFloat(resFloatType.getFloatSemantics(), bits));
1746  return IntegerAttr::get(resType, bits);
1747 }
1748 
1749 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1750  MLIRContext *context) {
1751  patterns.add<BitcastOfBitcast>(context);
1752 }
1753 
1754 //===----------------------------------------------------------------------===//
1755 // CmpIOp
1756 //===----------------------------------------------------------------------===//
1757 
1758 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1759 /// comparison predicates.
1760 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1761  const APInt &lhs, const APInt &rhs) {
1762  switch (predicate) {
1763  case arith::CmpIPredicate::eq:
1764  return lhs.eq(rhs);
1765  case arith::CmpIPredicate::ne:
1766  return lhs.ne(rhs);
1767  case arith::CmpIPredicate::slt:
1768  return lhs.slt(rhs);
1769  case arith::CmpIPredicate::sle:
1770  return lhs.sle(rhs);
1771  case arith::CmpIPredicate::sgt:
1772  return lhs.sgt(rhs);
1773  case arith::CmpIPredicate::sge:
1774  return lhs.sge(rhs);
1775  case arith::CmpIPredicate::ult:
1776  return lhs.ult(rhs);
1777  case arith::CmpIPredicate::ule:
1778  return lhs.ule(rhs);
1779  case arith::CmpIPredicate::ugt:
1780  return lhs.ugt(rhs);
1781  case arith::CmpIPredicate::uge:
1782  return lhs.uge(rhs);
1783  }
1784  llvm_unreachable("unknown cmpi predicate kind");
1785 }
1786 
1787 /// Returns true if the predicate is true for two equal operands.
1788 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1789  switch (predicate) {
1790  case arith::CmpIPredicate::eq:
1791  case arith::CmpIPredicate::sle:
1792  case arith::CmpIPredicate::sge:
1793  case arith::CmpIPredicate::ule:
1794  case arith::CmpIPredicate::uge:
1795  return true;
1796  case arith::CmpIPredicate::ne:
1797  case arith::CmpIPredicate::slt:
1798  case arith::CmpIPredicate::sgt:
1799  case arith::CmpIPredicate::ult:
1800  case arith::CmpIPredicate::ugt:
1801  return false;
1802  }
1803  llvm_unreachable("unknown cmpi predicate kind");
1804 }
1805 
1806 static std::optional<int64_t> getIntegerWidth(Type t) {
1807  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1808  return intType.getWidth();
1809  }
1810  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1811  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1812  }
1813  return std::nullopt;
1814 }
1815 
1816 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1817  // cmpi(pred, x, x)
1818  if (getLhs() == getRhs()) {
1819  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1820  return getBoolAttribute(getType(), val);
1821  }
1822 
1823  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1824  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1825  // extsi(%x : i1 -> iN) != 0 -> %x
1826  std::optional<int64_t> integerWidth =
1827  getIntegerWidth(extOp.getOperand().getType());
1828  if (integerWidth && integerWidth.value() == 1 &&
1829  getPredicate() == arith::CmpIPredicate::ne)
1830  return extOp.getOperand();
1831  }
1832  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1833  // extui(%x : i1 -> iN) != 0 -> %x
1834  std::optional<int64_t> integerWidth =
1835  getIntegerWidth(extOp.getOperand().getType());
1836  if (integerWidth && integerWidth.value() == 1 &&
1837  getPredicate() == arith::CmpIPredicate::ne)
1838  return extOp.getOperand();
1839  }
1840  }
1841 
1842  // Move constant to the right side.
1843  if (adaptor.getLhs() && !adaptor.getRhs()) {
1844  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1845  using Pred = CmpIPredicate;
1846  const std::pair<Pred, Pred> invPreds[] = {
1847  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1848  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1849  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1850  {Pred::ne, Pred::ne},
1851  };
1852  Pred origPred = getPredicate();
1853  for (auto pred : invPreds) {
1854  if (origPred == pred.first) {
1855  setPredicate(pred.second);
1856  Value lhs = getLhs();
1857  Value rhs = getRhs();
1858  getLhsMutable().assign(rhs);
1859  getRhsMutable().assign(lhs);
1860  return getResult();
1861  }
1862  }
1863  llvm_unreachable("unknown cmpi predicate kind");
1864  }
1865 
1866  // We are moving constants to the right side; So if lhs is constant rhs is
1867  // guaranteed to be a constant.
1868  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1869  return constFoldBinaryOp<IntegerAttr>(
1870  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1871  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1872  return APInt(1,
1873  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1874  });
1875  }
1876 
1877  return {};
1878 }
1879 
1880 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1881  MLIRContext *context) {
1882  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // CmpFOp
1887 //===----------------------------------------------------------------------===//
1888 
1889 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1890 /// comparison predicates.
1891 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1892  const APFloat &lhs, const APFloat &rhs) {
1893  auto cmpResult = lhs.compare(rhs);
1894  switch (predicate) {
1895  case arith::CmpFPredicate::AlwaysFalse:
1896  return false;
1897  case arith::CmpFPredicate::OEQ:
1898  return cmpResult == APFloat::cmpEqual;
1899  case arith::CmpFPredicate::OGT:
1900  return cmpResult == APFloat::cmpGreaterThan;
1901  case arith::CmpFPredicate::OGE:
1902  return cmpResult == APFloat::cmpGreaterThan ||
1903  cmpResult == APFloat::cmpEqual;
1904  case arith::CmpFPredicate::OLT:
1905  return cmpResult == APFloat::cmpLessThan;
1906  case arith::CmpFPredicate::OLE:
1907  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1908  case arith::CmpFPredicate::ONE:
1909  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1910  case arith::CmpFPredicate::ORD:
1911  return cmpResult != APFloat::cmpUnordered;
1912  case arith::CmpFPredicate::UEQ:
1913  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1914  case arith::CmpFPredicate::UGT:
1915  return cmpResult == APFloat::cmpUnordered ||
1916  cmpResult == APFloat::cmpGreaterThan;
1917  case arith::CmpFPredicate::UGE:
1918  return cmpResult == APFloat::cmpUnordered ||
1919  cmpResult == APFloat::cmpGreaterThan ||
1920  cmpResult == APFloat::cmpEqual;
1921  case arith::CmpFPredicate::ULT:
1922  return cmpResult == APFloat::cmpUnordered ||
1923  cmpResult == APFloat::cmpLessThan;
1924  case arith::CmpFPredicate::ULE:
1925  return cmpResult == APFloat::cmpUnordered ||
1926  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1927  case arith::CmpFPredicate::UNE:
1928  return cmpResult != APFloat::cmpEqual;
1929  case arith::CmpFPredicate::UNO:
1930  return cmpResult == APFloat::cmpUnordered;
1931  case arith::CmpFPredicate::AlwaysTrue:
1932  return true;
1933  }
1934  llvm_unreachable("unknown cmpf predicate kind");
1935 }
1936 
1937 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1938  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1939  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1940 
1941  // If one operand is NaN, making them both NaN does not change the result.
1942  if (lhs && lhs.getValue().isNaN())
1943  rhs = lhs;
1944  if (rhs && rhs.getValue().isNaN())
1945  lhs = rhs;
1946 
1947  if (!lhs || !rhs)
1948  return {};
1949 
1950  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1951  return BoolAttr::get(getContext(), val);
1952 }
1953 
1954 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1955 public:
1957 
1958  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1959  bool isUnsigned) {
1960  using namespace arith;
1961  switch (pred) {
1962  case CmpFPredicate::UEQ:
1963  case CmpFPredicate::OEQ:
1964  return CmpIPredicate::eq;
1965  case CmpFPredicate::UGT:
1966  case CmpFPredicate::OGT:
1967  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1968  case CmpFPredicate::UGE:
1969  case CmpFPredicate::OGE:
1970  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1971  case CmpFPredicate::ULT:
1972  case CmpFPredicate::OLT:
1973  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1974  case CmpFPredicate::ULE:
1975  case CmpFPredicate::OLE:
1976  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1977  case CmpFPredicate::UNE:
1978  case CmpFPredicate::ONE:
1979  return CmpIPredicate::ne;
1980  default:
1981  llvm_unreachable("Unexpected predicate!");
1982  }
1983  }
1984 
1985  LogicalResult matchAndRewrite(CmpFOp op,
1986  PatternRewriter &rewriter) const override {
1987  FloatAttr flt;
1988  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1989  return failure();
1990 
1991  const APFloat &rhs = flt.getValue();
1992 
1993  // Don't attempt to fold a nan.
1994  if (rhs.isNaN())
1995  return failure();
1996 
1997  // Get the width of the mantissa. We don't want to hack on conversions that
1998  // might lose information from the integer, e.g. "i64 -> float"
1999  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2000  int mantissaWidth = floatTy.getFPMantissaWidth();
2001  if (mantissaWidth <= 0)
2002  return failure();
2003 
2004  bool isUnsigned;
2005  Value intVal;
2006 
2007  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2008  isUnsigned = false;
2009  intVal = si.getIn();
2010  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2011  isUnsigned = true;
2012  intVal = ui.getIn();
2013  } else {
2014  return failure();
2015  }
2016 
2017  // Check to see that the input is converted from an integer type that is
2018  // small enough that preserves all bits.
2019  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2020  auto intWidth = intTy.getWidth();
2021 
2022  // Number of bits representing values, as opposed to the sign
2023  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2024 
2025  // Following test does NOT adjust intWidth downwards for signed inputs,
2026  // because the most negative value still requires all the mantissa bits
2027  // to distinguish it from one less than that value.
2028  if ((int)intWidth > mantissaWidth) {
2029  // Conversion would lose accuracy. Check if loss can impact comparison.
2030  int exponent = ilogb(rhs);
2031  if (exponent == APFloat::IEK_Inf) {
2032  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2033  if (maxExponent < (int)valueBits) {
2034  // Conversion could create infinity.
2035  return failure();
2036  }
2037  } else {
2038  // Note that if rhs is zero or NaN, then Exp is negative
2039  // and first condition is trivially false.
2040  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2041  // Conversion could affect comparison.
2042  return failure();
2043  }
2044  }
2045  }
2046 
2047  // Convert to equivalent cmpi predicate
2048  CmpIPredicate pred;
2049  switch (op.getPredicate()) {
2050  case CmpFPredicate::ORD:
2051  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2052  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2053  /*width=*/1);
2054  return success();
2055  case CmpFPredicate::UNO:
2056  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2057  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2058  /*width=*/1);
2059  return success();
2060  default:
2061  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2062  break;
2063  }
2064 
2065  if (!isUnsigned) {
2066  // If the rhs value is > SignedMax, fold the comparison. This handles
2067  // +INF and large values.
2068  APFloat signedMax(rhs.getSemantics());
2069  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2070  APFloat::rmNearestTiesToEven);
2071  if (signedMax < rhs) { // smax < 13123.0
2072  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2073  pred == CmpIPredicate::sle)
2074  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2075  /*width=*/1);
2076  else
2077  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2078  /*width=*/1);
2079  return success();
2080  }
2081  } else {
2082  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2083  // +INF and large values.
2084  APFloat unsignedMax(rhs.getSemantics());
2085  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2086  APFloat::rmNearestTiesToEven);
2087  if (unsignedMax < rhs) { // umax < 13123.0
2088  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2089  pred == CmpIPredicate::ule)
2090  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2091  /*width=*/1);
2092  else
2093  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2094  /*width=*/1);
2095  return success();
2096  }
2097  }
2098 
2099  if (!isUnsigned) {
2100  // See if the rhs value is < SignedMin.
2101  APFloat signedMin(rhs.getSemantics());
2102  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2103  APFloat::rmNearestTiesToEven);
2104  if (signedMin > rhs) { // smin > 12312.0
2105  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2106  pred == CmpIPredicate::sge)
2107  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2108  /*width=*/1);
2109  else
2110  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2111  /*width=*/1);
2112  return success();
2113  }
2114  } else {
2115  // See if the rhs value is < UnsignedMin.
2116  APFloat unsignedMin(rhs.getSemantics());
2117  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2118  APFloat::rmNearestTiesToEven);
2119  if (unsignedMin > rhs) { // umin > 12312.0
2120  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2121  pred == CmpIPredicate::uge)
2122  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2123  /*width=*/1);
2124  else
2125  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2126  /*width=*/1);
2127  return success();
2128  }
2129  }
2130 
2131  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2132  // [0, UMAX], but it may still be fractional. See if it is fractional by
2133  // casting the FP value to the integer value and back, checking for
2134  // equality. Don't do this for zero, because -0.0 is not fractional.
2135  bool ignored;
2136  APSInt rhsInt(intWidth, isUnsigned);
2137  if (APFloat::opInvalidOp ==
2138  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2139  // Undefined behavior invoked - the destination type can't represent
2140  // the input constant.
2141  return failure();
2142  }
2143 
2144  if (!rhs.isZero()) {
2145  APFloat apf(floatTy.getFloatSemantics(),
2146  APInt::getZero(floatTy.getWidth()));
2147  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2148 
2149  bool equal = apf == rhs;
2150  if (!equal) {
2151  // If we had a comparison against a fractional value, we have to adjust
2152  // the compare predicate and sometimes the value. rhsInt is rounded
2153  // towards zero at this point.
2154  switch (pred) {
2155  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2156  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2157  /*width=*/1);
2158  return success();
2159  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2160  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2161  /*width=*/1);
2162  return success();
2163  case CmpIPredicate::ule:
2164  // (float)int <= 4.4 --> int <= 4
2165  // (float)int <= -4.4 --> false
2166  if (rhs.isNegative()) {
2167  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2168  /*width=*/1);
2169  return success();
2170  }
2171  break;
2172  case CmpIPredicate::sle:
2173  // (float)int <= 4.4 --> int <= 4
2174  // (float)int <= -4.4 --> int < -4
2175  if (rhs.isNegative())
2176  pred = CmpIPredicate::slt;
2177  break;
2178  case CmpIPredicate::ult:
2179  // (float)int < -4.4 --> false
2180  // (float)int < 4.4 --> int <= 4
2181  if (rhs.isNegative()) {
2182  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2183  /*width=*/1);
2184  return success();
2185  }
2186  pred = CmpIPredicate::ule;
2187  break;
2188  case CmpIPredicate::slt:
2189  // (float)int < -4.4 --> int < -4
2190  // (float)int < 4.4 --> int <= 4
2191  if (!rhs.isNegative())
2192  pred = CmpIPredicate::sle;
2193  break;
2194  case CmpIPredicate::ugt:
2195  // (float)int > 4.4 --> int > 4
2196  // (float)int > -4.4 --> true
2197  if (rhs.isNegative()) {
2198  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2199  /*width=*/1);
2200  return success();
2201  }
2202  break;
2203  case CmpIPredicate::sgt:
2204  // (float)int > 4.4 --> int > 4
2205  // (float)int > -4.4 --> int >= -4
2206  if (rhs.isNegative())
2207  pred = CmpIPredicate::sge;
2208  break;
2209  case CmpIPredicate::uge:
2210  // (float)int >= -4.4 --> true
2211  // (float)int >= 4.4 --> int > 4
2212  if (rhs.isNegative()) {
2213  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2214  /*width=*/1);
2215  return success();
2216  }
2217  pred = CmpIPredicate::ugt;
2218  break;
2219  case CmpIPredicate::sge:
2220  // (float)int >= -4.4 --> int >= -4
2221  // (float)int >= 4.4 --> int > 4
2222  if (!rhs.isNegative())
2223  pred = CmpIPredicate::sgt;
2224  break;
2225  }
2226  }
2227  }
2228 
2229  // Lower this FP comparison into an appropriate integer version of the
2230  // comparison.
2231  rewriter.replaceOpWithNewOp<CmpIOp>(
2232  op, pred, intVal,
2233  rewriter.create<ConstantOp>(
2234  op.getLoc(), intVal.getType(),
2235  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2236  return success();
2237  }
2238 };
2239 
2240 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2241  MLIRContext *context) {
2242  patterns.insert<CmpFIntToFPConst>(context);
2243 }
2244 
2245 //===----------------------------------------------------------------------===//
2246 // SelectOp
2247 //===----------------------------------------------------------------------===//
2248 
2249 // select %arg, %c1, %c0 => extui %arg
2250 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2252 
2253  LogicalResult matchAndRewrite(arith::SelectOp op,
2254  PatternRewriter &rewriter) const override {
2255  // Cannot extui i1 to i1, or i1 to f32
2256  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2257  return failure();
2258 
2259  // select %x, c1, %c0 => extui %arg
2260  if (matchPattern(op.getTrueValue(), m_One()) &&
2261  matchPattern(op.getFalseValue(), m_Zero())) {
2262  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2263  op.getCondition());
2264  return success();
2265  }
2266 
2267  // select %x, c0, %c1 => extui (xor %arg, true)
2268  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2269  matchPattern(op.getFalseValue(), m_One())) {
2270  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2271  op, op.getType(),
2272  rewriter.create<arith::XOrIOp>(
2273  op.getLoc(), op.getCondition(),
2274  rewriter.create<arith::ConstantIntOp>(
2275  op.getLoc(), 1, op.getCondition().getType())));
2276  return success();
2277  }
2278 
2279  return failure();
2280  }
2281 };
2282 
2283 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2284  MLIRContext *context) {
2285  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2286  SelectI1ToNot, SelectToExtUI>(context);
2287 }
2288 
2289 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2290  Value trueVal = getTrueValue();
2291  Value falseVal = getFalseValue();
2292  if (trueVal == falseVal)
2293  return trueVal;
2294 
2295  Value condition = getCondition();
2296 
2297  // select true, %0, %1 => %0
2298  if (matchPattern(adaptor.getCondition(), m_One()))
2299  return trueVal;
2300 
2301  // select false, %0, %1 => %1
2302  if (matchPattern(adaptor.getCondition(), m_Zero()))
2303  return falseVal;
2304 
2305  // If either operand is fully poisoned, return the other.
2306  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2307  return falseVal;
2308 
2309  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2310  return trueVal;
2311 
2312  // select %x, true, false => %x
2313  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2314  matchPattern(adaptor.getFalseValue(), m_Zero()))
2315  return condition;
2316 
2317  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2318  auto pred = cmp.getPredicate();
2319  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2320  auto cmpLhs = cmp.getLhs();
2321  auto cmpRhs = cmp.getRhs();
2322 
2323  // %0 = arith.cmpi eq, %arg0, %arg1
2324  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2325 
2326  // %0 = arith.cmpi ne, %arg0, %arg1
2327  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2328 
2329  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2330  (cmpRhs == trueVal && cmpLhs == falseVal))
2331  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2332  }
2333  }
2334 
2335  // Constant-fold constant operands over non-splat constant condition.
2336  // select %cst_vec, %cst0, %cst1 => %cst2
2337  if (auto cond =
2338  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2339  if (auto lhs =
2340  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2341  if (auto rhs =
2342  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2343  SmallVector<Attribute> results;
2344  results.reserve(static_cast<size_t>(cond.getNumElements()));
2345  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2346  cond.value_end<BoolAttr>());
2347  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2348  lhs.value_end<Attribute>());
2349  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2350  rhs.value_end<Attribute>());
2351 
2352  for (auto [condVal, lhsVal, rhsVal] :
2353  llvm::zip_equal(condVals, lhsVals, rhsVals))
2354  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2355 
2356  return DenseElementsAttr::get(lhs.getType(), results);
2357  }
2358  }
2359  }
2360 
2361  return nullptr;
2362 }
2363 
2364 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2365  Type conditionType, resultType;
2367  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2368  parser.parseOptionalAttrDict(result.attributes) ||
2369  parser.parseColonType(resultType))
2370  return failure();
2371 
2372  // Check for the explicit condition type if this is a masked tensor or vector.
2373  if (succeeded(parser.parseOptionalComma())) {
2374  conditionType = resultType;
2375  if (parser.parseType(resultType))
2376  return failure();
2377  } else {
2378  conditionType = parser.getBuilder().getI1Type();
2379  }
2380 
2381  result.addTypes(resultType);
2382  return parser.resolveOperands(operands,
2383  {conditionType, resultType, resultType},
2384  parser.getNameLoc(), result.operands);
2385 }
2386 
2388  p << " " << getOperands();
2389  p.printOptionalAttrDict((*this)->getAttrs());
2390  p << " : ";
2391  if (ShapedType condType =
2392  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2393  p << condType << ", ";
2394  p << getType();
2395 }
2396 
2397 LogicalResult arith::SelectOp::verify() {
2398  Type conditionType = getCondition().getType();
2399  if (conditionType.isSignlessInteger(1))
2400  return success();
2401 
2402  // If the result type is a vector or tensor, the type can be a mask with the
2403  // same elements.
2404  Type resultType = getType();
2405  if (!llvm::isa<TensorType, VectorType>(resultType))
2406  return emitOpError() << "expected condition to be a signless i1, but got "
2407  << conditionType;
2408  Type shapedConditionType = getI1SameShape(resultType);
2409  if (conditionType != shapedConditionType) {
2410  return emitOpError() << "expected condition type to have the same shape "
2411  "as the result type, expected "
2412  << shapedConditionType << ", but got "
2413  << conditionType;
2414  }
2415  return success();
2416 }
2417 //===----------------------------------------------------------------------===//
2418 // ShLIOp
2419 //===----------------------------------------------------------------------===//
2420 
2421 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2422  // shli(x, 0) -> x
2423  if (matchPattern(adaptor.getRhs(), m_Zero()))
2424  return getLhs();
2425  // Don't fold if shifting more or equal than the bit width.
2426  bool bounded = false;
2427  auto result = constFoldBinaryOp<IntegerAttr>(
2428  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2429  bounded = b.ult(b.getBitWidth());
2430  return a.shl(b);
2431  });
2432  return bounded ? result : Attribute();
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // ShRUIOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2440  // shrui(x, 0) -> x
2441  if (matchPattern(adaptor.getRhs(), m_Zero()))
2442  return getLhs();
2443  // Don't fold if shifting more or equal than the bit width.
2444  bool bounded = false;
2445  auto result = constFoldBinaryOp<IntegerAttr>(
2446  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2447  bounded = b.ult(b.getBitWidth());
2448  return a.lshr(b);
2449  });
2450  return bounded ? result : Attribute();
2451 }
2452 
2453 //===----------------------------------------------------------------------===//
2454 // ShRSIOp
2455 //===----------------------------------------------------------------------===//
2456 
2457 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2458  // shrsi(x, 0) -> x
2459  if (matchPattern(adaptor.getRhs(), m_Zero()))
2460  return getLhs();
2461  // Don't fold if shifting more or equal than the bit width.
2462  bool bounded = false;
2463  auto result = constFoldBinaryOp<IntegerAttr>(
2464  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2465  bounded = b.ult(b.getBitWidth());
2466  return a.ashr(b);
2467  });
2468  return bounded ? result : Attribute();
2469 }
2470 
2471 //===----------------------------------------------------------------------===//
2472 // Atomic Enum
2473 //===----------------------------------------------------------------------===//
2474 
2475 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2476 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2477  OpBuilder &builder, Location loc,
2478  bool useOnlyFiniteValue) {
2479  switch (kind) {
2480  case AtomicRMWKind::maximumf: {
2481  const llvm::fltSemantics &semantic =
2482  llvm::cast<FloatType>(resultType).getFloatSemantics();
2483  APFloat identity = useOnlyFiniteValue
2484  ? APFloat::getLargest(semantic, /*Negative=*/true)
2485  : APFloat::getInf(semantic, /*Negative=*/true);
2486  return builder.getFloatAttr(resultType, identity);
2487  }
2488  case AtomicRMWKind::maxnumf: {
2489  const llvm::fltSemantics &semantic =
2490  llvm::cast<FloatType>(resultType).getFloatSemantics();
2491  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2492  return builder.getFloatAttr(resultType, identity);
2493  }
2494  case AtomicRMWKind::addf:
2495  case AtomicRMWKind::addi:
2496  case AtomicRMWKind::maxu:
2497  case AtomicRMWKind::ori:
2498  return builder.getZeroAttr(resultType);
2499  case AtomicRMWKind::andi:
2500  return builder.getIntegerAttr(
2501  resultType,
2502  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2503  case AtomicRMWKind::maxs:
2504  return builder.getIntegerAttr(
2505  resultType, APInt::getSignedMinValue(
2506  llvm::cast<IntegerType>(resultType).getWidth()));
2507  case AtomicRMWKind::minimumf: {
2508  const llvm::fltSemantics &semantic =
2509  llvm::cast<FloatType>(resultType).getFloatSemantics();
2510  APFloat identity = useOnlyFiniteValue
2511  ? APFloat::getLargest(semantic, /*Negative=*/false)
2512  : APFloat::getInf(semantic, /*Negative=*/false);
2513 
2514  return builder.getFloatAttr(resultType, identity);
2515  }
2516  case AtomicRMWKind::minnumf: {
2517  const llvm::fltSemantics &semantic =
2518  llvm::cast<FloatType>(resultType).getFloatSemantics();
2519  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2520  return builder.getFloatAttr(resultType, identity);
2521  }
2522  case AtomicRMWKind::mins:
2523  return builder.getIntegerAttr(
2524  resultType, APInt::getSignedMaxValue(
2525  llvm::cast<IntegerType>(resultType).getWidth()));
2526  case AtomicRMWKind::minu:
2527  return builder.getIntegerAttr(
2528  resultType,
2529  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2530  case AtomicRMWKind::muli:
2531  return builder.getIntegerAttr(resultType, 1);
2532  case AtomicRMWKind::mulf:
2533  return builder.getFloatAttr(resultType, 1);
2534  // TODO: Add remaining reduction operations.
2535  default:
2536  (void)emitOptionalError(loc, "Reduction operation type not supported");
2537  break;
2538  }
2539  return nullptr;
2540 }
2541 
2542 /// Return the identity numeric value associated to the give op.
2543 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2544  std::optional<AtomicRMWKind> maybeKind =
2546  // Floating-point operations.
2547  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2548  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2549  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2550  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2551  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2552  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2553  // Integer operations.
2554  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2555  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2556  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2557  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2558  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2559  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2560  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2561  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2562  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2563  .Default([](Operation *op) { return std::nullopt; });
2564  if (!maybeKind) {
2565  return std::nullopt;
2566  }
2567 
2568  bool useOnlyFiniteValue = false;
2569  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2570  if (fmfOpInterface) {
2571  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2572  useOnlyFiniteValue =
2573  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2574  }
2575 
2576  // Builder only used as helper for attribute creation.
2577  OpBuilder b(op->getContext());
2578  Type resultType = op->getResult(0).getType();
2579 
2580  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2581  useOnlyFiniteValue);
2582 }
2583 
2584 /// Returns the identity value associated with an AtomicRMWKind op.
2585 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2586  OpBuilder &builder, Location loc,
2587  bool useOnlyFiniteValue) {
2588  auto attr =
2589  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2590  return builder.create<arith::ConstantOp>(loc, attr);
2591 }
2592 
2593 /// Return the value obtained by applying the reduction operation kind
2594 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2595 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2596  Location loc, Value lhs, Value rhs) {
2597  switch (op) {
2598  case AtomicRMWKind::addf:
2599  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2600  case AtomicRMWKind::addi:
2601  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2602  case AtomicRMWKind::mulf:
2603  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2604  case AtomicRMWKind::muli:
2605  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2606  case AtomicRMWKind::maximumf:
2607  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2608  case AtomicRMWKind::minimumf:
2609  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2610  case AtomicRMWKind::maxnumf:
2611  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2612  case AtomicRMWKind::minnumf:
2613  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2614  case AtomicRMWKind::maxs:
2615  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2616  case AtomicRMWKind::mins:
2617  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2618  case AtomicRMWKind::maxu:
2619  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2620  case AtomicRMWKind::minu:
2621  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2622  case AtomicRMWKind::ori:
2623  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2624  case AtomicRMWKind::andi:
2625  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2626  // TODO: Add remaining reduction operations.
2627  default:
2628  (void)emitOptionalError(loc, "Reduction operation type not supported");
2629  break;
2630  }
2631  return nullptr;
2632 }
2633 
2634 //===----------------------------------------------------------------------===//
2635 // TableGen'd op method definitions
2636 //===----------------------------------------------------------------------===//
2637 
2638 #define GET_OP_CLASSES
2639 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2640 
2641 //===----------------------------------------------------------------------===//
2642 // TableGen'd enum attribute definitions
2643 //===----------------------------------------------------------------------===//
2644 
2645 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1303
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1239
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1788
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1253
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1225
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:643
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:141
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:149
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1318
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1644
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1542
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1275
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:129
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:834
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1246
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:170
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1806
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1261
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1219
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:345
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1289
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1985
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1958
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
IntegerType getI1Type()
Definition: Builders.cpp:77
IndexType getIndexType()
Definition: Builders.cpp:75
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:68
bool isIndex() const
Definition: Types.cpp:57
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:278
static bool classof(Operation *op)
Definition: ArithOps.cpp:284
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:290
static bool classof(Operation *op)
Definition: ArithOps.cpp:296
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:257
static bool classof(Operation *op)
Definition: ArithOps.cpp:272
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2543
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1760
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2476
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2595
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2585
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:496
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:371
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:350
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:363
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:355
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2253
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes