MLIR  21.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // Integer values must be signless.
211  if (llvm::isa<IntegerType>(type) &&
212  !llvm::cast<IntegerType>(type).isSignless())
213  return emitOpError("integer return type must be signless");
214  // Any float or elements attribute are acceptable.
215  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216  return emitOpError(
217  "value must be an integer, float, or elements attribute");
218  }
219 
220  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
221  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
222  // However, this would most likely require updating the lowerings to LLVM.
223  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224  return emitOpError(
225  "intializing scalable vectors with elements attribute is not supported"
226  " unless it's a vector splat");
227  return success();
228 }
229 
230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231  // The value's type must be the same as the provided type.
232  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233  if (!typedAttr || typedAttr.getType() != type)
234  return false;
235  // Integer values must be signless.
236  if (llvm::isa<IntegerType>(type) &&
237  !llvm::cast<IntegerType>(type).isSignless())
238  return false;
239  // Integer, float, and element attributes are buildable.
240  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241 }
242 
243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
244  Type type, Location loc) {
245  if (isBuildableWith(value, type))
246  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
247  return nullptr;
248 }
249 
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251 
253  int64_t value, unsigned width) {
254  auto type = builder.getIntegerType(width);
255  arith::ConstantOp::build(builder, result, type,
256  builder.getIntegerAttr(type, value));
257 }
258 
260  int64_t value, Type type) {
261  assert(type.isSignlessInteger() &&
262  "ConstantIntOp can only have signless integer type values");
263  arith::ConstantOp::build(builder, result, type,
264  builder.getIntegerAttr(type, value));
265 }
266 
268  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269  return constOp.getType().isSignlessInteger();
270  return false;
271 }
272 
274  const APFloat &value, FloatType type) {
275  arith::ConstantOp::build(builder, result, type,
276  builder.getFloatAttr(type, value));
277 }
278 
280  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281  return llvm::isa<FloatType>(constOp.getType());
282  return false;
283 }
284 
286  int64_t value) {
287  arith::ConstantOp::build(builder, result, builder.getIndexType(),
288  builder.getIndexAttr(value));
289 }
290 
292  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293  return constOp.getType().isIndex();
294  return false;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // AddIOp
299 //===----------------------------------------------------------------------===//
300 
301 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
302  // addi(x, 0) -> x
303  if (matchPattern(adaptor.getRhs(), m_Zero()))
304  return getLhs();
305 
306  // addi(subi(a, b), b) -> a
307  if (auto sub = getLhs().getDefiningOp<SubIOp>())
308  if (getRhs() == sub.getRhs())
309  return sub.getLhs();
310 
311  // addi(b, subi(a, b)) -> a
312  if (auto sub = getRhs().getDefiningOp<SubIOp>())
313  if (getLhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  return constFoldBinaryOp<IntegerAttr>(
317  adaptor.getOperands(),
318  [](APInt a, const APInt &b) { return std::move(a) + b; });
319 }
320 
321 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
322  MLIRContext *context) {
323  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // AddUIExtendedOp
329 //===----------------------------------------------------------------------===//
330 
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
334  return llvm::to_vector<4>(vt.getShape());
335  return std::nullopt;
336 }
337 
338 // Returns the overflow bit, assuming that `sum` is the result of unsigned
339 // addition of `operand` and another number.
340 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
341  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
342 }
343 
344 LogicalResult
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347  Type overflowTy = getOverflow().getType();
348  // addui_extended(x, 0) -> x, false
349  if (matchPattern(getRhs(), m_Zero())) {
350  Builder builder(getContext());
351  auto falseValue = builder.getZeroAttr(overflowTy);
352 
353  results.push_back(getLhs());
354  results.push_back(falseValue);
355  return success();
356  }
357 
358  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
359  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
360  // operands. If that succeeds, calculate the overflow bit based on the sum
361  // and the first (constant) operand, `lhs`.
362  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363  adaptor.getOperands(),
364  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
365  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366  ArrayRef({sumAttr, adaptor.getLhs()}),
367  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
369  if (!overflowAttr)
370  return failure();
371 
372  results.push_back(sumAttr);
373  results.push_back(overflowAttr);
374  return success();
375  }
376 
377  return failure();
378 }
379 
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382  patterns.add<AddUIExtendedToAddI>(context);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // SubIOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
390  // subi(x,x) -> 0
391  if (getOperand(0) == getOperand(1)) {
392  auto shapedType = dyn_cast<ShapedType>(getType());
393  // We can't generate a constant with a dynamic shaped tensor.
394  if (!shapedType || shapedType.hasStaticShape())
395  return Builder(getContext()).getZeroAttr(getType());
396  }
397  // subi(x,0) -> x
398  if (matchPattern(adaptor.getRhs(), m_Zero()))
399  return getLhs();
400 
401  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
402  // subi(addi(a, b), b) -> a
403  if (getRhs() == add.getRhs())
404  return add.getLhs();
405  // subi(addi(a, b), a) -> b
406  if (getRhs() == add.getLhs())
407  return add.getRhs();
408  }
409 
410  return constFoldBinaryOp<IntegerAttr>(
411  adaptor.getOperands(),
412  [](APInt a, const APInt &b) { return std::move(a) - b; });
413 }
414 
415 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
416  MLIRContext *context) {
417  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // MulIOp
424 //===----------------------------------------------------------------------===//
425 
426 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
427  // muli(x, 0) -> 0
428  if (matchPattern(adaptor.getRhs(), m_Zero()))
429  return getRhs();
430  // muli(x, 1) -> x
431  if (matchPattern(adaptor.getRhs(), m_One()))
432  return getLhs();
433  // TODO: Handle the overflow case.
434 
435  // default folder
436  return constFoldBinaryOp<IntegerAttr>(
437  adaptor.getOperands(),
438  [](const APInt &a, const APInt &b) { return a * b; });
439 }
440 
441 void arith::MulIOp::getAsmResultNames(
442  function_ref<void(Value, StringRef)> setNameFn) {
443  if (!isa<IndexType>(getType()))
444  return;
445 
446  // Match vector.vscale by name to avoid depending on the vector dialect (which
447  // is a circular dependency).
448  auto isVscale = [](Operation *op) {
449  return op && op->getName().getStringRef() == "vector.vscale";
450  };
451 
452  IntegerAttr baseValue;
453  auto isVscaleExpr = [&](Value a, Value b) {
454  return matchPattern(a, m_Constant(&baseValue)) &&
455  isVscale(b.getDefiningOp());
456  };
457 
458  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
459  return;
460 
461  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
462  SmallString<32> specialNameBuffer;
463  llvm::raw_svector_ostream specialName(specialNameBuffer);
464  specialName << 'c' << baseValue.getInt() << "_vscale";
465  setNameFn(getResult(), specialName.str());
466 }
467 
468 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
469  MLIRContext *context) {
470  patterns.add<MulIMulIConstant>(context);
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // MulSIExtendedOp
475 //===----------------------------------------------------------------------===//
476 
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
480  return llvm::to_vector<4>(vt.getShape());
481  return std::nullopt;
482 }
483 
484 LogicalResult
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
487  // mulsi_extended(x, 0) -> 0, 0
488  if (matchPattern(adaptor.getRhs(), m_Zero())) {
489  Attribute zero = adaptor.getRhs();
490  results.push_back(zero);
491  results.push_back(zero);
492  return success();
493  }
494 
495  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
496  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497  adaptor.getOperands(),
498  [](const APInt &a, const APInt &b) { return a * b; })) {
499  // Invoke the constant fold helper again to calculate the 'high' result.
500  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
502  return llvm::APIntOps::mulhs(a, b);
503  });
504  assert(highAttr && "Unexpected constant-folding failure");
505 
506  results.push_back(lowAttr);
507  results.push_back(highAttr);
508  return success();
509  }
510 
511  return failure();
512 }
513 
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // MulUIExtendedOp
521 //===----------------------------------------------------------------------===//
522 
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
526  return llvm::to_vector<4>(vt.getShape());
527  return std::nullopt;
528 }
529 
530 LogicalResult
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
533  // mului_extended(x, 0) -> 0, 0
534  if (matchPattern(adaptor.getRhs(), m_Zero())) {
535  Attribute zero = adaptor.getRhs();
536  results.push_back(zero);
537  results.push_back(zero);
538  return success();
539  }
540 
541  // mului_extended(x, 1) -> x, 0
542  if (matchPattern(adaptor.getRhs(), m_One())) {
543  Builder builder(getContext());
544  Attribute zero = builder.getZeroAttr(getLhs().getType());
545  results.push_back(getLhs());
546  results.push_back(zero);
547  return success();
548  }
549 
550  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
551  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552  adaptor.getOperands(),
553  [](const APInt &a, const APInt &b) { return a * b; })) {
554  // Invoke the constant fold helper again to calculate the 'high' result.
555  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
557  return llvm::APIntOps::mulhu(a, b);
558  });
559  assert(highAttr && "Unexpected constant-folding failure");
560 
561  results.push_back(lowAttr);
562  results.push_back(highAttr);
563  return success();
564  }
565 
566  return failure();
567 }
568 
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571  patterns.add<MulUIExtendedToMulI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // DivUIOp
576 //===----------------------------------------------------------------------===//
577 
578 /// Fold `(a * b) / b -> a`
579 static Value foldDivMul(Value lhs, Value rhs,
580  arith::IntegerOverflowFlags ovfFlags) {
581  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
582  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
583  return {};
584 
585  if (mul.getLhs() == rhs)
586  return mul.getRhs();
587 
588  if (mul.getRhs() == rhs)
589  return mul.getLhs();
590 
591  return {};
592 }
593 
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
595  // divui (x, 1) -> x.
596  if (matchPattern(adaptor.getRhs(), m_One()))
597  return getLhs();
598 
599  // (a * b) / b -> a
600  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
601  return val;
602 
603  // Don't fold if it would require a division by zero.
604  bool div0 = false;
605  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606  [&](APInt a, const APInt &b) {
607  if (div0 || !b) {
608  div0 = true;
609  return a;
610  }
611  return a.udiv(b);
612  });
613 
614  return div0 ? Attribute() : result;
615 }
616 
617 /// Returns whether an unsigned division by `divisor` is speculatable.
619  // X / 0 => UB
620  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
622 
624 }
625 
626 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
627  return getDivUISpeculatability(getRhs());
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // DivSIOp
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
635  // divsi (x, 1) -> x.
636  if (matchPattern(adaptor.getRhs(), m_One()))
637  return getLhs();
638 
639  // (a * b) / b -> a
640  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
641  return val;
642 
643  // Don't fold if it would overflow or if it requires a division by zero.
644  bool overflowOrDiv0 = false;
645  auto result = constFoldBinaryOp<IntegerAttr>(
646  adaptor.getOperands(), [&](APInt a, const APInt &b) {
647  if (overflowOrDiv0 || !b) {
648  overflowOrDiv0 = true;
649  return a;
650  }
651  return a.sdiv_ov(b, overflowOrDiv0);
652  });
653 
654  return overflowOrDiv0 ? Attribute() : result;
655 }
656 
657 /// Returns whether a signed division by `divisor` is speculatable. This
658 /// function conservatively assumes that all signed division by -1 are not
659 /// speculatable.
661  // X / 0 => UB
662  // INT_MIN / -1 => UB
663  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
666 
668 }
669 
670 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
671  return getDivSISpeculatability(getRhs());
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // Ceil and floor division folding helpers
676 //===----------------------------------------------------------------------===//
677 
678 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
679  bool &overflow) {
680  // Returns (a-1)/b + 1
681  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
682  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683  return val.sadd_ov(one, overflow);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // CeilDivUIOp
688 //===----------------------------------------------------------------------===//
689 
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
691  // ceildivui (x, 1) -> x.
692  if (matchPattern(adaptor.getRhs(), m_One()))
693  return getLhs();
694 
695  bool overflowOrDiv0 = false;
696  auto result = constFoldBinaryOp<IntegerAttr>(
697  adaptor.getOperands(), [&](APInt a, const APInt &b) {
698  if (overflowOrDiv0 || !b) {
699  overflowOrDiv0 = true;
700  return a;
701  }
702  APInt quotient = a.udiv(b);
703  if (!a.urem(b))
704  return quotient;
705  APInt one(a.getBitWidth(), 1, true);
706  return quotient.uadd_ov(one, overflowOrDiv0);
707  });
708 
709  return overflowOrDiv0 ? Attribute() : result;
710 }
711 
712 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
713  return getDivUISpeculatability(getRhs());
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // CeilDivSIOp
718 //===----------------------------------------------------------------------===//
719 
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
721  // ceildivsi (x, 1) -> x.
722  if (matchPattern(adaptor.getRhs(), m_One()))
723  return getLhs();
724 
725  // Don't fold if it would overflow or if it requires a division by zero.
726  // TODO: This hook won't fold operations where a = MININT, because
727  // negating MININT overflows. This can be improved.
728  bool overflowOrDiv0 = false;
729  auto result = constFoldBinaryOp<IntegerAttr>(
730  adaptor.getOperands(), [&](APInt a, const APInt &b) {
731  if (overflowOrDiv0 || !b) {
732  overflowOrDiv0 = true;
733  return a;
734  }
735  if (!a)
736  return a;
737  // After this point we know that neither a or b are zero.
738  unsigned bits = a.getBitWidth();
739  APInt zero = APInt::getZero(bits);
740  bool aGtZero = a.sgt(zero);
741  bool bGtZero = b.sgt(zero);
742  if (aGtZero && bGtZero) {
743  // Both positive, return ceil(a, b).
744  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
745  }
746 
747  // No folding happens if any of the intermediate arithmetic operations
748  // overflows.
749  bool overflowNegA = false;
750  bool overflowNegB = false;
751  bool overflowDiv = false;
752  bool overflowNegRes = false;
753  if (!aGtZero && !bGtZero) {
754  // Both negative, return ceil(-a, -b).
755  APInt posA = zero.ssub_ov(a, overflowNegA);
756  APInt posB = zero.ssub_ov(b, overflowNegB);
757  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
758  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
759  return res;
760  }
761  if (!aGtZero && bGtZero) {
762  // A is negative, b is positive, return - ( -a / b).
763  APInt posA = zero.ssub_ov(a, overflowNegA);
764  APInt div = posA.sdiv_ov(b, overflowDiv);
765  APInt res = zero.ssub_ov(div, overflowNegRes);
766  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
767  return res;
768  }
769  // A is positive, b is negative, return - (a / -b).
770  APInt posB = zero.ssub_ov(b, overflowNegB);
771  APInt div = a.sdiv_ov(posB, overflowDiv);
772  APInt res = zero.ssub_ov(div, overflowNegRes);
773 
774  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
775  return res;
776  });
777 
778  return overflowOrDiv0 ? Attribute() : result;
779 }
780 
781 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
782  return getDivSISpeculatability(getRhs());
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // FloorDivSIOp
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
790  // floordivsi (x, 1) -> x.
791  if (matchPattern(adaptor.getRhs(), m_One()))
792  return getLhs();
793 
794  // Don't fold if it would overflow or if it requires a division by zero.
795  bool overflowOrDiv = false;
796  auto result = constFoldBinaryOp<IntegerAttr>(
797  adaptor.getOperands(), [&](APInt a, const APInt &b) {
798  if (b.isZero()) {
799  overflowOrDiv = true;
800  return a;
801  }
802  return a.sfloordiv_ov(b, overflowOrDiv);
803  });
804 
805  return overflowOrDiv ? Attribute() : result;
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // RemUIOp
810 //===----------------------------------------------------------------------===//
811 
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
813  // remui (x, 1) -> 0.
814  if (matchPattern(adaptor.getRhs(), m_One()))
815  return Builder(getContext()).getZeroAttr(getType());
816 
817  // Don't fold if it would require a division by zero.
818  bool div0 = false;
819  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820  [&](APInt a, const APInt &b) {
821  if (div0 || b.isZero()) {
822  div0 = true;
823  return a;
824  }
825  return a.urem(b);
826  });
827 
828  return div0 ? Attribute() : result;
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // RemSIOp
833 //===----------------------------------------------------------------------===//
834 
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
836  // remsi (x, 1) -> 0.
837  if (matchPattern(adaptor.getRhs(), m_One()))
838  return Builder(getContext()).getZeroAttr(getType());
839 
840  // Don't fold if it would require a division by zero.
841  bool div0 = false;
842  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843  [&](APInt a, const APInt &b) {
844  if (div0 || b.isZero()) {
845  div0 = true;
846  return a;
847  }
848  return a.srem(b);
849  });
850 
851  return div0 ? Attribute() : result;
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // AndIOp
856 //===----------------------------------------------------------------------===//
857 
858 /// Fold `and(a, and(a, b))` to `and(a, b)`
859 static Value foldAndIofAndI(arith::AndIOp op) {
860  for (bool reversePrev : {false, true}) {
861  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862  .getDefiningOp<arith::AndIOp>();
863  if (!prev)
864  continue;
865 
866  Value other = (reversePrev ? op.getLhs() : op.getRhs());
867  if (other != prev.getLhs() && other != prev.getRhs())
868  continue;
869 
870  return prev.getResult();
871  }
872  return {};
873 }
874 
875 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
876  /// and(x, 0) -> 0
877  if (matchPattern(adaptor.getRhs(), m_Zero()))
878  return getRhs();
879  /// and(x, allOnes) -> x
880  APInt intValue;
881  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
882  intValue.isAllOnes())
883  return getLhs();
884  /// and(x, not(x)) -> 0
885  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
886  m_ConstantInt(&intValue))) &&
887  intValue.isAllOnes())
888  return Builder(getContext()).getZeroAttr(getType());
889  /// and(not(x), x) -> 0
890  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
891  m_ConstantInt(&intValue))) &&
892  intValue.isAllOnes())
893  return Builder(getContext()).getZeroAttr(getType());
894 
895  /// and(a, and(a, b)) -> and(a, b)
896  if (Value result = foldAndIofAndI(*this))
897  return result;
898 
899  return constFoldBinaryOp<IntegerAttr>(
900  adaptor.getOperands(),
901  [](APInt a, const APInt &b) { return std::move(a) & b; });
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // OrIOp
906 //===----------------------------------------------------------------------===//
907 
908 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
909  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
910  /// or(x, 0) -> x
911  if (rhsVal.isZero())
912  return getLhs();
913  /// or(x, <all ones>) -> <all ones>
914  if (rhsVal.isAllOnes())
915  return adaptor.getRhs();
916  }
917 
918  APInt intValue;
919  /// or(x, xor(x, 1)) -> 1
920  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
921  m_ConstantInt(&intValue))) &&
922  intValue.isAllOnes())
923  return getRhs().getDefiningOp<XOrIOp>().getRhs();
924  /// or(xor(x, 1), x) -> 1
925  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
926  m_ConstantInt(&intValue))) &&
927  intValue.isAllOnes())
928  return getLhs().getDefiningOp<XOrIOp>().getRhs();
929 
930  return constFoldBinaryOp<IntegerAttr>(
931  adaptor.getOperands(),
932  [](APInt a, const APInt &b) { return std::move(a) | b; });
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // XOrIOp
937 //===----------------------------------------------------------------------===//
938 
939 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
940  /// xor(x, 0) -> x
941  if (matchPattern(adaptor.getRhs(), m_Zero()))
942  return getLhs();
943  /// xor(x, x) -> 0
944  if (getLhs() == getRhs())
945  return Builder(getContext()).getZeroAttr(getType());
946  /// xor(xor(x, a), a) -> x
947  /// xor(xor(a, x), a) -> x
948  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949  if (prev.getRhs() == getRhs())
950  return prev.getLhs();
951  if (prev.getLhs() == getRhs())
952  return prev.getRhs();
953  }
954  /// xor(a, xor(x, a)) -> x
955  /// xor(a, xor(a, x)) -> x
956  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957  if (prev.getRhs() == getLhs())
958  return prev.getLhs();
959  if (prev.getLhs() == getLhs())
960  return prev.getRhs();
961  }
962 
963  return constFoldBinaryOp<IntegerAttr>(
964  adaptor.getOperands(),
965  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
966 }
967 
968 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
969  MLIRContext *context) {
970  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // NegFOp
975 //===----------------------------------------------------------------------===//
976 
977 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
978  /// negf(negf(x)) -> x
979  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980  return op.getOperand();
981  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982  [](const APFloat &a) { return -a; });
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // AddFOp
987 //===----------------------------------------------------------------------===//
988 
989 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
990  // addf(x, -0) -> x
991  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
992  return getLhs();
993 
994  return constFoldBinaryOp<FloatAttr>(
995  adaptor.getOperands(),
996  [](const APFloat &a, const APFloat &b) { return a + b; });
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // SubFOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1004  // subf(x, +0) -> x
1005  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1006  return getLhs();
1007 
1008  return constFoldBinaryOp<FloatAttr>(
1009  adaptor.getOperands(),
1010  [](const APFloat &a, const APFloat &b) { return a - b; });
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // MaximumFOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1018  // maximumf(x,x) -> x
1019  if (getLhs() == getRhs())
1020  return getRhs();
1021 
1022  // maximumf(x, -inf) -> x
1023  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1024  return getLhs();
1025 
1026  return constFoldBinaryOp<FloatAttr>(
1027  adaptor.getOperands(),
1028  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // MaxNumFOp
1033 //===----------------------------------------------------------------------===//
1034 
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1036  // maxnumf(x,x) -> x
1037  if (getLhs() == getRhs())
1038  return getRhs();
1039 
1040  // maxnumf(x, NaN) -> x
1041  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1042  return getLhs();
1043 
1044  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // MaxSIOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1052  // maxsi(x,x) -> x
1053  if (getLhs() == getRhs())
1054  return getRhs();
1055 
1056  if (APInt intValue;
1057  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1058  // maxsi(x,MAX_INT) -> MAX_INT
1059  if (intValue.isMaxSignedValue())
1060  return getRhs();
1061  // maxsi(x, MIN_INT) -> x
1062  if (intValue.isMinSignedValue())
1063  return getLhs();
1064  }
1065 
1066  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067  [](const APInt &a, const APInt &b) {
1068  return llvm::APIntOps::smax(a, b);
1069  });
1070 }
1071 
1072 //===----------------------------------------------------------------------===//
1073 // MaxUIOp
1074 //===----------------------------------------------------------------------===//
1075 
1076 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1077  // maxui(x,x) -> x
1078  if (getLhs() == getRhs())
1079  return getRhs();
1080 
1081  if (APInt intValue;
1082  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1083  // maxui(x,MAX_INT) -> MAX_INT
1084  if (intValue.isMaxValue())
1085  return getRhs();
1086  // maxui(x, MIN_INT) -> x
1087  if (intValue.isMinValue())
1088  return getLhs();
1089  }
1090 
1091  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092  [](const APInt &a, const APInt &b) {
1093  return llvm::APIntOps::umax(a, b);
1094  });
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // MinimumFOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1102  // minimumf(x,x) -> x
1103  if (getLhs() == getRhs())
1104  return getRhs();
1105 
1106  // minimumf(x, +inf) -> x
1107  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1108  return getLhs();
1109 
1110  return constFoldBinaryOp<FloatAttr>(
1111  adaptor.getOperands(),
1112  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // MinNumFOp
1117 //===----------------------------------------------------------------------===//
1118 
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1120  // minnumf(x,x) -> x
1121  if (getLhs() == getRhs())
1122  return getRhs();
1123 
1124  // minnumf(x, NaN) -> x
1125  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1126  return getLhs();
1127 
1128  return constFoldBinaryOp<FloatAttr>(
1129  adaptor.getOperands(),
1130  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // MinSIOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1138  // minsi(x,x) -> x
1139  if (getLhs() == getRhs())
1140  return getRhs();
1141 
1142  if (APInt intValue;
1143  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1144  // minsi(x,MIN_INT) -> MIN_INT
1145  if (intValue.isMinSignedValue())
1146  return getRhs();
1147  // minsi(x, MAX_INT) -> x
1148  if (intValue.isMaxSignedValue())
1149  return getLhs();
1150  }
1151 
1152  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153  [](const APInt &a, const APInt &b) {
1154  return llvm::APIntOps::smin(a, b);
1155  });
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // MinUIOp
1160 //===----------------------------------------------------------------------===//
1161 
1162 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1163  // minui(x,x) -> x
1164  if (getLhs() == getRhs())
1165  return getRhs();
1166 
1167  if (APInt intValue;
1168  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1169  // minui(x,MIN_INT) -> MIN_INT
1170  if (intValue.isMinValue())
1171  return getRhs();
1172  // minui(x, MAX_INT) -> x
1173  if (intValue.isMaxValue())
1174  return getLhs();
1175  }
1176 
1177  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178  [](const APInt &a, const APInt &b) {
1179  return llvm::APIntOps::umin(a, b);
1180  });
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // MulFOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1188  // mulf(x, 1) -> x
1189  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1190  return getLhs();
1191 
1192  return constFoldBinaryOp<FloatAttr>(
1193  adaptor.getOperands(),
1194  [](const APFloat &a, const APFloat &b) { return a * b; });
1195 }
1196 
1197 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1198  MLIRContext *context) {
1199  patterns.add<MulFOfNegF>(context);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // DivFOp
1204 //===----------------------------------------------------------------------===//
1205 
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1207  // divf(x, 1) -> x
1208  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1209  return getLhs();
1210 
1211  return constFoldBinaryOp<FloatAttr>(
1212  adaptor.getOperands(),
1213  [](const APFloat &a, const APFloat &b) { return a / b; });
1214 }
1215 
1216 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1217  MLIRContext *context) {
1218  patterns.add<DivFOfNegF>(context);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // RemFOp
1223 //===----------------------------------------------------------------------===//
1224 
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227  [](const APFloat &a, const APFloat &b) {
1228  APFloat result(a);
1229  // APFloat::mod() offers the remainder
1230  // behavior we want, i.e. the result has
1231  // the sign of LHS operand.
1232  (void)result.mod(b);
1233  return result;
1234  });
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Utility functions for verifying cast ops
1239 //===----------------------------------------------------------------------===//
1240 
1241 template <typename... Types>
1242 using type_list = std::tuple<Types...> *;
1243 
1244 /// Returns a non-null type only if the provided type is one of the allowed
1245 /// types or one of the allowed shaped types of the allowed types. Returns the
1246 /// element type if a valid shaped type is provided.
1247 template <typename... ShapedTypes, typename... ElementTypes>
1250  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1251  return {};
1252 
1253  auto underlyingType = getElementTypeOrSelf(type);
1254  if (!llvm::isa<ElementTypes...>(underlyingType))
1255  return {};
1256 
1257  return underlyingType;
1258 }
1259 
1260 /// Get allowed underlying types for vectors and tensors.
1261 template <typename... ElementTypes>
1262 static Type getTypeIfLike(Type type) {
1265 }
1266 
1267 /// Get allowed underlying types for vectors, tensors, and memrefs.
1268 template <typename... ElementTypes>
1270  return getUnderlyingType(type,
1273 }
1274 
1275 /// Return false if both types are ranked tensor with mismatching encoding.
1276 static bool hasSameEncoding(Type typeA, Type typeB) {
1277  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279  if (!rankedTensorA || !rankedTensorB)
1280  return true;
1281  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1282 }
1283 
1284 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1285  if (inputs.size() != 1 || outputs.size() != 1)
1286  return false;
1287  if (!hasSameEncoding(inputs.front(), outputs.front()))
1288  return false;
1289  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // Verifiers for integer and floating point extension/truncation ops
1294 //===----------------------------------------------------------------------===//
1295 
1296 // Extend ops can only extend to a wider type.
1297 template <typename ValType, typename Op>
1298 static LogicalResult verifyExtOp(Op op) {
1299  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1300  Type dstType = getElementTypeOrSelf(op.getType());
1301 
1302  if (llvm::cast<ValType>(srcType).getWidth() >=
1303  llvm::cast<ValType>(dstType).getWidth())
1304  return op.emitError("result type ")
1305  << dstType << " must be wider than operand type " << srcType;
1306 
1307  return success();
1308 }
1309 
1310 // Truncate ops can only truncate to a shorter type.
1311 template <typename ValType, typename Op>
1312 static LogicalResult verifyTruncateOp(Op op) {
1313  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1314  Type dstType = getElementTypeOrSelf(op.getType());
1315 
1316  if (llvm::cast<ValType>(srcType).getWidth() <=
1317  llvm::cast<ValType>(dstType).getWidth())
1318  return op.emitError("result type ")
1319  << dstType << " must be shorter than operand type " << srcType;
1320 
1321  return success();
1322 }
1323 
1324 /// Validate a cast that changes the width of a type.
1325 template <template <typename> class WidthComparator, typename... ElementTypes>
1326 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1327  if (!areValidCastInputsAndOutputs(inputs, outputs))
1328  return false;
1329 
1330  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1331  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1332  if (!srcType || !dstType)
1333  return false;
1334 
1335  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336  srcType.getIntOrFloatBitWidth());
1337 }
1338 
1339 /// Attempts to convert `sourceValue` to an APFloat value with
1340 /// `targetSemantics` and `roundingMode`, without any information loss.
1341 static FailureOr<APFloat> convertFloatValue(
1342  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1343  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344  bool losesInfo = false;
1345  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346  if (losesInfo || status != APFloat::opOK)
1347  return failure();
1348 
1349  return sourceValue;
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // ExtUIOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358  getInMutable().assign(lhs.getIn());
1359  return getResult();
1360  }
1361 
1362  Type resType = getElementTypeOrSelf(getType());
1363  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365  adaptor.getOperands(), getType(),
1366  [bitWidth](const APInt &a, bool &castStatus) {
1367  return a.zext(bitWidth);
1368  });
1369 }
1370 
1371 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1372  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1373 }
1374 
1375 LogicalResult arith::ExtUIOp::verify() {
1376  return verifyExtOp<IntegerType>(*this);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ExtSIOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385  getInMutable().assign(lhs.getIn());
1386  return getResult();
1387  }
1388 
1389  Type resType = getElementTypeOrSelf(getType());
1390  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392  adaptor.getOperands(), getType(),
1393  [bitWidth](const APInt &a, bool &castStatus) {
1394  return a.sext(bitWidth);
1395  });
1396 }
1397 
1398 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1400 }
1401 
1402 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403  MLIRContext *context) {
1404  patterns.add<ExtSIOfExtUI>(context);
1405 }
1406 
1407 LogicalResult arith::ExtSIOp::verify() {
1408  return verifyExtOp<IntegerType>(*this);
1409 }
1410 
1411 //===----------------------------------------------------------------------===//
1412 // ExtFOp
1413 //===----------------------------------------------------------------------===//
1414 
1415 /// Fold extension of float constants when there is no information loss due the
1416 /// difference in fp semantics.
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419  if (truncFOp.getOperand().getType() == getType()) {
1420  arith::FastMathFlags truncFMF =
1421  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422  bool isTruncContract =
1423  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1424  arith::FastMathFlags extFMF =
1425  getFastmath().value_or(arith::FastMathFlags::none);
1426  bool isExtContract =
1427  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1428  if (isTruncContract && isExtContract) {
1429  return truncFOp.getOperand();
1430  }
1431  }
1432  }
1433 
1434  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1435  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436  return constFoldCastOp<FloatAttr, FloatAttr>(
1437  adaptor.getOperands(), getType(),
1438  [&targetSemantics](const APFloat &a, bool &castStatus) {
1439  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1440  if (failed(result)) {
1441  castStatus = false;
1442  return a;
1443  }
1444  return *result;
1445  });
1446 }
1447 
1448 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1449  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1450 }
1451 
1452 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // TruncIOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1459  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1460  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1461  Value src = getOperand().getDefiningOp()->getOperand(0);
1462  Type srcType = getElementTypeOrSelf(src.getType());
1463  Type dstType = getElementTypeOrSelf(getType());
1464  // trunci(zexti(a)) -> trunci(a)
1465  // trunci(sexti(a)) -> trunci(a)
1466  if (llvm::cast<IntegerType>(srcType).getWidth() >
1467  llvm::cast<IntegerType>(dstType).getWidth()) {
1468  setOperand(src);
1469  return getResult();
1470  }
1471 
1472  // trunci(zexti(a)) -> a
1473  // trunci(sexti(a)) -> a
1474  if (srcType == dstType)
1475  return src;
1476  }
1477 
1478  // trunci(trunci(a)) -> trunci(a))
1479  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1480  setOperand(getOperand().getDefiningOp()->getOperand(0));
1481  return getResult();
1482  }
1483 
1484  Type resType = getElementTypeOrSelf(getType());
1485  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1486  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1487  adaptor.getOperands(), getType(),
1488  [bitWidth](const APInt &a, bool &castStatus) {
1489  return a.trunc(bitWidth);
1490  });
1491 }
1492 
1493 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1494  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1495 }
1496 
1497 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1498  MLIRContext *context) {
1499  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1500  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1501  context);
1502 }
1503 
1504 LogicalResult arith::TruncIOp::verify() {
1505  return verifyTruncateOp<IntegerType>(*this);
1506 }
1507 
1508 //===----------------------------------------------------------------------===//
1509 // TruncFOp
1510 //===----------------------------------------------------------------------===//
1511 
1512 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1513 /// can be represented without precision loss.
1514 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1515  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1516  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1517  Value src = extOp.getIn();
1518  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1519  auto intermediateType =
1520  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1521  // Check if the srcType is representable in the intermediateType.
1522  if (llvm::APFloatBase::isRepresentableBy(
1523  srcType.getFloatSemantics(),
1524  intermediateType.getFloatSemantics())) {
1525  // truncf(extf(a)) -> truncf(a)
1526  if (srcType.getWidth() > resElemType.getWidth()) {
1527  setOperand(src);
1528  return getResult();
1529  }
1530 
1531  // truncf(extf(a)) -> a
1532  if (srcType == resElemType)
1533  return src;
1534  }
1535  }
1536 
1537  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1538  return constFoldCastOp<FloatAttr, FloatAttr>(
1539  adaptor.getOperands(), getType(),
1540  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1541  RoundingMode roundingMode =
1542  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1543  llvm::RoundingMode llvmRoundingMode =
1544  convertArithRoundingModeToLLVMIR(roundingMode);
1545  FailureOr<APFloat> result =
1546  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1547  if (failed(result)) {
1548  castStatus = false;
1549  return a;
1550  }
1551  return *result;
1552  });
1553 }
1554 
1555 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1556  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1557 }
1558 
1559 LogicalResult arith::TruncFOp::verify() {
1560  return verifyTruncateOp<FloatType>(*this);
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // AndIOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1568  MLIRContext *context) {
1569  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1570 }
1571 
1572 //===----------------------------------------------------------------------===//
1573 // OrIOp
1574 //===----------------------------------------------------------------------===//
1575 
1576 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1577  MLIRContext *context) {
1578  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // Verifiers for casts between integers and floats.
1583 //===----------------------------------------------------------------------===//
1584 
1585 template <typename From, typename To>
1586 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1587  if (!areValidCastInputsAndOutputs(inputs, outputs))
1588  return false;
1589 
1590  auto srcType = getTypeIfLike<From>(inputs.front());
1591  auto dstType = getTypeIfLike<To>(outputs.back());
1592 
1593  return srcType && dstType;
1594 }
1595 
1596 //===----------------------------------------------------------------------===//
1597 // UIToFPOp
1598 //===----------------------------------------------------------------------===//
1599 
1600 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1601  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1602 }
1603 
1604 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1605  Type resEleType = getElementTypeOrSelf(getType());
1606  return constFoldCastOp<IntegerAttr, FloatAttr>(
1607  adaptor.getOperands(), getType(),
1608  [&resEleType](const APInt &a, bool &castStatus) {
1609  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1610  APFloat apf(floatTy.getFloatSemantics(),
1611  APInt::getZero(floatTy.getWidth()));
1612  apf.convertFromAPInt(a, /*IsSigned=*/false,
1613  APFloat::rmNearestTiesToEven);
1614  return apf;
1615  });
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // SIToFPOp
1620 //===----------------------------------------------------------------------===//
1621 
1622 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1623  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1624 }
1625 
1626 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1627  Type resEleType = getElementTypeOrSelf(getType());
1628  return constFoldCastOp<IntegerAttr, FloatAttr>(
1629  adaptor.getOperands(), getType(),
1630  [&resEleType](const APInt &a, bool &castStatus) {
1631  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1632  APFloat apf(floatTy.getFloatSemantics(),
1633  APInt::getZero(floatTy.getWidth()));
1634  apf.convertFromAPInt(a, /*IsSigned=*/true,
1635  APFloat::rmNearestTiesToEven);
1636  return apf;
1637  });
1638 }
1639 
1640 //===----------------------------------------------------------------------===//
1641 // FPToUIOp
1642 //===----------------------------------------------------------------------===//
1643 
1644 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1645  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1646 }
1647 
1648 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1649  Type resType = getElementTypeOrSelf(getType());
1650  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1651  return constFoldCastOp<FloatAttr, IntegerAttr>(
1652  adaptor.getOperands(), getType(),
1653  [&bitWidth](const APFloat &a, bool &castStatus) {
1654  bool ignored;
1655  APSInt api(bitWidth, /*isUnsigned=*/true);
1656  castStatus = APFloat::opInvalidOp !=
1657  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1658  return api;
1659  });
1660 }
1661 
1662 //===----------------------------------------------------------------------===//
1663 // FPToSIOp
1664 //===----------------------------------------------------------------------===//
1665 
1666 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1667  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1668 }
1669 
1670 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1671  Type resType = getElementTypeOrSelf(getType());
1672  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1673  return constFoldCastOp<FloatAttr, IntegerAttr>(
1674  adaptor.getOperands(), getType(),
1675  [&bitWidth](const APFloat &a, bool &castStatus) {
1676  bool ignored;
1677  APSInt api(bitWidth, /*isUnsigned=*/false);
1678  castStatus = APFloat::opInvalidOp !=
1679  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1680  return api;
1681  });
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // IndexCastOp
1686 //===----------------------------------------------------------------------===//
1687 
1688 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1689  if (!areValidCastInputsAndOutputs(inputs, outputs))
1690  return false;
1691 
1692  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1693  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1694  if (!srcType || !dstType)
1695  return false;
1696 
1697  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1698  (srcType.isSignlessInteger() && dstType.isIndex());
1699 }
1700 
1701 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1702  TypeRange outputs) {
1703  return areIndexCastCompatible(inputs, outputs);
1704 }
1705 
1706 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1707  // index_cast(constant) -> constant
1708  unsigned resultBitwidth = 64; // Default for index integer attributes.
1709  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1710  resultBitwidth = intTy.getWidth();
1711 
1712  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1713  adaptor.getOperands(), getType(),
1714  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1715  return a.sextOrTrunc(resultBitwidth);
1716  });
1717 }
1718 
1719 void arith::IndexCastOp::getCanonicalizationPatterns(
1720  RewritePatternSet &patterns, MLIRContext *context) {
1721  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1722 }
1723 
1724 //===----------------------------------------------------------------------===//
1725 // IndexCastUIOp
1726 //===----------------------------------------------------------------------===//
1727 
1728 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1729  TypeRange outputs) {
1730  return areIndexCastCompatible(inputs, outputs);
1731 }
1732 
1733 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1734  // index_castui(constant) -> constant
1735  unsigned resultBitwidth = 64; // Default for index integer attributes.
1736  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1737  resultBitwidth = intTy.getWidth();
1738 
1739  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1740  adaptor.getOperands(), getType(),
1741  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1742  return a.zextOrTrunc(resultBitwidth);
1743  });
1744 }
1745 
1746 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1747  RewritePatternSet &patterns, MLIRContext *context) {
1748  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1749 }
1750 
1751 //===----------------------------------------------------------------------===//
1752 // BitcastOp
1753 //===----------------------------------------------------------------------===//
1754 
1755 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1756  if (!areValidCastInputsAndOutputs(inputs, outputs))
1757  return false;
1758 
1759  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1760  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1761  if (!srcType || !dstType)
1762  return false;
1763 
1764  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1765 }
1766 
1767 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1768  auto resType = getType();
1769  auto operand = adaptor.getIn();
1770  if (!operand)
1771  return {};
1772 
1773  /// Bitcast dense elements.
1774  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1775  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1776  /// Other shaped types unhandled.
1777  if (llvm::isa<ShapedType>(resType))
1778  return {};
1779 
1780  /// Bitcast poison.
1781  if (llvm::isa<ub::PoisonAttr>(operand))
1782  return ub::PoisonAttr::get(getContext());
1783 
1784  /// Bitcast integer or float to integer or float.
1785  APInt bits = llvm::isa<FloatAttr>(operand)
1786  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1787  : llvm::cast<IntegerAttr>(operand).getValue();
1788  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1789  "trying to fold on broken IR: operands have incompatible types");
1790 
1791  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1792  return FloatAttr::get(resType,
1793  APFloat(resFloatType.getFloatSemantics(), bits));
1794  return IntegerAttr::get(resType, bits);
1795 }
1796 
1797 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1798  MLIRContext *context) {
1799  patterns.add<BitcastOfBitcast>(context);
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // CmpIOp
1804 //===----------------------------------------------------------------------===//
1805 
1806 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1807 /// comparison predicates.
1808 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1809  const APInt &lhs, const APInt &rhs) {
1810  switch (predicate) {
1811  case arith::CmpIPredicate::eq:
1812  return lhs.eq(rhs);
1813  case arith::CmpIPredicate::ne:
1814  return lhs.ne(rhs);
1815  case arith::CmpIPredicate::slt:
1816  return lhs.slt(rhs);
1817  case arith::CmpIPredicate::sle:
1818  return lhs.sle(rhs);
1819  case arith::CmpIPredicate::sgt:
1820  return lhs.sgt(rhs);
1821  case arith::CmpIPredicate::sge:
1822  return lhs.sge(rhs);
1823  case arith::CmpIPredicate::ult:
1824  return lhs.ult(rhs);
1825  case arith::CmpIPredicate::ule:
1826  return lhs.ule(rhs);
1827  case arith::CmpIPredicate::ugt:
1828  return lhs.ugt(rhs);
1829  case arith::CmpIPredicate::uge:
1830  return lhs.uge(rhs);
1831  }
1832  llvm_unreachable("unknown cmpi predicate kind");
1833 }
1834 
1835 /// Returns true if the predicate is true for two equal operands.
1836 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1837  switch (predicate) {
1838  case arith::CmpIPredicate::eq:
1839  case arith::CmpIPredicate::sle:
1840  case arith::CmpIPredicate::sge:
1841  case arith::CmpIPredicate::ule:
1842  case arith::CmpIPredicate::uge:
1843  return true;
1844  case arith::CmpIPredicate::ne:
1845  case arith::CmpIPredicate::slt:
1846  case arith::CmpIPredicate::sgt:
1847  case arith::CmpIPredicate::ult:
1848  case arith::CmpIPredicate::ugt:
1849  return false;
1850  }
1851  llvm_unreachable("unknown cmpi predicate kind");
1852 }
1853 
1854 static std::optional<int64_t> getIntegerWidth(Type t) {
1855  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1856  return intType.getWidth();
1857  }
1858  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1859  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1860  }
1861  return std::nullopt;
1862 }
1863 
1864 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1865  // cmpi(pred, x, x)
1866  if (getLhs() == getRhs()) {
1867  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1868  return getBoolAttribute(getType(), val);
1869  }
1870 
1871  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1872  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1873  // extsi(%x : i1 -> iN) != 0 -> %x
1874  std::optional<int64_t> integerWidth =
1875  getIntegerWidth(extOp.getOperand().getType());
1876  if (integerWidth && integerWidth.value() == 1 &&
1877  getPredicate() == arith::CmpIPredicate::ne)
1878  return extOp.getOperand();
1879  }
1880  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1881  // extui(%x : i1 -> iN) != 0 -> %x
1882  std::optional<int64_t> integerWidth =
1883  getIntegerWidth(extOp.getOperand().getType());
1884  if (integerWidth && integerWidth.value() == 1 &&
1885  getPredicate() == arith::CmpIPredicate::ne)
1886  return extOp.getOperand();
1887  }
1888 
1889  // arith.cmpi ne, %val, %zero : i1 -> %val
1890  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1891  getPredicate() == arith::CmpIPredicate::ne)
1892  return getLhs();
1893  }
1894 
1895  if (matchPattern(adaptor.getRhs(), m_One())) {
1896  // arith.cmpi eq, %val, %one : i1 -> %val
1897  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1898  getPredicate() == arith::CmpIPredicate::eq)
1899  return getLhs();
1900  }
1901 
1902  // Move constant to the right side.
1903  if (adaptor.getLhs() && !adaptor.getRhs()) {
1904  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1905  using Pred = CmpIPredicate;
1906  const std::pair<Pred, Pred> invPreds[] = {
1907  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1908  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1909  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1910  {Pred::ne, Pred::ne},
1911  };
1912  Pred origPred = getPredicate();
1913  for (auto pred : invPreds) {
1914  if (origPred == pred.first) {
1915  setPredicate(pred.second);
1916  Value lhs = getLhs();
1917  Value rhs = getRhs();
1918  getLhsMutable().assign(rhs);
1919  getRhsMutable().assign(lhs);
1920  return getResult();
1921  }
1922  }
1923  llvm_unreachable("unknown cmpi predicate kind");
1924  }
1925 
1926  // We are moving constants to the right side; So if lhs is constant rhs is
1927  // guaranteed to be a constant.
1928  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1929  return constFoldBinaryOp<IntegerAttr>(
1930  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1931  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1932  return APInt(1,
1933  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1934  });
1935  }
1936 
1937  return {};
1938 }
1939 
1940 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1941  MLIRContext *context) {
1942  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1943 }
1944 
1945 //===----------------------------------------------------------------------===//
1946 // CmpFOp
1947 //===----------------------------------------------------------------------===//
1948 
1949 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1950 /// comparison predicates.
1951 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1952  const APFloat &lhs, const APFloat &rhs) {
1953  auto cmpResult = lhs.compare(rhs);
1954  switch (predicate) {
1955  case arith::CmpFPredicate::AlwaysFalse:
1956  return false;
1957  case arith::CmpFPredicate::OEQ:
1958  return cmpResult == APFloat::cmpEqual;
1959  case arith::CmpFPredicate::OGT:
1960  return cmpResult == APFloat::cmpGreaterThan;
1961  case arith::CmpFPredicate::OGE:
1962  return cmpResult == APFloat::cmpGreaterThan ||
1963  cmpResult == APFloat::cmpEqual;
1964  case arith::CmpFPredicate::OLT:
1965  return cmpResult == APFloat::cmpLessThan;
1966  case arith::CmpFPredicate::OLE:
1967  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1968  case arith::CmpFPredicate::ONE:
1969  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1970  case arith::CmpFPredicate::ORD:
1971  return cmpResult != APFloat::cmpUnordered;
1972  case arith::CmpFPredicate::UEQ:
1973  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1974  case arith::CmpFPredicate::UGT:
1975  return cmpResult == APFloat::cmpUnordered ||
1976  cmpResult == APFloat::cmpGreaterThan;
1977  case arith::CmpFPredicate::UGE:
1978  return cmpResult == APFloat::cmpUnordered ||
1979  cmpResult == APFloat::cmpGreaterThan ||
1980  cmpResult == APFloat::cmpEqual;
1981  case arith::CmpFPredicate::ULT:
1982  return cmpResult == APFloat::cmpUnordered ||
1983  cmpResult == APFloat::cmpLessThan;
1984  case arith::CmpFPredicate::ULE:
1985  return cmpResult == APFloat::cmpUnordered ||
1986  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1987  case arith::CmpFPredicate::UNE:
1988  return cmpResult != APFloat::cmpEqual;
1989  case arith::CmpFPredicate::UNO:
1990  return cmpResult == APFloat::cmpUnordered;
1991  case arith::CmpFPredicate::AlwaysTrue:
1992  return true;
1993  }
1994  llvm_unreachable("unknown cmpf predicate kind");
1995 }
1996 
1997 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1998  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1999  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2000 
2001  // If one operand is NaN, making them both NaN does not change the result.
2002  if (lhs && lhs.getValue().isNaN())
2003  rhs = lhs;
2004  if (rhs && rhs.getValue().isNaN())
2005  lhs = rhs;
2006 
2007  if (!lhs || !rhs)
2008  return {};
2009 
2010  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2011  return BoolAttr::get(getContext(), val);
2012 }
2013 
2014 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2015 public:
2017 
2018  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2019  bool isUnsigned) {
2020  using namespace arith;
2021  switch (pred) {
2022  case CmpFPredicate::UEQ:
2023  case CmpFPredicate::OEQ:
2024  return CmpIPredicate::eq;
2025  case CmpFPredicate::UGT:
2026  case CmpFPredicate::OGT:
2027  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2028  case CmpFPredicate::UGE:
2029  case CmpFPredicate::OGE:
2030  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2031  case CmpFPredicate::ULT:
2032  case CmpFPredicate::OLT:
2033  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2034  case CmpFPredicate::ULE:
2035  case CmpFPredicate::OLE:
2036  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2037  case CmpFPredicate::UNE:
2038  case CmpFPredicate::ONE:
2039  return CmpIPredicate::ne;
2040  default:
2041  llvm_unreachable("Unexpected predicate!");
2042  }
2043  }
2044 
2045  LogicalResult matchAndRewrite(CmpFOp op,
2046  PatternRewriter &rewriter) const override {
2047  FloatAttr flt;
2048  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2049  return failure();
2050 
2051  const APFloat &rhs = flt.getValue();
2052 
2053  // Don't attempt to fold a nan.
2054  if (rhs.isNaN())
2055  return failure();
2056 
2057  // Get the width of the mantissa. We don't want to hack on conversions that
2058  // might lose information from the integer, e.g. "i64 -> float"
2059  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2060  int mantissaWidth = floatTy.getFPMantissaWidth();
2061  if (mantissaWidth <= 0)
2062  return failure();
2063 
2064  bool isUnsigned;
2065  Value intVal;
2066 
2067  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2068  isUnsigned = false;
2069  intVal = si.getIn();
2070  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2071  isUnsigned = true;
2072  intVal = ui.getIn();
2073  } else {
2074  return failure();
2075  }
2076 
2077  // Check to see that the input is converted from an integer type that is
2078  // small enough that preserves all bits.
2079  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2080  auto intWidth = intTy.getWidth();
2081 
2082  // Number of bits representing values, as opposed to the sign
2083  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2084 
2085  // Following test does NOT adjust intWidth downwards for signed inputs,
2086  // because the most negative value still requires all the mantissa bits
2087  // to distinguish it from one less than that value.
2088  if ((int)intWidth > mantissaWidth) {
2089  // Conversion would lose accuracy. Check if loss can impact comparison.
2090  int exponent = ilogb(rhs);
2091  if (exponent == APFloat::IEK_Inf) {
2092  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2093  if (maxExponent < (int)valueBits) {
2094  // Conversion could create infinity.
2095  return failure();
2096  }
2097  } else {
2098  // Note that if rhs is zero or NaN, then Exp is negative
2099  // and first condition is trivially false.
2100  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2101  // Conversion could affect comparison.
2102  return failure();
2103  }
2104  }
2105  }
2106 
2107  // Convert to equivalent cmpi predicate
2108  CmpIPredicate pred;
2109  switch (op.getPredicate()) {
2110  case CmpFPredicate::ORD:
2111  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2112  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2113  /*width=*/1);
2114  return success();
2115  case CmpFPredicate::UNO:
2116  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2117  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2118  /*width=*/1);
2119  return success();
2120  default:
2121  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2122  break;
2123  }
2124 
2125  if (!isUnsigned) {
2126  // If the rhs value is > SignedMax, fold the comparison. This handles
2127  // +INF and large values.
2128  APFloat signedMax(rhs.getSemantics());
2129  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2130  APFloat::rmNearestTiesToEven);
2131  if (signedMax < rhs) { // smax < 13123.0
2132  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2133  pred == CmpIPredicate::sle)
2134  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2135  /*width=*/1);
2136  else
2137  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2138  /*width=*/1);
2139  return success();
2140  }
2141  } else {
2142  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2143  // +INF and large values.
2144  APFloat unsignedMax(rhs.getSemantics());
2145  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2146  APFloat::rmNearestTiesToEven);
2147  if (unsignedMax < rhs) { // umax < 13123.0
2148  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2149  pred == CmpIPredicate::ule)
2150  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2151  /*width=*/1);
2152  else
2153  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2154  /*width=*/1);
2155  return success();
2156  }
2157  }
2158 
2159  if (!isUnsigned) {
2160  // See if the rhs value is < SignedMin.
2161  APFloat signedMin(rhs.getSemantics());
2162  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2163  APFloat::rmNearestTiesToEven);
2164  if (signedMin > rhs) { // smin > 12312.0
2165  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2166  pred == CmpIPredicate::sge)
2167  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2168  /*width=*/1);
2169  else
2170  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2171  /*width=*/1);
2172  return success();
2173  }
2174  } else {
2175  // See if the rhs value is < UnsignedMin.
2176  APFloat unsignedMin(rhs.getSemantics());
2177  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2178  APFloat::rmNearestTiesToEven);
2179  if (unsignedMin > rhs) { // umin > 12312.0
2180  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2181  pred == CmpIPredicate::uge)
2182  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2183  /*width=*/1);
2184  else
2185  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2186  /*width=*/1);
2187  return success();
2188  }
2189  }
2190 
2191  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2192  // [0, UMAX], but it may still be fractional. See if it is fractional by
2193  // casting the FP value to the integer value and back, checking for
2194  // equality. Don't do this for zero, because -0.0 is not fractional.
2195  bool ignored;
2196  APSInt rhsInt(intWidth, isUnsigned);
2197  if (APFloat::opInvalidOp ==
2198  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2199  // Undefined behavior invoked - the destination type can't represent
2200  // the input constant.
2201  return failure();
2202  }
2203 
2204  if (!rhs.isZero()) {
2205  APFloat apf(floatTy.getFloatSemantics(),
2206  APInt::getZero(floatTy.getWidth()));
2207  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2208 
2209  bool equal = apf == rhs;
2210  if (!equal) {
2211  // If we had a comparison against a fractional value, we have to adjust
2212  // the compare predicate and sometimes the value. rhsInt is rounded
2213  // towards zero at this point.
2214  switch (pred) {
2215  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2216  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2217  /*width=*/1);
2218  return success();
2219  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2220  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2221  /*width=*/1);
2222  return success();
2223  case CmpIPredicate::ule:
2224  // (float)int <= 4.4 --> int <= 4
2225  // (float)int <= -4.4 --> false
2226  if (rhs.isNegative()) {
2227  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2228  /*width=*/1);
2229  return success();
2230  }
2231  break;
2232  case CmpIPredicate::sle:
2233  // (float)int <= 4.4 --> int <= 4
2234  // (float)int <= -4.4 --> int < -4
2235  if (rhs.isNegative())
2236  pred = CmpIPredicate::slt;
2237  break;
2238  case CmpIPredicate::ult:
2239  // (float)int < -4.4 --> false
2240  // (float)int < 4.4 --> int <= 4
2241  if (rhs.isNegative()) {
2242  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2243  /*width=*/1);
2244  return success();
2245  }
2246  pred = CmpIPredicate::ule;
2247  break;
2248  case CmpIPredicate::slt:
2249  // (float)int < -4.4 --> int < -4
2250  // (float)int < 4.4 --> int <= 4
2251  if (!rhs.isNegative())
2252  pred = CmpIPredicate::sle;
2253  break;
2254  case CmpIPredicate::ugt:
2255  // (float)int > 4.4 --> int > 4
2256  // (float)int > -4.4 --> true
2257  if (rhs.isNegative()) {
2258  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2259  /*width=*/1);
2260  return success();
2261  }
2262  break;
2263  case CmpIPredicate::sgt:
2264  // (float)int > 4.4 --> int > 4
2265  // (float)int > -4.4 --> int >= -4
2266  if (rhs.isNegative())
2267  pred = CmpIPredicate::sge;
2268  break;
2269  case CmpIPredicate::uge:
2270  // (float)int >= -4.4 --> true
2271  // (float)int >= 4.4 --> int > 4
2272  if (rhs.isNegative()) {
2273  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2274  /*width=*/1);
2275  return success();
2276  }
2277  pred = CmpIPredicate::ugt;
2278  break;
2279  case CmpIPredicate::sge:
2280  // (float)int >= -4.4 --> int >= -4
2281  // (float)int >= 4.4 --> int > 4
2282  if (!rhs.isNegative())
2283  pred = CmpIPredicate::sgt;
2284  break;
2285  }
2286  }
2287  }
2288 
2289  // Lower this FP comparison into an appropriate integer version of the
2290  // comparison.
2291  rewriter.replaceOpWithNewOp<CmpIOp>(
2292  op, pred, intVal,
2293  rewriter.create<ConstantOp>(
2294  op.getLoc(), intVal.getType(),
2295  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2296  return success();
2297  }
2298 };
2299 
2300 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2301  MLIRContext *context) {
2302  patterns.insert<CmpFIntToFPConst>(context);
2303 }
2304 
2305 //===----------------------------------------------------------------------===//
2306 // SelectOp
2307 //===----------------------------------------------------------------------===//
2308 
2309 // select %arg, %c1, %c0 => extui %arg
2310 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2312 
2313  LogicalResult matchAndRewrite(arith::SelectOp op,
2314  PatternRewriter &rewriter) const override {
2315  // Cannot extui i1 to i1, or i1 to f32
2316  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2317  return failure();
2318 
2319  // select %x, c1, %c0 => extui %arg
2320  if (matchPattern(op.getTrueValue(), m_One()) &&
2321  matchPattern(op.getFalseValue(), m_Zero())) {
2322  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2323  op.getCondition());
2324  return success();
2325  }
2326 
2327  // select %x, c0, %c1 => extui (xor %arg, true)
2328  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2329  matchPattern(op.getFalseValue(), m_One())) {
2330  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2331  op, op.getType(),
2332  rewriter.create<arith::XOrIOp>(
2333  op.getLoc(), op.getCondition(),
2334  rewriter.create<arith::ConstantIntOp>(
2335  op.getLoc(), 1, op.getCondition().getType())));
2336  return success();
2337  }
2338 
2339  return failure();
2340  }
2341 };
2342 
2343 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2344  MLIRContext *context) {
2345  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2346  SelectI1ToNot, SelectToExtUI>(context);
2347 }
2348 
2349 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2350  Value trueVal = getTrueValue();
2351  Value falseVal = getFalseValue();
2352  if (trueVal == falseVal)
2353  return trueVal;
2354 
2355  Value condition = getCondition();
2356 
2357  // select true, %0, %1 => %0
2358  if (matchPattern(adaptor.getCondition(), m_One()))
2359  return trueVal;
2360 
2361  // select false, %0, %1 => %1
2362  if (matchPattern(adaptor.getCondition(), m_Zero()))
2363  return falseVal;
2364 
2365  // If either operand is fully poisoned, return the other.
2366  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2367  return falseVal;
2368 
2369  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2370  return trueVal;
2371 
2372  // select %x, true, false => %x
2373  if (getType().isSignlessInteger(1) &&
2374  matchPattern(adaptor.getTrueValue(), m_One()) &&
2375  matchPattern(adaptor.getFalseValue(), m_Zero()))
2376  return condition;
2377 
2378  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2379  auto pred = cmp.getPredicate();
2380  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2381  auto cmpLhs = cmp.getLhs();
2382  auto cmpRhs = cmp.getRhs();
2383 
2384  // %0 = arith.cmpi eq, %arg0, %arg1
2385  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2386 
2387  // %0 = arith.cmpi ne, %arg0, %arg1
2388  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2389 
2390  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2391  (cmpRhs == trueVal && cmpLhs == falseVal))
2392  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2393  }
2394  }
2395 
2396  // Constant-fold constant operands over non-splat constant condition.
2397  // select %cst_vec, %cst0, %cst1 => %cst2
2398  if (auto cond =
2399  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2400  if (auto lhs =
2401  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2402  if (auto rhs =
2403  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2404  SmallVector<Attribute> results;
2405  results.reserve(static_cast<size_t>(cond.getNumElements()));
2406  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2407  cond.value_end<BoolAttr>());
2408  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2409  lhs.value_end<Attribute>());
2410  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2411  rhs.value_end<Attribute>());
2412 
2413  for (auto [condVal, lhsVal, rhsVal] :
2414  llvm::zip_equal(condVals, lhsVals, rhsVals))
2415  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2416 
2417  return DenseElementsAttr::get(lhs.getType(), results);
2418  }
2419  }
2420  }
2421 
2422  return nullptr;
2423 }
2424 
2425 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2426  Type conditionType, resultType;
2428  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2429  parser.parseOptionalAttrDict(result.attributes) ||
2430  parser.parseColonType(resultType))
2431  return failure();
2432 
2433  // Check for the explicit condition type if this is a masked tensor or vector.
2434  if (succeeded(parser.parseOptionalComma())) {
2435  conditionType = resultType;
2436  if (parser.parseType(resultType))
2437  return failure();
2438  } else {
2439  conditionType = parser.getBuilder().getI1Type();
2440  }
2441 
2442  result.addTypes(resultType);
2443  return parser.resolveOperands(operands,
2444  {conditionType, resultType, resultType},
2445  parser.getNameLoc(), result.operands);
2446 }
2447 
2449  p << " " << getOperands();
2450  p.printOptionalAttrDict((*this)->getAttrs());
2451  p << " : ";
2452  if (ShapedType condType =
2453  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2454  p << condType << ", ";
2455  p << getType();
2456 }
2457 
2458 LogicalResult arith::SelectOp::verify() {
2459  Type conditionType = getCondition().getType();
2460  if (conditionType.isSignlessInteger(1))
2461  return success();
2462 
2463  // If the result type is a vector or tensor, the type can be a mask with the
2464  // same elements.
2465  Type resultType = getType();
2466  if (!llvm::isa<TensorType, VectorType>(resultType))
2467  return emitOpError() << "expected condition to be a signless i1, but got "
2468  << conditionType;
2469  Type shapedConditionType = getI1SameShape(resultType);
2470  if (conditionType != shapedConditionType) {
2471  return emitOpError() << "expected condition type to have the same shape "
2472  "as the result type, expected "
2473  << shapedConditionType << ", but got "
2474  << conditionType;
2475  }
2476  return success();
2477 }
2478 //===----------------------------------------------------------------------===//
2479 // ShLIOp
2480 //===----------------------------------------------------------------------===//
2481 
2482 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2483  // shli(x, 0) -> x
2484  if (matchPattern(adaptor.getRhs(), m_Zero()))
2485  return getLhs();
2486  // Don't fold if shifting more or equal than the bit width.
2487  bool bounded = false;
2488  auto result = constFoldBinaryOp<IntegerAttr>(
2489  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2490  bounded = b.ult(b.getBitWidth());
2491  return a.shl(b);
2492  });
2493  return bounded ? result : Attribute();
2494 }
2495 
2496 //===----------------------------------------------------------------------===//
2497 // ShRUIOp
2498 //===----------------------------------------------------------------------===//
2499 
2500 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2501  // shrui(x, 0) -> x
2502  if (matchPattern(adaptor.getRhs(), m_Zero()))
2503  return getLhs();
2504  // Don't fold if shifting more or equal than the bit width.
2505  bool bounded = false;
2506  auto result = constFoldBinaryOp<IntegerAttr>(
2507  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2508  bounded = b.ult(b.getBitWidth());
2509  return a.lshr(b);
2510  });
2511  return bounded ? result : Attribute();
2512 }
2513 
2514 //===----------------------------------------------------------------------===//
2515 // ShRSIOp
2516 //===----------------------------------------------------------------------===//
2517 
2518 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2519  // shrsi(x, 0) -> x
2520  if (matchPattern(adaptor.getRhs(), m_Zero()))
2521  return getLhs();
2522  // Don't fold if shifting more or equal than the bit width.
2523  bool bounded = false;
2524  auto result = constFoldBinaryOp<IntegerAttr>(
2525  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2526  bounded = b.ult(b.getBitWidth());
2527  return a.ashr(b);
2528  });
2529  return bounded ? result : Attribute();
2530 }
2531 
2532 //===----------------------------------------------------------------------===//
2533 // Atomic Enum
2534 //===----------------------------------------------------------------------===//
2535 
2536 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2537 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2538  OpBuilder &builder, Location loc,
2539  bool useOnlyFiniteValue) {
2540  switch (kind) {
2541  case AtomicRMWKind::maximumf: {
2542  const llvm::fltSemantics &semantic =
2543  llvm::cast<FloatType>(resultType).getFloatSemantics();
2544  APFloat identity = useOnlyFiniteValue
2545  ? APFloat::getLargest(semantic, /*Negative=*/true)
2546  : APFloat::getInf(semantic, /*Negative=*/true);
2547  return builder.getFloatAttr(resultType, identity);
2548  }
2549  case AtomicRMWKind::maxnumf: {
2550  const llvm::fltSemantics &semantic =
2551  llvm::cast<FloatType>(resultType).getFloatSemantics();
2552  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2553  return builder.getFloatAttr(resultType, identity);
2554  }
2555  case AtomicRMWKind::addf:
2556  case AtomicRMWKind::addi:
2557  case AtomicRMWKind::maxu:
2558  case AtomicRMWKind::ori:
2559  return builder.getZeroAttr(resultType);
2560  case AtomicRMWKind::andi:
2561  return builder.getIntegerAttr(
2562  resultType,
2563  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2564  case AtomicRMWKind::maxs:
2565  return builder.getIntegerAttr(
2566  resultType, APInt::getSignedMinValue(
2567  llvm::cast<IntegerType>(resultType).getWidth()));
2568  case AtomicRMWKind::minimumf: {
2569  const llvm::fltSemantics &semantic =
2570  llvm::cast<FloatType>(resultType).getFloatSemantics();
2571  APFloat identity = useOnlyFiniteValue
2572  ? APFloat::getLargest(semantic, /*Negative=*/false)
2573  : APFloat::getInf(semantic, /*Negative=*/false);
2574 
2575  return builder.getFloatAttr(resultType, identity);
2576  }
2577  case AtomicRMWKind::minnumf: {
2578  const llvm::fltSemantics &semantic =
2579  llvm::cast<FloatType>(resultType).getFloatSemantics();
2580  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2581  return builder.getFloatAttr(resultType, identity);
2582  }
2583  case AtomicRMWKind::mins:
2584  return builder.getIntegerAttr(
2585  resultType, APInt::getSignedMaxValue(
2586  llvm::cast<IntegerType>(resultType).getWidth()));
2587  case AtomicRMWKind::minu:
2588  return builder.getIntegerAttr(
2589  resultType,
2590  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2591  case AtomicRMWKind::muli:
2592  return builder.getIntegerAttr(resultType, 1);
2593  case AtomicRMWKind::mulf:
2594  return builder.getFloatAttr(resultType, 1);
2595  // TODO: Add remaining reduction operations.
2596  default:
2597  (void)emitOptionalError(loc, "Reduction operation type not supported");
2598  break;
2599  }
2600  return nullptr;
2601 }
2602 
2603 /// Return the identity numeric value associated to the give op.
2604 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2605  std::optional<AtomicRMWKind> maybeKind =
2607  // Floating-point operations.
2608  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2609  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2610  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2611  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2612  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2613  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2614  // Integer operations.
2615  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2616  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2617  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2618  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2619  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2620  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2621  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2622  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2623  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2624  .Default([](Operation *op) { return std::nullopt; });
2625  if (!maybeKind) {
2626  return std::nullopt;
2627  }
2628 
2629  bool useOnlyFiniteValue = false;
2630  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2631  if (fmfOpInterface) {
2632  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2633  useOnlyFiniteValue =
2634  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2635  }
2636 
2637  // Builder only used as helper for attribute creation.
2638  OpBuilder b(op->getContext());
2639  Type resultType = op->getResult(0).getType();
2640 
2641  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2642  useOnlyFiniteValue);
2643 }
2644 
2645 /// Returns the identity value associated with an AtomicRMWKind op.
2646 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2647  OpBuilder &builder, Location loc,
2648  bool useOnlyFiniteValue) {
2649  auto attr =
2650  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2651  return builder.create<arith::ConstantOp>(loc, attr);
2652 }
2653 
2654 /// Return the value obtained by applying the reduction operation kind
2655 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2656 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2657  Location loc, Value lhs, Value rhs) {
2658  switch (op) {
2659  case AtomicRMWKind::addf:
2660  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2661  case AtomicRMWKind::addi:
2662  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2663  case AtomicRMWKind::mulf:
2664  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2665  case AtomicRMWKind::muli:
2666  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2667  case AtomicRMWKind::maximumf:
2668  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2669  case AtomicRMWKind::minimumf:
2670  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2671  case AtomicRMWKind::maxnumf:
2672  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2673  case AtomicRMWKind::minnumf:
2674  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2675  case AtomicRMWKind::maxs:
2676  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2677  case AtomicRMWKind::mins:
2678  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2679  case AtomicRMWKind::maxu:
2680  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2681  case AtomicRMWKind::minu:
2682  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2683  case AtomicRMWKind::ori:
2684  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2685  case AtomicRMWKind::andi:
2686  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2687  // TODO: Add remaining reduction operations.
2688  default:
2689  (void)emitOptionalError(loc, "Reduction operation type not supported");
2690  break;
2691  }
2692  return nullptr;
2693 }
2694 
2695 //===----------------------------------------------------------------------===//
2696 // TableGen'd op method definitions
2697 //===----------------------------------------------------------------------===//
2698 
2699 #define GET_OP_CLASSES
2700 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2701 
2702 //===----------------------------------------------------------------------===//
2703 // TableGen'd enum attribute definitions
2704 //===----------------------------------------------------------------------===//
2705 
2706 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:618
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1326
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1262
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1836
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:579
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1276
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1248
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:678
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:660
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1341
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1688
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1586
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1298
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:859
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1269
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1854
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1284
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1242
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:340
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1312
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2045
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2018
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:273
static bool classof(Operation *op)
Definition: ArithOps.cpp:279
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:285
static bool classof(Operation *op)
Definition: ArithOps.cpp:291
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:252
static bool classof(Operation *op)
Definition: ArithOps.cpp:267
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2604
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1808
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2537
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2656
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2646
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2313
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes