MLIR  21.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // The value's type must match the return type.
211  if (getValue().getType() != type) {
212  return emitOpError() << "value type " << getValue().getType()
213  << " must match return type: " << type;
214  }
215  // Integer values must be signless.
216  if (llvm::isa<IntegerType>(type) &&
217  !llvm::cast<IntegerType>(type).isSignless())
218  return emitOpError("integer return type must be signless");
219  // Any float or elements attribute are acceptable.
220  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221  return emitOpError(
222  "value must be an integer, float, or elements attribute");
223  }
224 
225  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
226  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
227  // However, this would most likely require updating the lowerings to LLVM.
228  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
229  return emitOpError(
230  "intializing scalable vectors with elements attribute is not supported"
231  " unless it's a vector splat");
232  return success();
233 }
234 
235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
236  // The value's type must be the same as the provided type.
237  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238  if (!typedAttr || typedAttr.getType() != type)
239  return false;
240  // Integer values must be signless.
241  if (llvm::isa<IntegerType>(type) &&
242  !llvm::cast<IntegerType>(type).isSignless())
243  return false;
244  // Integer, float, and element attributes are buildable.
245  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246 }
247 
248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
249  Type type, Location loc) {
250  if (isBuildableWith(value, type))
251  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
252  return nullptr;
253 }
254 
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
256 
258  int64_t value, unsigned width) {
259  auto type = builder.getIntegerType(width);
260  arith::ConstantOp::build(builder, result, type,
261  builder.getIntegerAttr(type, value));
262 }
263 
265  int64_t value, Type type) {
266  assert(type.isSignlessInteger() &&
267  "ConstantIntOp can only have signless integer type values");
268  arith::ConstantOp::build(builder, result, type,
269  builder.getIntegerAttr(type, value));
270 }
271 
273  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274  return constOp.getType().isSignlessInteger();
275  return false;
276 }
277 
279  const APFloat &value, FloatType type) {
280  arith::ConstantOp::build(builder, result, type,
281  builder.getFloatAttr(type, value));
282 }
283 
285  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286  return llvm::isa<FloatType>(constOp.getType());
287  return false;
288 }
289 
291  int64_t value) {
292  arith::ConstantOp::build(builder, result, builder.getIndexType(),
293  builder.getIndexAttr(value));
294 }
295 
297  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298  return constOp.getType().isIndex();
299  return false;
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // AddIOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
307  // addi(x, 0) -> x
308  if (matchPattern(adaptor.getRhs(), m_Zero()))
309  return getLhs();
310 
311  // addi(subi(a, b), b) -> a
312  if (auto sub = getLhs().getDefiningOp<SubIOp>())
313  if (getRhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  // addi(b, subi(a, b)) -> a
317  if (auto sub = getRhs().getDefiningOp<SubIOp>())
318  if (getLhs() == sub.getRhs())
319  return sub.getLhs();
320 
321  return constFoldBinaryOp<IntegerAttr>(
322  adaptor.getOperands(),
323  [](APInt a, const APInt &b) { return std::move(a) + b; });
324 }
325 
326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
327  MLIRContext *context) {
328  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // AddUIExtendedOp
334 //===----------------------------------------------------------------------===//
335 
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
339  return llvm::to_vector<4>(vt.getShape());
340  return std::nullopt;
341 }
342 
343 // Returns the overflow bit, assuming that `sum` is the result of unsigned
344 // addition of `operand` and another number.
345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
346  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
347 }
348 
349 LogicalResult
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352  Type overflowTy = getOverflow().getType();
353  // addui_extended(x, 0) -> x, false
354  if (matchPattern(getRhs(), m_Zero())) {
355  Builder builder(getContext());
356  auto falseValue = builder.getZeroAttr(overflowTy);
357 
358  results.push_back(getLhs());
359  results.push_back(falseValue);
360  return success();
361  }
362 
363  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
364  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
365  // operands. If that succeeds, calculate the overflow bit based on the sum
366  // and the first (constant) operand, `lhs`.
367  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368  adaptor.getOperands(),
369  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
370  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371  ArrayRef({sumAttr, adaptor.getLhs()}),
372  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
374  if (!overflowAttr)
375  return failure();
376 
377  results.push_back(sumAttr);
378  results.push_back(overflowAttr);
379  return success();
380  }
381 
382  return failure();
383 }
384 
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387  patterns.add<AddUIExtendedToAddI>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // SubIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395  // subi(x,x) -> 0
396  if (getOperand(0) == getOperand(1)) {
397  auto shapedType = dyn_cast<ShapedType>(getType());
398  // We can't generate a constant with a dynamic shaped tensor.
399  if (!shapedType || shapedType.hasStaticShape())
400  return Builder(getContext()).getZeroAttr(getType());
401  }
402  // subi(x,0) -> x
403  if (matchPattern(adaptor.getRhs(), m_Zero()))
404  return getLhs();
405 
406  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
407  // subi(addi(a, b), b) -> a
408  if (getRhs() == add.getRhs())
409  return add.getLhs();
410  // subi(addi(a, b), a) -> b
411  if (getRhs() == add.getLhs())
412  return add.getRhs();
413  }
414 
415  return constFoldBinaryOp<IntegerAttr>(
416  adaptor.getOperands(),
417  [](APInt a, const APInt &b) { return std::move(a) - b; });
418 }
419 
420 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
421  MLIRContext *context) {
422  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
423  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
424  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // MulIOp
429 //===----------------------------------------------------------------------===//
430 
431 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
432  // muli(x, 0) -> 0
433  if (matchPattern(adaptor.getRhs(), m_Zero()))
434  return getRhs();
435  // muli(x, 1) -> x
436  if (matchPattern(adaptor.getRhs(), m_One()))
437  return getLhs();
438  // TODO: Handle the overflow case.
439 
440  // default folder
441  return constFoldBinaryOp<IntegerAttr>(
442  adaptor.getOperands(),
443  [](const APInt &a, const APInt &b) { return a * b; });
444 }
445 
446 void arith::MulIOp::getAsmResultNames(
447  function_ref<void(Value, StringRef)> setNameFn) {
448  if (!isa<IndexType>(getType()))
449  return;
450 
451  // Match vector.vscale by name to avoid depending on the vector dialect (which
452  // is a circular dependency).
453  auto isVscale = [](Operation *op) {
454  return op && op->getName().getStringRef() == "vector.vscale";
455  };
456 
457  IntegerAttr baseValue;
458  auto isVscaleExpr = [&](Value a, Value b) {
459  return matchPattern(a, m_Constant(&baseValue)) &&
460  isVscale(b.getDefiningOp());
461  };
462 
463  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
464  return;
465 
466  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
467  SmallString<32> specialNameBuffer;
468  llvm::raw_svector_ostream specialName(specialNameBuffer);
469  specialName << 'c' << baseValue.getInt() << "_vscale";
470  setNameFn(getResult(), specialName.str());
471 }
472 
473 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
474  MLIRContext *context) {
475  patterns.add<MulIMulIConstant>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // MulSIExtendedOp
480 //===----------------------------------------------------------------------===//
481 
482 std::optional<SmallVector<int64_t, 4>>
483 arith::MulSIExtendedOp::getShapeForUnroll() {
484  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
485  return llvm::to_vector<4>(vt.getShape());
486  return std::nullopt;
487 }
488 
489 LogicalResult
490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
492  // mulsi_extended(x, 0) -> 0, 0
493  if (matchPattern(adaptor.getRhs(), m_Zero())) {
494  Attribute zero = adaptor.getRhs();
495  results.push_back(zero);
496  results.push_back(zero);
497  return success();
498  }
499 
500  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
501  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
502  adaptor.getOperands(),
503  [](const APInt &a, const APInt &b) { return a * b; })) {
504  // Invoke the constant fold helper again to calculate the 'high' result.
505  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
506  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
507  return llvm::APIntOps::mulhs(a, b);
508  });
509  assert(highAttr && "Unexpected constant-folding failure");
510 
511  results.push_back(lowAttr);
512  results.push_back(highAttr);
513  return success();
514  }
515 
516  return failure();
517 }
518 
519 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
521  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // MulUIExtendedOp
526 //===----------------------------------------------------------------------===//
527 
528 std::optional<SmallVector<int64_t, 4>>
529 arith::MulUIExtendedOp::getShapeForUnroll() {
530  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
531  return llvm::to_vector<4>(vt.getShape());
532  return std::nullopt;
533 }
534 
535 LogicalResult
536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
538  // mului_extended(x, 0) -> 0, 0
539  if (matchPattern(adaptor.getRhs(), m_Zero())) {
540  Attribute zero = adaptor.getRhs();
541  results.push_back(zero);
542  results.push_back(zero);
543  return success();
544  }
545 
546  // mului_extended(x, 1) -> x, 0
547  if (matchPattern(adaptor.getRhs(), m_One())) {
548  Builder builder(getContext());
549  Attribute zero = builder.getZeroAttr(getLhs().getType());
550  results.push_back(getLhs());
551  results.push_back(zero);
552  return success();
553  }
554 
555  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
556  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
557  adaptor.getOperands(),
558  [](const APInt &a, const APInt &b) { return a * b; })) {
559  // Invoke the constant fold helper again to calculate the 'high' result.
560  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
561  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
562  return llvm::APIntOps::mulhu(a, b);
563  });
564  assert(highAttr && "Unexpected constant-folding failure");
565 
566  results.push_back(lowAttr);
567  results.push_back(highAttr);
568  return success();
569  }
570 
571  return failure();
572 }
573 
574 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
576  patterns.add<MulUIExtendedToMulI>(context);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // DivUIOp
581 //===----------------------------------------------------------------------===//
582 
583 /// Fold `(a * b) / b -> a`
584 static Value foldDivMul(Value lhs, Value rhs,
585  arith::IntegerOverflowFlags ovfFlags) {
586  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
587  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
588  return {};
589 
590  if (mul.getLhs() == rhs)
591  return mul.getRhs();
592 
593  if (mul.getRhs() == rhs)
594  return mul.getLhs();
595 
596  return {};
597 }
598 
599 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
600  // divui (x, 1) -> x.
601  if (matchPattern(adaptor.getRhs(), m_One()))
602  return getLhs();
603 
604  // (a * b) / b -> a
605  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
606  return val;
607 
608  // Don't fold if it would require a division by zero.
609  bool div0 = false;
610  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
611  [&](APInt a, const APInt &b) {
612  if (div0 || !b) {
613  div0 = true;
614  return a;
615  }
616  return a.udiv(b);
617  });
618 
619  return div0 ? Attribute() : result;
620 }
621 
622 /// Returns whether an unsigned division by `divisor` is speculatable.
624  // X / 0 => UB
625  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
627 
629 }
630 
631 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
632  return getDivUISpeculatability(getRhs());
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // DivSIOp
637 //===----------------------------------------------------------------------===//
638 
639 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
640  // divsi (x, 1) -> x.
641  if (matchPattern(adaptor.getRhs(), m_One()))
642  return getLhs();
643 
644  // (a * b) / b -> a
645  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
646  return val;
647 
648  // Don't fold if it would overflow or if it requires a division by zero.
649  bool overflowOrDiv0 = false;
650  auto result = constFoldBinaryOp<IntegerAttr>(
651  adaptor.getOperands(), [&](APInt a, const APInt &b) {
652  if (overflowOrDiv0 || !b) {
653  overflowOrDiv0 = true;
654  return a;
655  }
656  return a.sdiv_ov(b, overflowOrDiv0);
657  });
658 
659  return overflowOrDiv0 ? Attribute() : result;
660 }
661 
662 /// Returns whether a signed division by `divisor` is speculatable. This
663 /// function conservatively assumes that all signed division by -1 are not
664 /// speculatable.
666  // X / 0 => UB
667  // INT_MIN / -1 => UB
668  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
671 
673 }
674 
675 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
676  return getDivSISpeculatability(getRhs());
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // Ceil and floor division folding helpers
681 //===----------------------------------------------------------------------===//
682 
683 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
684  bool &overflow) {
685  // Returns (a-1)/b + 1
686  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
687  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
688  return val.sadd_ov(one, overflow);
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // CeilDivUIOp
693 //===----------------------------------------------------------------------===//
694 
695 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
696  // ceildivui (x, 1) -> x.
697  if (matchPattern(adaptor.getRhs(), m_One()))
698  return getLhs();
699 
700  bool overflowOrDiv0 = false;
701  auto result = constFoldBinaryOp<IntegerAttr>(
702  adaptor.getOperands(), [&](APInt a, const APInt &b) {
703  if (overflowOrDiv0 || !b) {
704  overflowOrDiv0 = true;
705  return a;
706  }
707  APInt quotient = a.udiv(b);
708  if (!a.urem(b))
709  return quotient;
710  APInt one(a.getBitWidth(), 1, true);
711  return quotient.uadd_ov(one, overflowOrDiv0);
712  });
713 
714  return overflowOrDiv0 ? Attribute() : result;
715 }
716 
717 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
718  return getDivUISpeculatability(getRhs());
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // CeilDivSIOp
723 //===----------------------------------------------------------------------===//
724 
725 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
726  // ceildivsi (x, 1) -> x.
727  if (matchPattern(adaptor.getRhs(), m_One()))
728  return getLhs();
729 
730  // Don't fold if it would overflow or if it requires a division by zero.
731  // TODO: This hook won't fold operations where a = MININT, because
732  // negating MININT overflows. This can be improved.
733  bool overflowOrDiv0 = false;
734  auto result = constFoldBinaryOp<IntegerAttr>(
735  adaptor.getOperands(), [&](APInt a, const APInt &b) {
736  if (overflowOrDiv0 || !b) {
737  overflowOrDiv0 = true;
738  return a;
739  }
740  if (!a)
741  return a;
742  // After this point we know that neither a or b are zero.
743  unsigned bits = a.getBitWidth();
744  APInt zero = APInt::getZero(bits);
745  bool aGtZero = a.sgt(zero);
746  bool bGtZero = b.sgt(zero);
747  if (aGtZero && bGtZero) {
748  // Both positive, return ceil(a, b).
749  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
750  }
751 
752  // No folding happens if any of the intermediate arithmetic operations
753  // overflows.
754  bool overflowNegA = false;
755  bool overflowNegB = false;
756  bool overflowDiv = false;
757  bool overflowNegRes = false;
758  if (!aGtZero && !bGtZero) {
759  // Both negative, return ceil(-a, -b).
760  APInt posA = zero.ssub_ov(a, overflowNegA);
761  APInt posB = zero.ssub_ov(b, overflowNegB);
762  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
763  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
764  return res;
765  }
766  if (!aGtZero && bGtZero) {
767  // A is negative, b is positive, return - ( -a / b).
768  APInt posA = zero.ssub_ov(a, overflowNegA);
769  APInt div = posA.sdiv_ov(b, overflowDiv);
770  APInt res = zero.ssub_ov(div, overflowNegRes);
771  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
772  return res;
773  }
774  // A is positive, b is negative, return - (a / -b).
775  APInt posB = zero.ssub_ov(b, overflowNegB);
776  APInt div = a.sdiv_ov(posB, overflowDiv);
777  APInt res = zero.ssub_ov(div, overflowNegRes);
778 
779  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
780  return res;
781  });
782 
783  return overflowOrDiv0 ? Attribute() : result;
784 }
785 
786 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
787  return getDivSISpeculatability(getRhs());
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // FloorDivSIOp
792 //===----------------------------------------------------------------------===//
793 
794 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
795  // floordivsi (x, 1) -> x.
796  if (matchPattern(adaptor.getRhs(), m_One()))
797  return getLhs();
798 
799  // Don't fold if it would overflow or if it requires a division by zero.
800  bool overflowOrDiv = false;
801  auto result = constFoldBinaryOp<IntegerAttr>(
802  adaptor.getOperands(), [&](APInt a, const APInt &b) {
803  if (b.isZero()) {
804  overflowOrDiv = true;
805  return a;
806  }
807  return a.sfloordiv_ov(b, overflowOrDiv);
808  });
809 
810  return overflowOrDiv ? Attribute() : result;
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // RemUIOp
815 //===----------------------------------------------------------------------===//
816 
817 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
818  // remui (x, 1) -> 0.
819  if (matchPattern(adaptor.getRhs(), m_One()))
820  return Builder(getContext()).getZeroAttr(getType());
821 
822  // Don't fold if it would require a division by zero.
823  bool div0 = false;
824  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
825  [&](APInt a, const APInt &b) {
826  if (div0 || b.isZero()) {
827  div0 = true;
828  return a;
829  }
830  return a.urem(b);
831  });
832 
833  return div0 ? Attribute() : result;
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // RemSIOp
838 //===----------------------------------------------------------------------===//
839 
840 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
841  // remsi (x, 1) -> 0.
842  if (matchPattern(adaptor.getRhs(), m_One()))
843  return Builder(getContext()).getZeroAttr(getType());
844 
845  // Don't fold if it would require a division by zero.
846  bool div0 = false;
847  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
848  [&](APInt a, const APInt &b) {
849  if (div0 || b.isZero()) {
850  div0 = true;
851  return a;
852  }
853  return a.srem(b);
854  });
855 
856  return div0 ? Attribute() : result;
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // AndIOp
861 //===----------------------------------------------------------------------===//
862 
863 /// Fold `and(a, and(a, b))` to `and(a, b)`
864 static Value foldAndIofAndI(arith::AndIOp op) {
865  for (bool reversePrev : {false, true}) {
866  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
867  .getDefiningOp<arith::AndIOp>();
868  if (!prev)
869  continue;
870 
871  Value other = (reversePrev ? op.getLhs() : op.getRhs());
872  if (other != prev.getLhs() && other != prev.getRhs())
873  continue;
874 
875  return prev.getResult();
876  }
877  return {};
878 }
879 
880 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
881  /// and(x, 0) -> 0
882  if (matchPattern(adaptor.getRhs(), m_Zero()))
883  return getRhs();
884  /// and(x, allOnes) -> x
885  APInt intValue;
886  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
887  intValue.isAllOnes())
888  return getLhs();
889  /// and(x, not(x)) -> 0
890  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
891  m_ConstantInt(&intValue))) &&
892  intValue.isAllOnes())
893  return Builder(getContext()).getZeroAttr(getType());
894  /// and(not(x), x) -> 0
895  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
896  m_ConstantInt(&intValue))) &&
897  intValue.isAllOnes())
898  return Builder(getContext()).getZeroAttr(getType());
899 
900  /// and(a, and(a, b)) -> and(a, b)
901  if (Value result = foldAndIofAndI(*this))
902  return result;
903 
904  return constFoldBinaryOp<IntegerAttr>(
905  adaptor.getOperands(),
906  [](APInt a, const APInt &b) { return std::move(a) & b; });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // OrIOp
911 //===----------------------------------------------------------------------===//
912 
913 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
914  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
915  /// or(x, 0) -> x
916  if (rhsVal.isZero())
917  return getLhs();
918  /// or(x, <all ones>) -> <all ones>
919  if (rhsVal.isAllOnes())
920  return adaptor.getRhs();
921  }
922 
923  APInt intValue;
924  /// or(x, xor(x, 1)) -> 1
925  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
926  m_ConstantInt(&intValue))) &&
927  intValue.isAllOnes())
928  return getRhs().getDefiningOp<XOrIOp>().getRhs();
929  /// or(xor(x, 1), x) -> 1
930  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
931  m_ConstantInt(&intValue))) &&
932  intValue.isAllOnes())
933  return getLhs().getDefiningOp<XOrIOp>().getRhs();
934 
935  return constFoldBinaryOp<IntegerAttr>(
936  adaptor.getOperands(),
937  [](APInt a, const APInt &b) { return std::move(a) | b; });
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // XOrIOp
942 //===----------------------------------------------------------------------===//
943 
944 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
945  /// xor(x, 0) -> x
946  if (matchPattern(adaptor.getRhs(), m_Zero()))
947  return getLhs();
948  /// xor(x, x) -> 0
949  if (getLhs() == getRhs())
950  return Builder(getContext()).getZeroAttr(getType());
951  /// xor(xor(x, a), a) -> x
952  /// xor(xor(a, x), a) -> x
953  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
954  if (prev.getRhs() == getRhs())
955  return prev.getLhs();
956  if (prev.getLhs() == getRhs())
957  return prev.getRhs();
958  }
959  /// xor(a, xor(x, a)) -> x
960  /// xor(a, xor(a, x)) -> x
961  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
962  if (prev.getRhs() == getLhs())
963  return prev.getLhs();
964  if (prev.getLhs() == getLhs())
965  return prev.getRhs();
966  }
967 
968  return constFoldBinaryOp<IntegerAttr>(
969  adaptor.getOperands(),
970  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
971 }
972 
973 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
974  MLIRContext *context) {
975  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // NegFOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
983  /// negf(negf(x)) -> x
984  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
985  return op.getOperand();
986  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
987  [](const APFloat &a) { return -a; });
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // AddFOp
992 //===----------------------------------------------------------------------===//
993 
994 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
995  // addf(x, -0) -> x
996  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
997  return getLhs();
998 
999  return constFoldBinaryOp<FloatAttr>(
1000  adaptor.getOperands(),
1001  [](const APFloat &a, const APFloat &b) { return a + b; });
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // SubFOp
1006 //===----------------------------------------------------------------------===//
1007 
1008 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1009  // subf(x, +0) -> x
1010  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1011  return getLhs();
1012 
1013  return constFoldBinaryOp<FloatAttr>(
1014  adaptor.getOperands(),
1015  [](const APFloat &a, const APFloat &b) { return a - b; });
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // MaximumFOp
1020 //===----------------------------------------------------------------------===//
1021 
1022 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1023  // maximumf(x,x) -> x
1024  if (getLhs() == getRhs())
1025  return getRhs();
1026 
1027  // maximumf(x, -inf) -> x
1028  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1029  return getLhs();
1030 
1031  return constFoldBinaryOp<FloatAttr>(
1032  adaptor.getOperands(),
1033  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // MaxNumFOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1041  // maxnumf(x,x) -> x
1042  if (getLhs() == getRhs())
1043  return getRhs();
1044 
1045  // maxnumf(x, NaN) -> x
1046  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1047  return getLhs();
1048 
1049  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // MaxSIOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1057  // maxsi(x,x) -> x
1058  if (getLhs() == getRhs())
1059  return getRhs();
1060 
1061  if (APInt intValue;
1062  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1063  // maxsi(x,MAX_INT) -> MAX_INT
1064  if (intValue.isMaxSignedValue())
1065  return getRhs();
1066  // maxsi(x, MIN_INT) -> x
1067  if (intValue.isMinSignedValue())
1068  return getLhs();
1069  }
1070 
1071  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1072  [](const APInt &a, const APInt &b) {
1073  return llvm::APIntOps::smax(a, b);
1074  });
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // MaxUIOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1082  // maxui(x,x) -> x
1083  if (getLhs() == getRhs())
1084  return getRhs();
1085 
1086  if (APInt intValue;
1087  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1088  // maxui(x,MAX_INT) -> MAX_INT
1089  if (intValue.isMaxValue())
1090  return getRhs();
1091  // maxui(x, MIN_INT) -> x
1092  if (intValue.isMinValue())
1093  return getLhs();
1094  }
1095 
1096  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1097  [](const APInt &a, const APInt &b) {
1098  return llvm::APIntOps::umax(a, b);
1099  });
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // MinimumFOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1107  // minimumf(x,x) -> x
1108  if (getLhs() == getRhs())
1109  return getRhs();
1110 
1111  // minimumf(x, +inf) -> x
1112  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1113  return getLhs();
1114 
1115  return constFoldBinaryOp<FloatAttr>(
1116  adaptor.getOperands(),
1117  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // MinNumFOp
1122 //===----------------------------------------------------------------------===//
1123 
1124 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1125  // minnumf(x,x) -> x
1126  if (getLhs() == getRhs())
1127  return getRhs();
1128 
1129  // minnumf(x, NaN) -> x
1130  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1131  return getLhs();
1132 
1133  return constFoldBinaryOp<FloatAttr>(
1134  adaptor.getOperands(),
1135  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // MinSIOp
1140 //===----------------------------------------------------------------------===//
1141 
1142 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1143  // minsi(x,x) -> x
1144  if (getLhs() == getRhs())
1145  return getRhs();
1146 
1147  if (APInt intValue;
1148  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1149  // minsi(x,MIN_INT) -> MIN_INT
1150  if (intValue.isMinSignedValue())
1151  return getRhs();
1152  // minsi(x, MAX_INT) -> x
1153  if (intValue.isMaxSignedValue())
1154  return getLhs();
1155  }
1156 
1157  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1158  [](const APInt &a, const APInt &b) {
1159  return llvm::APIntOps::smin(a, b);
1160  });
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // MinUIOp
1165 //===----------------------------------------------------------------------===//
1166 
1167 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1168  // minui(x,x) -> x
1169  if (getLhs() == getRhs())
1170  return getRhs();
1171 
1172  if (APInt intValue;
1173  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1174  // minui(x,MIN_INT) -> MIN_INT
1175  if (intValue.isMinValue())
1176  return getRhs();
1177  // minui(x, MAX_INT) -> x
1178  if (intValue.isMaxValue())
1179  return getLhs();
1180  }
1181 
1182  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1183  [](const APInt &a, const APInt &b) {
1184  return llvm::APIntOps::umin(a, b);
1185  });
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // MulFOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1193  // mulf(x, 1) -> x
1194  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1195  return getLhs();
1196 
1197  return constFoldBinaryOp<FloatAttr>(
1198  adaptor.getOperands(),
1199  [](const APFloat &a, const APFloat &b) { return a * b; });
1200 }
1201 
1202 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1203  MLIRContext *context) {
1204  patterns.add<MulFOfNegF>(context);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // DivFOp
1209 //===----------------------------------------------------------------------===//
1210 
1211 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1212  // divf(x, 1) -> x
1213  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1214  return getLhs();
1215 
1216  return constFoldBinaryOp<FloatAttr>(
1217  adaptor.getOperands(),
1218  [](const APFloat &a, const APFloat &b) { return a / b; });
1219 }
1220 
1221 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1222  MLIRContext *context) {
1223  patterns.add<DivFOfNegF>(context);
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // RemFOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1231  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1232  [](const APFloat &a, const APFloat &b) {
1233  APFloat result(a);
1234  // APFloat::mod() offers the remainder
1235  // behavior we want, i.e. the result has
1236  // the sign of LHS operand.
1237  (void)result.mod(b);
1238  return result;
1239  });
1240 }
1241 
1242 //===----------------------------------------------------------------------===//
1243 // Utility functions for verifying cast ops
1244 //===----------------------------------------------------------------------===//
1245 
1246 template <typename... Types>
1247 using type_list = std::tuple<Types...> *;
1248 
1249 /// Returns a non-null type only if the provided type is one of the allowed
1250 /// types or one of the allowed shaped types of the allowed types. Returns the
1251 /// element type if a valid shaped type is provided.
1252 template <typename... ShapedTypes, typename... ElementTypes>
1255  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1256  return {};
1257 
1258  auto underlyingType = getElementTypeOrSelf(type);
1259  if (!llvm::isa<ElementTypes...>(underlyingType))
1260  return {};
1261 
1262  return underlyingType;
1263 }
1264 
1265 /// Get allowed underlying types for vectors and tensors.
1266 template <typename... ElementTypes>
1267 static Type getTypeIfLike(Type type) {
1270 }
1271 
1272 /// Get allowed underlying types for vectors, tensors, and memrefs.
1273 template <typename... ElementTypes>
1275  return getUnderlyingType(type,
1278 }
1279 
1280 /// Return false if both types are ranked tensor with mismatching encoding.
1281 static bool hasSameEncoding(Type typeA, Type typeB) {
1282  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1283  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1284  if (!rankedTensorA || !rankedTensorB)
1285  return true;
1286  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1287 }
1288 
1289 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1290  if (inputs.size() != 1 || outputs.size() != 1)
1291  return false;
1292  if (!hasSameEncoding(inputs.front(), outputs.front()))
1293  return false;
1294  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1295 }
1296 
1297 //===----------------------------------------------------------------------===//
1298 // Verifiers for integer and floating point extension/truncation ops
1299 //===----------------------------------------------------------------------===//
1300 
1301 // Extend ops can only extend to a wider type.
1302 template <typename ValType, typename Op>
1303 static LogicalResult verifyExtOp(Op op) {
1304  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1305  Type dstType = getElementTypeOrSelf(op.getType());
1306 
1307  if (llvm::cast<ValType>(srcType).getWidth() >=
1308  llvm::cast<ValType>(dstType).getWidth())
1309  return op.emitError("result type ")
1310  << dstType << " must be wider than operand type " << srcType;
1311 
1312  return success();
1313 }
1314 
1315 // Truncate ops can only truncate to a shorter type.
1316 template <typename ValType, typename Op>
1317 static LogicalResult verifyTruncateOp(Op op) {
1318  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1319  Type dstType = getElementTypeOrSelf(op.getType());
1320 
1321  if (llvm::cast<ValType>(srcType).getWidth() <=
1322  llvm::cast<ValType>(dstType).getWidth())
1323  return op.emitError("result type ")
1324  << dstType << " must be shorter than operand type " << srcType;
1325 
1326  return success();
1327 }
1328 
1329 /// Validate a cast that changes the width of a type.
1330 template <template <typename> class WidthComparator, typename... ElementTypes>
1331 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1332  if (!areValidCastInputsAndOutputs(inputs, outputs))
1333  return false;
1334 
1335  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1336  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1337  if (!srcType || !dstType)
1338  return false;
1339 
1340  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1341  srcType.getIntOrFloatBitWidth());
1342 }
1343 
1344 /// Attempts to convert `sourceValue` to an APFloat value with
1345 /// `targetSemantics` and `roundingMode`, without any information loss.
1346 static FailureOr<APFloat> convertFloatValue(
1347  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1348  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1349  bool losesInfo = false;
1350  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1351  if (losesInfo || status != APFloat::opOK)
1352  return failure();
1353 
1354  return sourceValue;
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // ExtUIOp
1359 //===----------------------------------------------------------------------===//
1360 
1361 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1362  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1363  getInMutable().assign(lhs.getIn());
1364  return getResult();
1365  }
1366 
1367  Type resType = getElementTypeOrSelf(getType());
1368  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1369  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1370  adaptor.getOperands(), getType(),
1371  [bitWidth](const APInt &a, bool &castStatus) {
1372  return a.zext(bitWidth);
1373  });
1374 }
1375 
1376 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1377  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1378 }
1379 
1380 LogicalResult arith::ExtUIOp::verify() {
1381  return verifyExtOp<IntegerType>(*this);
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // ExtSIOp
1386 //===----------------------------------------------------------------------===//
1387 
1388 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1389  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1390  getInMutable().assign(lhs.getIn());
1391  return getResult();
1392  }
1393 
1394  Type resType = getElementTypeOrSelf(getType());
1395  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1396  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1397  adaptor.getOperands(), getType(),
1398  [bitWidth](const APInt &a, bool &castStatus) {
1399  return a.sext(bitWidth);
1400  });
1401 }
1402 
1403 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1404  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1405 }
1406 
1407 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1408  MLIRContext *context) {
1409  patterns.add<ExtSIOfExtUI>(context);
1410 }
1411 
1412 LogicalResult arith::ExtSIOp::verify() {
1413  return verifyExtOp<IntegerType>(*this);
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // ExtFOp
1418 //===----------------------------------------------------------------------===//
1419 
1420 /// Fold extension of float constants when there is no information loss due the
1421 /// difference in fp semantics.
1422 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1423  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1424  if (truncFOp.getOperand().getType() == getType()) {
1425  arith::FastMathFlags truncFMF =
1426  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1427  bool isTruncContract =
1428  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1429  arith::FastMathFlags extFMF =
1430  getFastmath().value_or(arith::FastMathFlags::none);
1431  bool isExtContract =
1432  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1433  if (isTruncContract && isExtContract) {
1434  return truncFOp.getOperand();
1435  }
1436  }
1437  }
1438 
1439  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1440  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1441  return constFoldCastOp<FloatAttr, FloatAttr>(
1442  adaptor.getOperands(), getType(),
1443  [&targetSemantics](const APFloat &a, bool &castStatus) {
1444  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1445  if (failed(result)) {
1446  castStatus = false;
1447  return a;
1448  }
1449  return *result;
1450  });
1451 }
1452 
1453 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1454  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1455 }
1456 
1457 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1458 
1459 //===----------------------------------------------------------------------===//
1460 // TruncIOp
1461 //===----------------------------------------------------------------------===//
1462 
1463 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1464  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1465  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1466  Value src = getOperand().getDefiningOp()->getOperand(0);
1467  Type srcType = getElementTypeOrSelf(src.getType());
1468  Type dstType = getElementTypeOrSelf(getType());
1469  // trunci(zexti(a)) -> trunci(a)
1470  // trunci(sexti(a)) -> trunci(a)
1471  if (llvm::cast<IntegerType>(srcType).getWidth() >
1472  llvm::cast<IntegerType>(dstType).getWidth()) {
1473  setOperand(src);
1474  return getResult();
1475  }
1476 
1477  // trunci(zexti(a)) -> a
1478  // trunci(sexti(a)) -> a
1479  if (srcType == dstType)
1480  return src;
1481  }
1482 
1483  // trunci(trunci(a)) -> trunci(a))
1484  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1485  setOperand(getOperand().getDefiningOp()->getOperand(0));
1486  return getResult();
1487  }
1488 
1489  Type resType = getElementTypeOrSelf(getType());
1490  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1491  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1492  adaptor.getOperands(), getType(),
1493  [bitWidth](const APInt &a, bool &castStatus) {
1494  return a.trunc(bitWidth);
1495  });
1496 }
1497 
1498 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1499  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1500 }
1501 
1502 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1503  MLIRContext *context) {
1504  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1505  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1506  context);
1507 }
1508 
1509 LogicalResult arith::TruncIOp::verify() {
1510  return verifyTruncateOp<IntegerType>(*this);
1511 }
1512 
1513 //===----------------------------------------------------------------------===//
1514 // TruncFOp
1515 //===----------------------------------------------------------------------===//
1516 
1517 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1518 /// can be represented without precision loss.
1519 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1520  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1521  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1522  Value src = extOp.getIn();
1523  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1524  auto intermediateType =
1525  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1526  // Check if the srcType is representable in the intermediateType.
1527  if (llvm::APFloatBase::isRepresentableBy(
1528  srcType.getFloatSemantics(),
1529  intermediateType.getFloatSemantics())) {
1530  // truncf(extf(a)) -> truncf(a)
1531  if (srcType.getWidth() > resElemType.getWidth()) {
1532  setOperand(src);
1533  return getResult();
1534  }
1535 
1536  // truncf(extf(a)) -> a
1537  if (srcType == resElemType)
1538  return src;
1539  }
1540  }
1541 
1542  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1543  return constFoldCastOp<FloatAttr, FloatAttr>(
1544  adaptor.getOperands(), getType(),
1545  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1546  RoundingMode roundingMode =
1547  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1548  llvm::RoundingMode llvmRoundingMode =
1549  convertArithRoundingModeToLLVMIR(roundingMode);
1550  FailureOr<APFloat> result =
1551  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1552  if (failed(result)) {
1553  castStatus = false;
1554  return a;
1555  }
1556  return *result;
1557  });
1558 }
1559 
1560 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1561  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1562 }
1563 
1564 LogicalResult arith::TruncFOp::verify() {
1565  return verifyTruncateOp<FloatType>(*this);
1566 }
1567 
1568 //===----------------------------------------------------------------------===//
1569 // AndIOp
1570 //===----------------------------------------------------------------------===//
1571 
1572 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1573  MLIRContext *context) {
1574  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1575 }
1576 
1577 //===----------------------------------------------------------------------===//
1578 // OrIOp
1579 //===----------------------------------------------------------------------===//
1580 
1581 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1582  MLIRContext *context) {
1583  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 // Verifiers for casts between integers and floats.
1588 //===----------------------------------------------------------------------===//
1589 
1590 template <typename From, typename To>
1591 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1592  if (!areValidCastInputsAndOutputs(inputs, outputs))
1593  return false;
1594 
1595  auto srcType = getTypeIfLike<From>(inputs.front());
1596  auto dstType = getTypeIfLike<To>(outputs.back());
1597 
1598  return srcType && dstType;
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // UIToFPOp
1603 //===----------------------------------------------------------------------===//
1604 
1605 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1606  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1607 }
1608 
1609 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1610  Type resEleType = getElementTypeOrSelf(getType());
1611  return constFoldCastOp<IntegerAttr, FloatAttr>(
1612  adaptor.getOperands(), getType(),
1613  [&resEleType](const APInt &a, bool &castStatus) {
1614  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1615  APFloat apf(floatTy.getFloatSemantics(),
1616  APInt::getZero(floatTy.getWidth()));
1617  apf.convertFromAPInt(a, /*IsSigned=*/false,
1618  APFloat::rmNearestTiesToEven);
1619  return apf;
1620  });
1621 }
1622 
1623 //===----------------------------------------------------------------------===//
1624 // SIToFPOp
1625 //===----------------------------------------------------------------------===//
1626 
1627 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1628  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1629 }
1630 
1631 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1632  Type resEleType = getElementTypeOrSelf(getType());
1633  return constFoldCastOp<IntegerAttr, FloatAttr>(
1634  adaptor.getOperands(), getType(),
1635  [&resEleType](const APInt &a, bool &castStatus) {
1636  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1637  APFloat apf(floatTy.getFloatSemantics(),
1638  APInt::getZero(floatTy.getWidth()));
1639  apf.convertFromAPInt(a, /*IsSigned=*/true,
1640  APFloat::rmNearestTiesToEven);
1641  return apf;
1642  });
1643 }
1644 
1645 //===----------------------------------------------------------------------===//
1646 // FPToUIOp
1647 //===----------------------------------------------------------------------===//
1648 
1649 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1650  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1651 }
1652 
1653 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1654  Type resType = getElementTypeOrSelf(getType());
1655  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1656  return constFoldCastOp<FloatAttr, IntegerAttr>(
1657  adaptor.getOperands(), getType(),
1658  [&bitWidth](const APFloat &a, bool &castStatus) {
1659  bool ignored;
1660  APSInt api(bitWidth, /*isUnsigned=*/true);
1661  castStatus = APFloat::opInvalidOp !=
1662  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1663  return api;
1664  });
1665 }
1666 
1667 //===----------------------------------------------------------------------===//
1668 // FPToSIOp
1669 //===----------------------------------------------------------------------===//
1670 
1671 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1672  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1673 }
1674 
1675 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1676  Type resType = getElementTypeOrSelf(getType());
1677  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1678  return constFoldCastOp<FloatAttr, IntegerAttr>(
1679  adaptor.getOperands(), getType(),
1680  [&bitWidth](const APFloat &a, bool &castStatus) {
1681  bool ignored;
1682  APSInt api(bitWidth, /*isUnsigned=*/false);
1683  castStatus = APFloat::opInvalidOp !=
1684  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1685  return api;
1686  });
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // IndexCastOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1694  if (!areValidCastInputsAndOutputs(inputs, outputs))
1695  return false;
1696 
1697  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1698  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1699  if (!srcType || !dstType)
1700  return false;
1701 
1702  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1703  (srcType.isSignlessInteger() && dstType.isIndex());
1704 }
1705 
1706 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1707  TypeRange outputs) {
1708  return areIndexCastCompatible(inputs, outputs);
1709 }
1710 
1711 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1712  // index_cast(constant) -> constant
1713  unsigned resultBitwidth = 64; // Default for index integer attributes.
1714  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1715  resultBitwidth = intTy.getWidth();
1716 
1717  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1718  adaptor.getOperands(), getType(),
1719  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1720  return a.sextOrTrunc(resultBitwidth);
1721  });
1722 }
1723 
1724 void arith::IndexCastOp::getCanonicalizationPatterns(
1725  RewritePatternSet &patterns, MLIRContext *context) {
1726  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // IndexCastUIOp
1731 //===----------------------------------------------------------------------===//
1732 
1733 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1734  TypeRange outputs) {
1735  return areIndexCastCompatible(inputs, outputs);
1736 }
1737 
1738 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1739  // index_castui(constant) -> constant
1740  unsigned resultBitwidth = 64; // Default for index integer attributes.
1741  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1742  resultBitwidth = intTy.getWidth();
1743 
1744  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1745  adaptor.getOperands(), getType(),
1746  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1747  return a.zextOrTrunc(resultBitwidth);
1748  });
1749 }
1750 
1751 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1752  RewritePatternSet &patterns, MLIRContext *context) {
1753  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // BitcastOp
1758 //===----------------------------------------------------------------------===//
1759 
1760 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1761  if (!areValidCastInputsAndOutputs(inputs, outputs))
1762  return false;
1763 
1764  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1765  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1766  if (!srcType || !dstType)
1767  return false;
1768 
1769  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1770 }
1771 
1772 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1773  auto resType = getType();
1774  auto operand = adaptor.getIn();
1775  if (!operand)
1776  return {};
1777 
1778  /// Bitcast dense elements.
1779  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1780  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1781  /// Other shaped types unhandled.
1782  if (llvm::isa<ShapedType>(resType))
1783  return {};
1784 
1785  /// Bitcast poison.
1786  if (llvm::isa<ub::PoisonAttr>(operand))
1787  return ub::PoisonAttr::get(getContext());
1788 
1789  /// Bitcast integer or float to integer or float.
1790  APInt bits = llvm::isa<FloatAttr>(operand)
1791  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1792  : llvm::cast<IntegerAttr>(operand).getValue();
1793  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1794  "trying to fold on broken IR: operands have incompatible types");
1795 
1796  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1797  return FloatAttr::get(resType,
1798  APFloat(resFloatType.getFloatSemantics(), bits));
1799  return IntegerAttr::get(resType, bits);
1800 }
1801 
1802 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1803  MLIRContext *context) {
1804  patterns.add<BitcastOfBitcast>(context);
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // CmpIOp
1809 //===----------------------------------------------------------------------===//
1810 
1811 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1812 /// comparison predicates.
1813 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1814  const APInt &lhs, const APInt &rhs) {
1815  switch (predicate) {
1816  case arith::CmpIPredicate::eq:
1817  return lhs.eq(rhs);
1818  case arith::CmpIPredicate::ne:
1819  return lhs.ne(rhs);
1820  case arith::CmpIPredicate::slt:
1821  return lhs.slt(rhs);
1822  case arith::CmpIPredicate::sle:
1823  return lhs.sle(rhs);
1824  case arith::CmpIPredicate::sgt:
1825  return lhs.sgt(rhs);
1826  case arith::CmpIPredicate::sge:
1827  return lhs.sge(rhs);
1828  case arith::CmpIPredicate::ult:
1829  return lhs.ult(rhs);
1830  case arith::CmpIPredicate::ule:
1831  return lhs.ule(rhs);
1832  case arith::CmpIPredicate::ugt:
1833  return lhs.ugt(rhs);
1834  case arith::CmpIPredicate::uge:
1835  return lhs.uge(rhs);
1836  }
1837  llvm_unreachable("unknown cmpi predicate kind");
1838 }
1839 
1840 /// Returns true if the predicate is true for two equal operands.
1841 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1842  switch (predicate) {
1843  case arith::CmpIPredicate::eq:
1844  case arith::CmpIPredicate::sle:
1845  case arith::CmpIPredicate::sge:
1846  case arith::CmpIPredicate::ule:
1847  case arith::CmpIPredicate::uge:
1848  return true;
1849  case arith::CmpIPredicate::ne:
1850  case arith::CmpIPredicate::slt:
1851  case arith::CmpIPredicate::sgt:
1852  case arith::CmpIPredicate::ult:
1853  case arith::CmpIPredicate::ugt:
1854  return false;
1855  }
1856  llvm_unreachable("unknown cmpi predicate kind");
1857 }
1858 
1859 static std::optional<int64_t> getIntegerWidth(Type t) {
1860  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1861  return intType.getWidth();
1862  }
1863  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1864  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1865  }
1866  return std::nullopt;
1867 }
1868 
1869 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1870  // cmpi(pred, x, x)
1871  if (getLhs() == getRhs()) {
1872  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1873  return getBoolAttribute(getType(), val);
1874  }
1875 
1876  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1877  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1878  // extsi(%x : i1 -> iN) != 0 -> %x
1879  std::optional<int64_t> integerWidth =
1880  getIntegerWidth(extOp.getOperand().getType());
1881  if (integerWidth && integerWidth.value() == 1 &&
1882  getPredicate() == arith::CmpIPredicate::ne)
1883  return extOp.getOperand();
1884  }
1885  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1886  // extui(%x : i1 -> iN) != 0 -> %x
1887  std::optional<int64_t> integerWidth =
1888  getIntegerWidth(extOp.getOperand().getType());
1889  if (integerWidth && integerWidth.value() == 1 &&
1890  getPredicate() == arith::CmpIPredicate::ne)
1891  return extOp.getOperand();
1892  }
1893 
1894  // arith.cmpi ne, %val, %zero : i1 -> %val
1895  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1896  getPredicate() == arith::CmpIPredicate::ne)
1897  return getLhs();
1898  }
1899 
1900  if (matchPattern(adaptor.getRhs(), m_One())) {
1901  // arith.cmpi eq, %val, %one : i1 -> %val
1902  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1903  getPredicate() == arith::CmpIPredicate::eq)
1904  return getLhs();
1905  }
1906 
1907  // Move constant to the right side.
1908  if (adaptor.getLhs() && !adaptor.getRhs()) {
1909  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1910  using Pred = CmpIPredicate;
1911  const std::pair<Pred, Pred> invPreds[] = {
1912  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1913  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1914  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1915  {Pred::ne, Pred::ne},
1916  };
1917  Pred origPred = getPredicate();
1918  for (auto pred : invPreds) {
1919  if (origPred == pred.first) {
1920  setPredicate(pred.second);
1921  Value lhs = getLhs();
1922  Value rhs = getRhs();
1923  getLhsMutable().assign(rhs);
1924  getRhsMutable().assign(lhs);
1925  return getResult();
1926  }
1927  }
1928  llvm_unreachable("unknown cmpi predicate kind");
1929  }
1930 
1931  // We are moving constants to the right side; So if lhs is constant rhs is
1932  // guaranteed to be a constant.
1933  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1934  return constFoldBinaryOp<IntegerAttr>(
1935  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1936  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1937  return APInt(1,
1938  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1939  });
1940  }
1941 
1942  return {};
1943 }
1944 
1945 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1946  MLIRContext *context) {
1947  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // CmpFOp
1952 //===----------------------------------------------------------------------===//
1953 
1954 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1955 /// comparison predicates.
1956 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1957  const APFloat &lhs, const APFloat &rhs) {
1958  auto cmpResult = lhs.compare(rhs);
1959  switch (predicate) {
1960  case arith::CmpFPredicate::AlwaysFalse:
1961  return false;
1962  case arith::CmpFPredicate::OEQ:
1963  return cmpResult == APFloat::cmpEqual;
1964  case arith::CmpFPredicate::OGT:
1965  return cmpResult == APFloat::cmpGreaterThan;
1966  case arith::CmpFPredicate::OGE:
1967  return cmpResult == APFloat::cmpGreaterThan ||
1968  cmpResult == APFloat::cmpEqual;
1969  case arith::CmpFPredicate::OLT:
1970  return cmpResult == APFloat::cmpLessThan;
1971  case arith::CmpFPredicate::OLE:
1972  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1973  case arith::CmpFPredicate::ONE:
1974  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1975  case arith::CmpFPredicate::ORD:
1976  return cmpResult != APFloat::cmpUnordered;
1977  case arith::CmpFPredicate::UEQ:
1978  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1979  case arith::CmpFPredicate::UGT:
1980  return cmpResult == APFloat::cmpUnordered ||
1981  cmpResult == APFloat::cmpGreaterThan;
1982  case arith::CmpFPredicate::UGE:
1983  return cmpResult == APFloat::cmpUnordered ||
1984  cmpResult == APFloat::cmpGreaterThan ||
1985  cmpResult == APFloat::cmpEqual;
1986  case arith::CmpFPredicate::ULT:
1987  return cmpResult == APFloat::cmpUnordered ||
1988  cmpResult == APFloat::cmpLessThan;
1989  case arith::CmpFPredicate::ULE:
1990  return cmpResult == APFloat::cmpUnordered ||
1991  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1992  case arith::CmpFPredicate::UNE:
1993  return cmpResult != APFloat::cmpEqual;
1994  case arith::CmpFPredicate::UNO:
1995  return cmpResult == APFloat::cmpUnordered;
1996  case arith::CmpFPredicate::AlwaysTrue:
1997  return true;
1998  }
1999  llvm_unreachable("unknown cmpf predicate kind");
2000 }
2001 
2002 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2003  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2004  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2005 
2006  // If one operand is NaN, making them both NaN does not change the result.
2007  if (lhs && lhs.getValue().isNaN())
2008  rhs = lhs;
2009  if (rhs && rhs.getValue().isNaN())
2010  lhs = rhs;
2011 
2012  if (!lhs || !rhs)
2013  return {};
2014 
2015  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2016  return BoolAttr::get(getContext(), val);
2017 }
2018 
2019 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2020 public:
2022 
2023  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2024  bool isUnsigned) {
2025  using namespace arith;
2026  switch (pred) {
2027  case CmpFPredicate::UEQ:
2028  case CmpFPredicate::OEQ:
2029  return CmpIPredicate::eq;
2030  case CmpFPredicate::UGT:
2031  case CmpFPredicate::OGT:
2032  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2033  case CmpFPredicate::UGE:
2034  case CmpFPredicate::OGE:
2035  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2036  case CmpFPredicate::ULT:
2037  case CmpFPredicate::OLT:
2038  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2039  case CmpFPredicate::ULE:
2040  case CmpFPredicate::OLE:
2041  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2042  case CmpFPredicate::UNE:
2043  case CmpFPredicate::ONE:
2044  return CmpIPredicate::ne;
2045  default:
2046  llvm_unreachable("Unexpected predicate!");
2047  }
2048  }
2049 
2050  LogicalResult matchAndRewrite(CmpFOp op,
2051  PatternRewriter &rewriter) const override {
2052  FloatAttr flt;
2053  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2054  return failure();
2055 
2056  const APFloat &rhs = flt.getValue();
2057 
2058  // Don't attempt to fold a nan.
2059  if (rhs.isNaN())
2060  return failure();
2061 
2062  // Get the width of the mantissa. We don't want to hack on conversions that
2063  // might lose information from the integer, e.g. "i64 -> float"
2064  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2065  int mantissaWidth = floatTy.getFPMantissaWidth();
2066  if (mantissaWidth <= 0)
2067  return failure();
2068 
2069  bool isUnsigned;
2070  Value intVal;
2071 
2072  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2073  isUnsigned = false;
2074  intVal = si.getIn();
2075  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2076  isUnsigned = true;
2077  intVal = ui.getIn();
2078  } else {
2079  return failure();
2080  }
2081 
2082  // Check to see that the input is converted from an integer type that is
2083  // small enough that preserves all bits.
2084  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2085  auto intWidth = intTy.getWidth();
2086 
2087  // Number of bits representing values, as opposed to the sign
2088  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2089 
2090  // Following test does NOT adjust intWidth downwards for signed inputs,
2091  // because the most negative value still requires all the mantissa bits
2092  // to distinguish it from one less than that value.
2093  if ((int)intWidth > mantissaWidth) {
2094  // Conversion would lose accuracy. Check if loss can impact comparison.
2095  int exponent = ilogb(rhs);
2096  if (exponent == APFloat::IEK_Inf) {
2097  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2098  if (maxExponent < (int)valueBits) {
2099  // Conversion could create infinity.
2100  return failure();
2101  }
2102  } else {
2103  // Note that if rhs is zero or NaN, then Exp is negative
2104  // and first condition is trivially false.
2105  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2106  // Conversion could affect comparison.
2107  return failure();
2108  }
2109  }
2110  }
2111 
2112  // Convert to equivalent cmpi predicate
2113  CmpIPredicate pred;
2114  switch (op.getPredicate()) {
2115  case CmpFPredicate::ORD:
2116  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2117  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2118  /*width=*/1);
2119  return success();
2120  case CmpFPredicate::UNO:
2121  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2122  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2123  /*width=*/1);
2124  return success();
2125  default:
2126  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2127  break;
2128  }
2129 
2130  if (!isUnsigned) {
2131  // If the rhs value is > SignedMax, fold the comparison. This handles
2132  // +INF and large values.
2133  APFloat signedMax(rhs.getSemantics());
2134  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2135  APFloat::rmNearestTiesToEven);
2136  if (signedMax < rhs) { // smax < 13123.0
2137  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2138  pred == CmpIPredicate::sle)
2139  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2140  /*width=*/1);
2141  else
2142  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2143  /*width=*/1);
2144  return success();
2145  }
2146  } else {
2147  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2148  // +INF and large values.
2149  APFloat unsignedMax(rhs.getSemantics());
2150  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2151  APFloat::rmNearestTiesToEven);
2152  if (unsignedMax < rhs) { // umax < 13123.0
2153  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2154  pred == CmpIPredicate::ule)
2155  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2156  /*width=*/1);
2157  else
2158  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2159  /*width=*/1);
2160  return success();
2161  }
2162  }
2163 
2164  if (!isUnsigned) {
2165  // See if the rhs value is < SignedMin.
2166  APFloat signedMin(rhs.getSemantics());
2167  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2168  APFloat::rmNearestTiesToEven);
2169  if (signedMin > rhs) { // smin > 12312.0
2170  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2171  pred == CmpIPredicate::sge)
2172  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2173  /*width=*/1);
2174  else
2175  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2176  /*width=*/1);
2177  return success();
2178  }
2179  } else {
2180  // See if the rhs value is < UnsignedMin.
2181  APFloat unsignedMin(rhs.getSemantics());
2182  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2183  APFloat::rmNearestTiesToEven);
2184  if (unsignedMin > rhs) { // umin > 12312.0
2185  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2186  pred == CmpIPredicate::uge)
2187  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2188  /*width=*/1);
2189  else
2190  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2191  /*width=*/1);
2192  return success();
2193  }
2194  }
2195 
2196  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2197  // [0, UMAX], but it may still be fractional. See if it is fractional by
2198  // casting the FP value to the integer value and back, checking for
2199  // equality. Don't do this for zero, because -0.0 is not fractional.
2200  bool ignored;
2201  APSInt rhsInt(intWidth, isUnsigned);
2202  if (APFloat::opInvalidOp ==
2203  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2204  // Undefined behavior invoked - the destination type can't represent
2205  // the input constant.
2206  return failure();
2207  }
2208 
2209  if (!rhs.isZero()) {
2210  APFloat apf(floatTy.getFloatSemantics(),
2211  APInt::getZero(floatTy.getWidth()));
2212  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2213 
2214  bool equal = apf == rhs;
2215  if (!equal) {
2216  // If we had a comparison against a fractional value, we have to adjust
2217  // the compare predicate and sometimes the value. rhsInt is rounded
2218  // towards zero at this point.
2219  switch (pred) {
2220  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2221  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2222  /*width=*/1);
2223  return success();
2224  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2225  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2226  /*width=*/1);
2227  return success();
2228  case CmpIPredicate::ule:
2229  // (float)int <= 4.4 --> int <= 4
2230  // (float)int <= -4.4 --> false
2231  if (rhs.isNegative()) {
2232  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2233  /*width=*/1);
2234  return success();
2235  }
2236  break;
2237  case CmpIPredicate::sle:
2238  // (float)int <= 4.4 --> int <= 4
2239  // (float)int <= -4.4 --> int < -4
2240  if (rhs.isNegative())
2241  pred = CmpIPredicate::slt;
2242  break;
2243  case CmpIPredicate::ult:
2244  // (float)int < -4.4 --> false
2245  // (float)int < 4.4 --> int <= 4
2246  if (rhs.isNegative()) {
2247  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2248  /*width=*/1);
2249  return success();
2250  }
2251  pred = CmpIPredicate::ule;
2252  break;
2253  case CmpIPredicate::slt:
2254  // (float)int < -4.4 --> int < -4
2255  // (float)int < 4.4 --> int <= 4
2256  if (!rhs.isNegative())
2257  pred = CmpIPredicate::sle;
2258  break;
2259  case CmpIPredicate::ugt:
2260  // (float)int > 4.4 --> int > 4
2261  // (float)int > -4.4 --> true
2262  if (rhs.isNegative()) {
2263  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2264  /*width=*/1);
2265  return success();
2266  }
2267  break;
2268  case CmpIPredicate::sgt:
2269  // (float)int > 4.4 --> int > 4
2270  // (float)int > -4.4 --> int >= -4
2271  if (rhs.isNegative())
2272  pred = CmpIPredicate::sge;
2273  break;
2274  case CmpIPredicate::uge:
2275  // (float)int >= -4.4 --> true
2276  // (float)int >= 4.4 --> int > 4
2277  if (rhs.isNegative()) {
2278  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2279  /*width=*/1);
2280  return success();
2281  }
2282  pred = CmpIPredicate::ugt;
2283  break;
2284  case CmpIPredicate::sge:
2285  // (float)int >= -4.4 --> int >= -4
2286  // (float)int >= 4.4 --> int > 4
2287  if (!rhs.isNegative())
2288  pred = CmpIPredicate::sgt;
2289  break;
2290  }
2291  }
2292  }
2293 
2294  // Lower this FP comparison into an appropriate integer version of the
2295  // comparison.
2296  rewriter.replaceOpWithNewOp<CmpIOp>(
2297  op, pred, intVal,
2298  rewriter.create<ConstantOp>(
2299  op.getLoc(), intVal.getType(),
2300  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2301  return success();
2302  }
2303 };
2304 
2305 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2306  MLIRContext *context) {
2307  patterns.insert<CmpFIntToFPConst>(context);
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // SelectOp
2312 //===----------------------------------------------------------------------===//
2313 
2314 // select %arg, %c1, %c0 => extui %arg
2315 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2317 
2318  LogicalResult matchAndRewrite(arith::SelectOp op,
2319  PatternRewriter &rewriter) const override {
2320  // Cannot extui i1 to i1, or i1 to f32
2321  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2322  return failure();
2323 
2324  // select %x, c1, %c0 => extui %arg
2325  if (matchPattern(op.getTrueValue(), m_One()) &&
2326  matchPattern(op.getFalseValue(), m_Zero())) {
2327  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2328  op.getCondition());
2329  return success();
2330  }
2331 
2332  // select %x, c0, %c1 => extui (xor %arg, true)
2333  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2334  matchPattern(op.getFalseValue(), m_One())) {
2335  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2336  op, op.getType(),
2337  rewriter.create<arith::XOrIOp>(
2338  op.getLoc(), op.getCondition(),
2339  rewriter.create<arith::ConstantIntOp>(
2340  op.getLoc(), 1, op.getCondition().getType())));
2341  return success();
2342  }
2343 
2344  return failure();
2345  }
2346 };
2347 
2348 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2349  MLIRContext *context) {
2350  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2351  SelectI1ToNot, SelectToExtUI>(context);
2352 }
2353 
2354 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2355  Value trueVal = getTrueValue();
2356  Value falseVal = getFalseValue();
2357  if (trueVal == falseVal)
2358  return trueVal;
2359 
2360  Value condition = getCondition();
2361 
2362  // select true, %0, %1 => %0
2363  if (matchPattern(adaptor.getCondition(), m_One()))
2364  return trueVal;
2365 
2366  // select false, %0, %1 => %1
2367  if (matchPattern(adaptor.getCondition(), m_Zero()))
2368  return falseVal;
2369 
2370  // If either operand is fully poisoned, return the other.
2371  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2372  return falseVal;
2373 
2374  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2375  return trueVal;
2376 
2377  // select %x, true, false => %x
2378  if (getType().isSignlessInteger(1) &&
2379  matchPattern(adaptor.getTrueValue(), m_One()) &&
2380  matchPattern(adaptor.getFalseValue(), m_Zero()))
2381  return condition;
2382 
2383  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2384  auto pred = cmp.getPredicate();
2385  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2386  auto cmpLhs = cmp.getLhs();
2387  auto cmpRhs = cmp.getRhs();
2388 
2389  // %0 = arith.cmpi eq, %arg0, %arg1
2390  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2391 
2392  // %0 = arith.cmpi ne, %arg0, %arg1
2393  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2394 
2395  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2396  (cmpRhs == trueVal && cmpLhs == falseVal))
2397  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2398  }
2399  }
2400 
2401  // Constant-fold constant operands over non-splat constant condition.
2402  // select %cst_vec, %cst0, %cst1 => %cst2
2403  if (auto cond =
2404  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2405  if (auto lhs =
2406  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2407  if (auto rhs =
2408  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2409  SmallVector<Attribute> results;
2410  results.reserve(static_cast<size_t>(cond.getNumElements()));
2411  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2412  cond.value_end<BoolAttr>());
2413  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2414  lhs.value_end<Attribute>());
2415  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2416  rhs.value_end<Attribute>());
2417 
2418  for (auto [condVal, lhsVal, rhsVal] :
2419  llvm::zip_equal(condVals, lhsVals, rhsVals))
2420  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2421 
2422  return DenseElementsAttr::get(lhs.getType(), results);
2423  }
2424  }
2425  }
2426 
2427  return nullptr;
2428 }
2429 
2430 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2431  Type conditionType, resultType;
2433  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2434  parser.parseOptionalAttrDict(result.attributes) ||
2435  parser.parseColonType(resultType))
2436  return failure();
2437 
2438  // Check for the explicit condition type if this is a masked tensor or vector.
2439  if (succeeded(parser.parseOptionalComma())) {
2440  conditionType = resultType;
2441  if (parser.parseType(resultType))
2442  return failure();
2443  } else {
2444  conditionType = parser.getBuilder().getI1Type();
2445  }
2446 
2447  result.addTypes(resultType);
2448  return parser.resolveOperands(operands,
2449  {conditionType, resultType, resultType},
2450  parser.getNameLoc(), result.operands);
2451 }
2452 
2454  p << " " << getOperands();
2455  p.printOptionalAttrDict((*this)->getAttrs());
2456  p << " : ";
2457  if (ShapedType condType =
2458  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2459  p << condType << ", ";
2460  p << getType();
2461 }
2462 
2463 LogicalResult arith::SelectOp::verify() {
2464  Type conditionType = getCondition().getType();
2465  if (conditionType.isSignlessInteger(1))
2466  return success();
2467 
2468  // If the result type is a vector or tensor, the type can be a mask with the
2469  // same elements.
2470  Type resultType = getType();
2471  if (!llvm::isa<TensorType, VectorType>(resultType))
2472  return emitOpError() << "expected condition to be a signless i1, but got "
2473  << conditionType;
2474  Type shapedConditionType = getI1SameShape(resultType);
2475  if (conditionType != shapedConditionType) {
2476  return emitOpError() << "expected condition type to have the same shape "
2477  "as the result type, expected "
2478  << shapedConditionType << ", but got "
2479  << conditionType;
2480  }
2481  return success();
2482 }
2483 //===----------------------------------------------------------------------===//
2484 // ShLIOp
2485 //===----------------------------------------------------------------------===//
2486 
2487 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2488  // shli(x, 0) -> x
2489  if (matchPattern(adaptor.getRhs(), m_Zero()))
2490  return getLhs();
2491  // Don't fold if shifting more or equal than the bit width.
2492  bool bounded = false;
2493  auto result = constFoldBinaryOp<IntegerAttr>(
2494  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2495  bounded = b.ult(b.getBitWidth());
2496  return a.shl(b);
2497  });
2498  return bounded ? result : Attribute();
2499 }
2500 
2501 //===----------------------------------------------------------------------===//
2502 // ShRUIOp
2503 //===----------------------------------------------------------------------===//
2504 
2505 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2506  // shrui(x, 0) -> x
2507  if (matchPattern(adaptor.getRhs(), m_Zero()))
2508  return getLhs();
2509  // Don't fold if shifting more or equal than the bit width.
2510  bool bounded = false;
2511  auto result = constFoldBinaryOp<IntegerAttr>(
2512  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2513  bounded = b.ult(b.getBitWidth());
2514  return a.lshr(b);
2515  });
2516  return bounded ? result : Attribute();
2517 }
2518 
2519 //===----------------------------------------------------------------------===//
2520 // ShRSIOp
2521 //===----------------------------------------------------------------------===//
2522 
2523 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2524  // shrsi(x, 0) -> x
2525  if (matchPattern(adaptor.getRhs(), m_Zero()))
2526  return getLhs();
2527  // Don't fold if shifting more or equal than the bit width.
2528  bool bounded = false;
2529  auto result = constFoldBinaryOp<IntegerAttr>(
2530  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2531  bounded = b.ult(b.getBitWidth());
2532  return a.ashr(b);
2533  });
2534  return bounded ? result : Attribute();
2535 }
2536 
2537 //===----------------------------------------------------------------------===//
2538 // Atomic Enum
2539 //===----------------------------------------------------------------------===//
2540 
2541 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2542 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2543  OpBuilder &builder, Location loc,
2544  bool useOnlyFiniteValue) {
2545  switch (kind) {
2546  case AtomicRMWKind::maximumf: {
2547  const llvm::fltSemantics &semantic =
2548  llvm::cast<FloatType>(resultType).getFloatSemantics();
2549  APFloat identity = useOnlyFiniteValue
2550  ? APFloat::getLargest(semantic, /*Negative=*/true)
2551  : APFloat::getInf(semantic, /*Negative=*/true);
2552  return builder.getFloatAttr(resultType, identity);
2553  }
2554  case AtomicRMWKind::maxnumf: {
2555  const llvm::fltSemantics &semantic =
2556  llvm::cast<FloatType>(resultType).getFloatSemantics();
2557  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2558  return builder.getFloatAttr(resultType, identity);
2559  }
2560  case AtomicRMWKind::addf:
2561  case AtomicRMWKind::addi:
2562  case AtomicRMWKind::maxu:
2563  case AtomicRMWKind::ori:
2564  return builder.getZeroAttr(resultType);
2565  case AtomicRMWKind::andi:
2566  return builder.getIntegerAttr(
2567  resultType,
2568  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2569  case AtomicRMWKind::maxs:
2570  return builder.getIntegerAttr(
2571  resultType, APInt::getSignedMinValue(
2572  llvm::cast<IntegerType>(resultType).getWidth()));
2573  case AtomicRMWKind::minimumf: {
2574  const llvm::fltSemantics &semantic =
2575  llvm::cast<FloatType>(resultType).getFloatSemantics();
2576  APFloat identity = useOnlyFiniteValue
2577  ? APFloat::getLargest(semantic, /*Negative=*/false)
2578  : APFloat::getInf(semantic, /*Negative=*/false);
2579 
2580  return builder.getFloatAttr(resultType, identity);
2581  }
2582  case AtomicRMWKind::minnumf: {
2583  const llvm::fltSemantics &semantic =
2584  llvm::cast<FloatType>(resultType).getFloatSemantics();
2585  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2586  return builder.getFloatAttr(resultType, identity);
2587  }
2588  case AtomicRMWKind::mins:
2589  return builder.getIntegerAttr(
2590  resultType, APInt::getSignedMaxValue(
2591  llvm::cast<IntegerType>(resultType).getWidth()));
2592  case AtomicRMWKind::minu:
2593  return builder.getIntegerAttr(
2594  resultType,
2595  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2596  case AtomicRMWKind::muli:
2597  return builder.getIntegerAttr(resultType, 1);
2598  case AtomicRMWKind::mulf:
2599  return builder.getFloatAttr(resultType, 1);
2600  // TODO: Add remaining reduction operations.
2601  default:
2602  (void)emitOptionalError(loc, "Reduction operation type not supported");
2603  break;
2604  }
2605  return nullptr;
2606 }
2607 
2608 /// Return the identity numeric value associated to the give op.
2609 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2610  std::optional<AtomicRMWKind> maybeKind =
2612  // Floating-point operations.
2613  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2614  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2615  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2616  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2617  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2618  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2619  // Integer operations.
2620  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2621  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2622  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2623  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2624  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2625  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2626  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2627  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2628  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2629  .Default([](Operation *op) { return std::nullopt; });
2630  if (!maybeKind) {
2631  return std::nullopt;
2632  }
2633 
2634  bool useOnlyFiniteValue = false;
2635  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2636  if (fmfOpInterface) {
2637  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2638  useOnlyFiniteValue =
2639  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2640  }
2641 
2642  // Builder only used as helper for attribute creation.
2643  OpBuilder b(op->getContext());
2644  Type resultType = op->getResult(0).getType();
2645 
2646  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2647  useOnlyFiniteValue);
2648 }
2649 
2650 /// Returns the identity value associated with an AtomicRMWKind op.
2651 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2652  OpBuilder &builder, Location loc,
2653  bool useOnlyFiniteValue) {
2654  auto attr =
2655  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2656  return builder.create<arith::ConstantOp>(loc, attr);
2657 }
2658 
2659 /// Return the value obtained by applying the reduction operation kind
2660 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2661 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2662  Location loc, Value lhs, Value rhs) {
2663  switch (op) {
2664  case AtomicRMWKind::addf:
2665  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2666  case AtomicRMWKind::addi:
2667  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2668  case AtomicRMWKind::mulf:
2669  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2670  case AtomicRMWKind::muli:
2671  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2672  case AtomicRMWKind::maximumf:
2673  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2674  case AtomicRMWKind::minimumf:
2675  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2676  case AtomicRMWKind::maxnumf:
2677  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2678  case AtomicRMWKind::minnumf:
2679  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2680  case AtomicRMWKind::maxs:
2681  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2682  case AtomicRMWKind::mins:
2683  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2684  case AtomicRMWKind::maxu:
2685  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2686  case AtomicRMWKind::minu:
2687  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2688  case AtomicRMWKind::ori:
2689  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2690  case AtomicRMWKind::andi:
2691  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2692  // TODO: Add remaining reduction operations.
2693  default:
2694  (void)emitOptionalError(loc, "Reduction operation type not supported");
2695  break;
2696  }
2697  return nullptr;
2698 }
2699 
2700 //===----------------------------------------------------------------------===//
2701 // TableGen'd op method definitions
2702 //===----------------------------------------------------------------------===//
2703 
2704 #define GET_OP_CLASSES
2705 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2706 
2707 //===----------------------------------------------------------------------===//
2708 // TableGen'd enum attribute definitions
2709 //===----------------------------------------------------------------------===//
2710 
2711 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:623
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1331
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1267
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1841
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:584
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1281
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1253
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:683
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:665
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1346
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1693
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1591
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1303
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:864
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1274
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1859
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1289
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1247
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:345
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1317
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1176::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2050
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2023
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:56
bool isIndex() const
Definition: Types.cpp:45
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:278
static bool classof(Operation *op)
Definition: ArithOps.cpp:284
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:290
static bool classof(Operation *op)
Definition: ArithOps.cpp:296
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:257
static bool classof(Operation *op)
Definition: ArithOps.cpp:272
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2609
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1813
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2542
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2661
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2651
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2318
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes