MLIR  21.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // Integer values must be signless.
211  if (llvm::isa<IntegerType>(type) &&
212  !llvm::cast<IntegerType>(type).isSignless())
213  return emitOpError("integer return type must be signless");
214  // Any float or elements attribute are acceptable.
215  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216  return emitOpError(
217  "value must be an integer, float, or elements attribute");
218  }
219 
220  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
221  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
222  // However, this would most likely require updating the lowerings to LLVM.
223  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224  return emitOpError(
225  "intializing scalable vectors with elements attribute is not supported"
226  " unless it's a vector splat");
227  return success();
228 }
229 
230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231  // The value's type must be the same as the provided type.
232  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233  if (!typedAttr || typedAttr.getType() != type)
234  return false;
235  // Integer values must be signless.
236  if (llvm::isa<IntegerType>(type) &&
237  !llvm::cast<IntegerType>(type).isSignless())
238  return false;
239  // Integer, float, and element attributes are buildable.
240  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241 }
242 
243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
244  Type type, Location loc) {
245  if (isBuildableWith(value, type))
246  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
247  return nullptr;
248 }
249 
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251 
253  int64_t value, unsigned width) {
254  auto type = builder.getIntegerType(width);
255  arith::ConstantOp::build(builder, result, type,
256  builder.getIntegerAttr(type, value));
257 }
258 
260  int64_t value, Type type) {
261  assert(type.isSignlessInteger() &&
262  "ConstantIntOp can only have signless integer type values");
263  arith::ConstantOp::build(builder, result, type,
264  builder.getIntegerAttr(type, value));
265 }
266 
268  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269  return constOp.getType().isSignlessInteger();
270  return false;
271 }
272 
274  const APFloat &value, FloatType type) {
275  arith::ConstantOp::build(builder, result, type,
276  builder.getFloatAttr(type, value));
277 }
278 
280  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281  return llvm::isa<FloatType>(constOp.getType());
282  return false;
283 }
284 
286  int64_t value) {
287  arith::ConstantOp::build(builder, result, builder.getIndexType(),
288  builder.getIndexAttr(value));
289 }
290 
292  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293  return constOp.getType().isIndex();
294  return false;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // AddIOp
299 //===----------------------------------------------------------------------===//
300 
301 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
302  // addi(x, 0) -> x
303  if (matchPattern(adaptor.getRhs(), m_Zero()))
304  return getLhs();
305 
306  // addi(subi(a, b), b) -> a
307  if (auto sub = getLhs().getDefiningOp<SubIOp>())
308  if (getRhs() == sub.getRhs())
309  return sub.getLhs();
310 
311  // addi(b, subi(a, b)) -> a
312  if (auto sub = getRhs().getDefiningOp<SubIOp>())
313  if (getLhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  return constFoldBinaryOp<IntegerAttr>(
317  adaptor.getOperands(),
318  [](APInt a, const APInt &b) { return std::move(a) + b; });
319 }
320 
321 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
322  MLIRContext *context) {
323  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // AddUIExtendedOp
329 //===----------------------------------------------------------------------===//
330 
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
334  return llvm::to_vector<4>(vt.getShape());
335  return std::nullopt;
336 }
337 
338 // Returns the overflow bit, assuming that `sum` is the result of unsigned
339 // addition of `operand` and another number.
340 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
341  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
342 }
343 
344 LogicalResult
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347  Type overflowTy = getOverflow().getType();
348  // addui_extended(x, 0) -> x, false
349  if (matchPattern(getRhs(), m_Zero())) {
350  Builder builder(getContext());
351  auto falseValue = builder.getZeroAttr(overflowTy);
352 
353  results.push_back(getLhs());
354  results.push_back(falseValue);
355  return success();
356  }
357 
358  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
359  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
360  // operands. If that succeeds, calculate the overflow bit based on the sum
361  // and the first (constant) operand, `lhs`.
362  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363  adaptor.getOperands(),
364  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
365  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366  ArrayRef({sumAttr, adaptor.getLhs()}),
367  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
369  if (!overflowAttr)
370  return failure();
371 
372  results.push_back(sumAttr);
373  results.push_back(overflowAttr);
374  return success();
375  }
376 
377  return failure();
378 }
379 
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382  patterns.add<AddUIExtendedToAddI>(context);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // SubIOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
390  // subi(x,x) -> 0
391  if (getOperand(0) == getOperand(1)) {
392  auto shapedType = dyn_cast<ShapedType>(getType());
393  // We can't generate a constant with a dynamic shaped tensor.
394  if (!shapedType || shapedType.hasStaticShape())
395  return Builder(getContext()).getZeroAttr(getType());
396  }
397  // subi(x,0) -> x
398  if (matchPattern(adaptor.getRhs(), m_Zero()))
399  return getLhs();
400 
401  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
402  // subi(addi(a, b), b) -> a
403  if (getRhs() == add.getRhs())
404  return add.getLhs();
405  // subi(addi(a, b), a) -> b
406  if (getRhs() == add.getLhs())
407  return add.getRhs();
408  }
409 
410  return constFoldBinaryOp<IntegerAttr>(
411  adaptor.getOperands(),
412  [](APInt a, const APInt &b) { return std::move(a) - b; });
413 }
414 
415 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
416  MLIRContext *context) {
417  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // MulIOp
424 //===----------------------------------------------------------------------===//
425 
426 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
427  // muli(x, 0) -> 0
428  if (matchPattern(adaptor.getRhs(), m_Zero()))
429  return getRhs();
430  // muli(x, 1) -> x
431  if (matchPattern(adaptor.getRhs(), m_One()))
432  return getLhs();
433  // TODO: Handle the overflow case.
434 
435  // default folder
436  return constFoldBinaryOp<IntegerAttr>(
437  adaptor.getOperands(),
438  [](const APInt &a, const APInt &b) { return a * b; });
439 }
440 
441 void arith::MulIOp::getAsmResultNames(
442  function_ref<void(Value, StringRef)> setNameFn) {
443  if (!isa<IndexType>(getType()))
444  return;
445 
446  // Match vector.vscale by name to avoid depending on the vector dialect (which
447  // is a circular dependency).
448  auto isVscale = [](Operation *op) {
449  return op && op->getName().getStringRef() == "vector.vscale";
450  };
451 
452  IntegerAttr baseValue;
453  auto isVscaleExpr = [&](Value a, Value b) {
454  return matchPattern(a, m_Constant(&baseValue)) &&
455  isVscale(b.getDefiningOp());
456  };
457 
458  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
459  return;
460 
461  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
462  SmallString<32> specialNameBuffer;
463  llvm::raw_svector_ostream specialName(specialNameBuffer);
464  specialName << 'c' << baseValue.getInt() << "_vscale";
465  setNameFn(getResult(), specialName.str());
466 }
467 
468 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
469  MLIRContext *context) {
470  patterns.add<MulIMulIConstant>(context);
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // MulSIExtendedOp
475 //===----------------------------------------------------------------------===//
476 
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
480  return llvm::to_vector<4>(vt.getShape());
481  return std::nullopt;
482 }
483 
484 LogicalResult
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
487  // mulsi_extended(x, 0) -> 0, 0
488  if (matchPattern(adaptor.getRhs(), m_Zero())) {
489  Attribute zero = adaptor.getRhs();
490  results.push_back(zero);
491  results.push_back(zero);
492  return success();
493  }
494 
495  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
496  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497  adaptor.getOperands(),
498  [](const APInt &a, const APInt &b) { return a * b; })) {
499  // Invoke the constant fold helper again to calculate the 'high' result.
500  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
502  return llvm::APIntOps::mulhs(a, b);
503  });
504  assert(highAttr && "Unexpected constant-folding failure");
505 
506  results.push_back(lowAttr);
507  results.push_back(highAttr);
508  return success();
509  }
510 
511  return failure();
512 }
513 
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // MulUIExtendedOp
521 //===----------------------------------------------------------------------===//
522 
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
526  return llvm::to_vector<4>(vt.getShape());
527  return std::nullopt;
528 }
529 
530 LogicalResult
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
533  // mului_extended(x, 0) -> 0, 0
534  if (matchPattern(adaptor.getRhs(), m_Zero())) {
535  Attribute zero = adaptor.getRhs();
536  results.push_back(zero);
537  results.push_back(zero);
538  return success();
539  }
540 
541  // mului_extended(x, 1) -> x, 0
542  if (matchPattern(adaptor.getRhs(), m_One())) {
543  Builder builder(getContext());
544  Attribute zero = builder.getZeroAttr(getLhs().getType());
545  results.push_back(getLhs());
546  results.push_back(zero);
547  return success();
548  }
549 
550  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
551  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552  adaptor.getOperands(),
553  [](const APInt &a, const APInt &b) { return a * b; })) {
554  // Invoke the constant fold helper again to calculate the 'high' result.
555  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
557  return llvm::APIntOps::mulhu(a, b);
558  });
559  assert(highAttr && "Unexpected constant-folding failure");
560 
561  results.push_back(lowAttr);
562  results.push_back(highAttr);
563  return success();
564  }
565 
566  return failure();
567 }
568 
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571  patterns.add<MulUIExtendedToMulI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // DivUIOp
576 //===----------------------------------------------------------------------===//
577 
578 /// Fold `(a * b) / b -> a`
579 static Value foldDivMul(Value lhs, Value rhs,
580  arith::IntegerOverflowFlags ovfFlags) {
581  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
582  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
583  return {};
584 
585  if (mul.getLhs() == rhs)
586  return mul.getRhs();
587 
588  if (mul.getRhs() == rhs)
589  return mul.getLhs();
590 
591  return {};
592 }
593 
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
595  // divui (x, 1) -> x.
596  if (matchPattern(adaptor.getRhs(), m_One()))
597  return getLhs();
598 
599  // (a * b) / b -> a
600  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
601  return val;
602 
603  // Don't fold if it would require a division by zero.
604  bool div0 = false;
605  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606  [&](APInt a, const APInt &b) {
607  if (div0 || !b) {
608  div0 = true;
609  return a;
610  }
611  return a.udiv(b);
612  });
613 
614  return div0 ? Attribute() : result;
615 }
616 
617 /// Returns whether an unsigned division by `divisor` is speculatable.
619  // X / 0 => UB
620  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
622 
624 }
625 
626 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
627  return getDivUISpeculatability(getRhs());
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // DivSIOp
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
635  // divsi (x, 1) -> x.
636  if (matchPattern(adaptor.getRhs(), m_One()))
637  return getLhs();
638 
639  // (a * b) / b -> a
640  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
641  return val;
642 
643  // Don't fold if it would overflow or if it requires a division by zero.
644  bool overflowOrDiv0 = false;
645  auto result = constFoldBinaryOp<IntegerAttr>(
646  adaptor.getOperands(), [&](APInt a, const APInt &b) {
647  if (overflowOrDiv0 || !b) {
648  overflowOrDiv0 = true;
649  return a;
650  }
651  return a.sdiv_ov(b, overflowOrDiv0);
652  });
653 
654  return overflowOrDiv0 ? Attribute() : result;
655 }
656 
657 /// Returns whether a signed division by `divisor` is speculatable. This
658 /// function conservatively assumes that all signed division by -1 are not
659 /// speculatable.
661  // X / 0 => UB
662  // INT_MIN / -1 => UB
663  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
666 
668 }
669 
670 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
671  return getDivSISpeculatability(getRhs());
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // Ceil and floor division folding helpers
676 //===----------------------------------------------------------------------===//
677 
678 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
679  bool &overflow) {
680  // Returns (a-1)/b + 1
681  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
682  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683  return val.sadd_ov(one, overflow);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // CeilDivUIOp
688 //===----------------------------------------------------------------------===//
689 
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
691  // ceildivui (x, 1) -> x.
692  if (matchPattern(adaptor.getRhs(), m_One()))
693  return getLhs();
694 
695  bool overflowOrDiv0 = false;
696  auto result = constFoldBinaryOp<IntegerAttr>(
697  adaptor.getOperands(), [&](APInt a, const APInt &b) {
698  if (overflowOrDiv0 || !b) {
699  overflowOrDiv0 = true;
700  return a;
701  }
702  APInt quotient = a.udiv(b);
703  if (!a.urem(b))
704  return quotient;
705  APInt one(a.getBitWidth(), 1, true);
706  return quotient.uadd_ov(one, overflowOrDiv0);
707  });
708 
709  return overflowOrDiv0 ? Attribute() : result;
710 }
711 
712 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
713  return getDivUISpeculatability(getRhs());
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // CeilDivSIOp
718 //===----------------------------------------------------------------------===//
719 
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
721  // ceildivsi (x, 1) -> x.
722  if (matchPattern(adaptor.getRhs(), m_One()))
723  return getLhs();
724 
725  // Don't fold if it would overflow or if it requires a division by zero.
726  // TODO: This hook won't fold operations where a = MININT, because
727  // negating MININT overflows. This can be improved.
728  bool overflowOrDiv0 = false;
729  auto result = constFoldBinaryOp<IntegerAttr>(
730  adaptor.getOperands(), [&](APInt a, const APInt &b) {
731  if (overflowOrDiv0 || !b) {
732  overflowOrDiv0 = true;
733  return a;
734  }
735  if (!a)
736  return a;
737  // After this point we know that neither a or b are zero.
738  unsigned bits = a.getBitWidth();
739  APInt zero = APInt::getZero(bits);
740  bool aGtZero = a.sgt(zero);
741  bool bGtZero = b.sgt(zero);
742  if (aGtZero && bGtZero) {
743  // Both positive, return ceil(a, b).
744  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
745  }
746 
747  // No folding happens if any of the intermediate arithmetic operations
748  // overflows.
749  bool overflowNegA = false;
750  bool overflowNegB = false;
751  bool overflowDiv = false;
752  bool overflowNegRes = false;
753  if (!aGtZero && !bGtZero) {
754  // Both negative, return ceil(-a, -b).
755  APInt posA = zero.ssub_ov(a, overflowNegA);
756  APInt posB = zero.ssub_ov(b, overflowNegB);
757  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
758  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
759  return res;
760  }
761  if (!aGtZero && bGtZero) {
762  // A is negative, b is positive, return - ( -a / b).
763  APInt posA = zero.ssub_ov(a, overflowNegA);
764  APInt div = posA.sdiv_ov(b, overflowDiv);
765  APInt res = zero.ssub_ov(div, overflowNegRes);
766  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
767  return res;
768  }
769  // A is positive, b is negative, return - (a / -b).
770  APInt posB = zero.ssub_ov(b, overflowNegB);
771  APInt div = a.sdiv_ov(posB, overflowDiv);
772  APInt res = zero.ssub_ov(div, overflowNegRes);
773 
774  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
775  return res;
776  });
777 
778  return overflowOrDiv0 ? Attribute() : result;
779 }
780 
781 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
782  return getDivSISpeculatability(getRhs());
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // FloorDivSIOp
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
790  // floordivsi (x, 1) -> x.
791  if (matchPattern(adaptor.getRhs(), m_One()))
792  return getLhs();
793 
794  // Don't fold if it would overflow or if it requires a division by zero.
795  bool overflowOrDiv = false;
796  auto result = constFoldBinaryOp<IntegerAttr>(
797  adaptor.getOperands(), [&](APInt a, const APInt &b) {
798  if (b.isZero()) {
799  overflowOrDiv = true;
800  return a;
801  }
802  return a.sfloordiv_ov(b, overflowOrDiv);
803  });
804 
805  return overflowOrDiv ? Attribute() : result;
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // RemUIOp
810 //===----------------------------------------------------------------------===//
811 
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
813  // remui (x, 1) -> 0.
814  if (matchPattern(adaptor.getRhs(), m_One()))
815  return Builder(getContext()).getZeroAttr(getType());
816 
817  // Don't fold if it would require a division by zero.
818  bool div0 = false;
819  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820  [&](APInt a, const APInt &b) {
821  if (div0 || b.isZero()) {
822  div0 = true;
823  return a;
824  }
825  return a.urem(b);
826  });
827 
828  return div0 ? Attribute() : result;
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // RemSIOp
833 //===----------------------------------------------------------------------===//
834 
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
836  // remsi (x, 1) -> 0.
837  if (matchPattern(adaptor.getRhs(), m_One()))
838  return Builder(getContext()).getZeroAttr(getType());
839 
840  // Don't fold if it would require a division by zero.
841  bool div0 = false;
842  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843  [&](APInt a, const APInt &b) {
844  if (div0 || b.isZero()) {
845  div0 = true;
846  return a;
847  }
848  return a.srem(b);
849  });
850 
851  return div0 ? Attribute() : result;
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // AndIOp
856 //===----------------------------------------------------------------------===//
857 
858 /// Fold `and(a, and(a, b))` to `and(a, b)`
859 static Value foldAndIofAndI(arith::AndIOp op) {
860  for (bool reversePrev : {false, true}) {
861  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862  .getDefiningOp<arith::AndIOp>();
863  if (!prev)
864  continue;
865 
866  Value other = (reversePrev ? op.getLhs() : op.getRhs());
867  if (other != prev.getLhs() && other != prev.getRhs())
868  continue;
869 
870  return prev.getResult();
871  }
872  return {};
873 }
874 
875 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
876  /// and(x, 0) -> 0
877  if (matchPattern(adaptor.getRhs(), m_Zero()))
878  return getRhs();
879  /// and(x, allOnes) -> x
880  APInt intValue;
881  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
882  intValue.isAllOnes())
883  return getLhs();
884  /// and(x, not(x)) -> 0
885  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
886  m_ConstantInt(&intValue))) &&
887  intValue.isAllOnes())
888  return Builder(getContext()).getZeroAttr(getType());
889  /// and(not(x), x) -> 0
890  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
891  m_ConstantInt(&intValue))) &&
892  intValue.isAllOnes())
893  return Builder(getContext()).getZeroAttr(getType());
894 
895  /// and(a, and(a, b)) -> and(a, b)
896  if (Value result = foldAndIofAndI(*this))
897  return result;
898 
899  return constFoldBinaryOp<IntegerAttr>(
900  adaptor.getOperands(),
901  [](APInt a, const APInt &b) { return std::move(a) & b; });
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // OrIOp
906 //===----------------------------------------------------------------------===//
907 
908 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
909  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
910  /// or(x, 0) -> x
911  if (rhsVal.isZero())
912  return getLhs();
913  /// or(x, <all ones>) -> <all ones>
914  if (rhsVal.isAllOnes())
915  return adaptor.getRhs();
916  }
917 
918  APInt intValue;
919  /// or(x, xor(x, 1)) -> 1
920  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
921  m_ConstantInt(&intValue))) &&
922  intValue.isAllOnes())
923  return getRhs().getDefiningOp<XOrIOp>().getRhs();
924  /// or(xor(x, 1), x) -> 1
925  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
926  m_ConstantInt(&intValue))) &&
927  intValue.isAllOnes())
928  return getLhs().getDefiningOp<XOrIOp>().getRhs();
929 
930  return constFoldBinaryOp<IntegerAttr>(
931  adaptor.getOperands(),
932  [](APInt a, const APInt &b) { return std::move(a) | b; });
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // XOrIOp
937 //===----------------------------------------------------------------------===//
938 
939 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
940  /// xor(x, 0) -> x
941  if (matchPattern(adaptor.getRhs(), m_Zero()))
942  return getLhs();
943  /// xor(x, x) -> 0
944  if (getLhs() == getRhs())
945  return Builder(getContext()).getZeroAttr(getType());
946  /// xor(xor(x, a), a) -> x
947  /// xor(xor(a, x), a) -> x
948  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949  if (prev.getRhs() == getRhs())
950  return prev.getLhs();
951  if (prev.getLhs() == getRhs())
952  return prev.getRhs();
953  }
954  /// xor(a, xor(x, a)) -> x
955  /// xor(a, xor(a, x)) -> x
956  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957  if (prev.getRhs() == getLhs())
958  return prev.getLhs();
959  if (prev.getLhs() == getLhs())
960  return prev.getRhs();
961  }
962 
963  return constFoldBinaryOp<IntegerAttr>(
964  adaptor.getOperands(),
965  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
966 }
967 
968 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
969  MLIRContext *context) {
970  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // NegFOp
975 //===----------------------------------------------------------------------===//
976 
977 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
978  /// negf(negf(x)) -> x
979  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980  return op.getOperand();
981  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982  [](const APFloat &a) { return -a; });
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // AddFOp
987 //===----------------------------------------------------------------------===//
988 
989 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
990  // addf(x, -0) -> x
991  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
992  return getLhs();
993 
994  return constFoldBinaryOp<FloatAttr>(
995  adaptor.getOperands(),
996  [](const APFloat &a, const APFloat &b) { return a + b; });
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // SubFOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1004  // subf(x, +0) -> x
1005  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1006  return getLhs();
1007 
1008  return constFoldBinaryOp<FloatAttr>(
1009  adaptor.getOperands(),
1010  [](const APFloat &a, const APFloat &b) { return a - b; });
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // MaximumFOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1018  // maximumf(x,x) -> x
1019  if (getLhs() == getRhs())
1020  return getRhs();
1021 
1022  // maximumf(x, -inf) -> x
1023  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1024  return getLhs();
1025 
1026  return constFoldBinaryOp<FloatAttr>(
1027  adaptor.getOperands(),
1028  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // MaxNumFOp
1033 //===----------------------------------------------------------------------===//
1034 
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1036  // maxnumf(x,x) -> x
1037  if (getLhs() == getRhs())
1038  return getRhs();
1039 
1040  // maxnumf(x, NaN) -> x
1041  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1042  return getLhs();
1043 
1044  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // MaxSIOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1052  // maxsi(x,x) -> x
1053  if (getLhs() == getRhs())
1054  return getRhs();
1055 
1056  if (APInt intValue;
1057  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1058  // maxsi(x,MAX_INT) -> MAX_INT
1059  if (intValue.isMaxSignedValue())
1060  return getRhs();
1061  // maxsi(x, MIN_INT) -> x
1062  if (intValue.isMinSignedValue())
1063  return getLhs();
1064  }
1065 
1066  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067  [](const APInt &a, const APInt &b) {
1068  return llvm::APIntOps::smax(a, b);
1069  });
1070 }
1071 
1072 //===----------------------------------------------------------------------===//
1073 // MaxUIOp
1074 //===----------------------------------------------------------------------===//
1075 
1076 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1077  // maxui(x,x) -> x
1078  if (getLhs() == getRhs())
1079  return getRhs();
1080 
1081  if (APInt intValue;
1082  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1083  // maxui(x,MAX_INT) -> MAX_INT
1084  if (intValue.isMaxValue())
1085  return getRhs();
1086  // maxui(x, MIN_INT) -> x
1087  if (intValue.isMinValue())
1088  return getLhs();
1089  }
1090 
1091  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092  [](const APInt &a, const APInt &b) {
1093  return llvm::APIntOps::umax(a, b);
1094  });
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // MinimumFOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1102  // minimumf(x,x) -> x
1103  if (getLhs() == getRhs())
1104  return getRhs();
1105 
1106  // minimumf(x, +inf) -> x
1107  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1108  return getLhs();
1109 
1110  return constFoldBinaryOp<FloatAttr>(
1111  adaptor.getOperands(),
1112  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // MinNumFOp
1117 //===----------------------------------------------------------------------===//
1118 
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1120  // minnumf(x,x) -> x
1121  if (getLhs() == getRhs())
1122  return getRhs();
1123 
1124  // minnumf(x, NaN) -> x
1125  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1126  return getLhs();
1127 
1128  return constFoldBinaryOp<FloatAttr>(
1129  adaptor.getOperands(),
1130  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // MinSIOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1138  // minsi(x,x) -> x
1139  if (getLhs() == getRhs())
1140  return getRhs();
1141 
1142  if (APInt intValue;
1143  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1144  // minsi(x,MIN_INT) -> MIN_INT
1145  if (intValue.isMinSignedValue())
1146  return getRhs();
1147  // minsi(x, MAX_INT) -> x
1148  if (intValue.isMaxSignedValue())
1149  return getLhs();
1150  }
1151 
1152  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153  [](const APInt &a, const APInt &b) {
1154  return llvm::APIntOps::smin(a, b);
1155  });
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // MinUIOp
1160 //===----------------------------------------------------------------------===//
1161 
1162 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1163  // minui(x,x) -> x
1164  if (getLhs() == getRhs())
1165  return getRhs();
1166 
1167  if (APInt intValue;
1168  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1169  // minui(x,MIN_INT) -> MIN_INT
1170  if (intValue.isMinValue())
1171  return getRhs();
1172  // minui(x, MAX_INT) -> x
1173  if (intValue.isMaxValue())
1174  return getLhs();
1175  }
1176 
1177  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178  [](const APInt &a, const APInt &b) {
1179  return llvm::APIntOps::umin(a, b);
1180  });
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // MulFOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1188  // mulf(x, 1) -> x
1189  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1190  return getLhs();
1191 
1192  return constFoldBinaryOp<FloatAttr>(
1193  adaptor.getOperands(),
1194  [](const APFloat &a, const APFloat &b) { return a * b; });
1195 }
1196 
1197 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1198  MLIRContext *context) {
1199  patterns.add<MulFOfNegF>(context);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // DivFOp
1204 //===----------------------------------------------------------------------===//
1205 
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1207  // divf(x, 1) -> x
1208  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1209  return getLhs();
1210 
1211  return constFoldBinaryOp<FloatAttr>(
1212  adaptor.getOperands(),
1213  [](const APFloat &a, const APFloat &b) { return a / b; });
1214 }
1215 
1216 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1217  MLIRContext *context) {
1218  patterns.add<DivFOfNegF>(context);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // RemFOp
1223 //===----------------------------------------------------------------------===//
1224 
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227  [](const APFloat &a, const APFloat &b) {
1228  APFloat result(a);
1229  // APFloat::mod() offers the remainder
1230  // behavior we want, i.e. the result has
1231  // the sign of LHS operand.
1232  (void)result.mod(b);
1233  return result;
1234  });
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Utility functions for verifying cast ops
1239 //===----------------------------------------------------------------------===//
1240 
1241 template <typename... Types>
1242 using type_list = std::tuple<Types...> *;
1243 
1244 /// Returns a non-null type only if the provided type is one of the allowed
1245 /// types or one of the allowed shaped types of the allowed types. Returns the
1246 /// element type if a valid shaped type is provided.
1247 template <typename... ShapedTypes, typename... ElementTypes>
1250  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1251  return {};
1252 
1253  auto underlyingType = getElementTypeOrSelf(type);
1254  if (!llvm::isa<ElementTypes...>(underlyingType))
1255  return {};
1256 
1257  return underlyingType;
1258 }
1259 
1260 /// Get allowed underlying types for vectors and tensors.
1261 template <typename... ElementTypes>
1262 static Type getTypeIfLike(Type type) {
1265 }
1266 
1267 /// Get allowed underlying types for vectors, tensors, and memrefs.
1268 template <typename... ElementTypes>
1270  return getUnderlyingType(type,
1273 }
1274 
1275 /// Return false if both types are ranked tensor with mismatching encoding.
1276 static bool hasSameEncoding(Type typeA, Type typeB) {
1277  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279  if (!rankedTensorA || !rankedTensorB)
1280  return true;
1281  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1282 }
1283 
1284 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1285  if (inputs.size() != 1 || outputs.size() != 1)
1286  return false;
1287  if (!hasSameEncoding(inputs.front(), outputs.front()))
1288  return false;
1289  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // Verifiers for integer and floating point extension/truncation ops
1294 //===----------------------------------------------------------------------===//
1295 
1296 // Extend ops can only extend to a wider type.
1297 template <typename ValType, typename Op>
1298 static LogicalResult verifyExtOp(Op op) {
1299  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1300  Type dstType = getElementTypeOrSelf(op.getType());
1301 
1302  if (llvm::cast<ValType>(srcType).getWidth() >=
1303  llvm::cast<ValType>(dstType).getWidth())
1304  return op.emitError("result type ")
1305  << dstType << " must be wider than operand type " << srcType;
1306 
1307  return success();
1308 }
1309 
1310 // Truncate ops can only truncate to a shorter type.
1311 template <typename ValType, typename Op>
1312 static LogicalResult verifyTruncateOp(Op op) {
1313  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1314  Type dstType = getElementTypeOrSelf(op.getType());
1315 
1316  if (llvm::cast<ValType>(srcType).getWidth() <=
1317  llvm::cast<ValType>(dstType).getWidth())
1318  return op.emitError("result type ")
1319  << dstType << " must be shorter than operand type " << srcType;
1320 
1321  return success();
1322 }
1323 
1324 /// Validate a cast that changes the width of a type.
1325 template <template <typename> class WidthComparator, typename... ElementTypes>
1326 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1327  if (!areValidCastInputsAndOutputs(inputs, outputs))
1328  return false;
1329 
1330  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1331  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1332  if (!srcType || !dstType)
1333  return false;
1334 
1335  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336  srcType.getIntOrFloatBitWidth());
1337 }
1338 
1339 /// Attempts to convert `sourceValue` to an APFloat value with
1340 /// `targetSemantics` and `roundingMode`, without any information loss.
1341 static FailureOr<APFloat> convertFloatValue(
1342  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1343  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344  bool losesInfo = false;
1345  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346  if (losesInfo || status != APFloat::opOK)
1347  return failure();
1348 
1349  return sourceValue;
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // ExtUIOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358  getInMutable().assign(lhs.getIn());
1359  return getResult();
1360  }
1361 
1362  Type resType = getElementTypeOrSelf(getType());
1363  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365  adaptor.getOperands(), getType(),
1366  [bitWidth](const APInt &a, bool &castStatus) {
1367  return a.zext(bitWidth);
1368  });
1369 }
1370 
1371 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1372  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1373 }
1374 
1375 LogicalResult arith::ExtUIOp::verify() {
1376  return verifyExtOp<IntegerType>(*this);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ExtSIOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385  getInMutable().assign(lhs.getIn());
1386  return getResult();
1387  }
1388 
1389  Type resType = getElementTypeOrSelf(getType());
1390  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392  adaptor.getOperands(), getType(),
1393  [bitWidth](const APInt &a, bool &castStatus) {
1394  return a.sext(bitWidth);
1395  });
1396 }
1397 
1398 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1400 }
1401 
1402 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403  MLIRContext *context) {
1404  patterns.add<ExtSIOfExtUI>(context);
1405 }
1406 
1407 LogicalResult arith::ExtSIOp::verify() {
1408  return verifyExtOp<IntegerType>(*this);
1409 }
1410 
1411 //===----------------------------------------------------------------------===//
1412 // ExtFOp
1413 //===----------------------------------------------------------------------===//
1414 
1415 /// Fold extension of float constants when there is no information loss due the
1416 /// difference in fp semantics.
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419  if (truncFOp.getOperand().getType() == getType()) {
1420  arith::FastMathFlags truncFMF =
1421  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422  bool isTruncContract =
1423  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1424  arith::FastMathFlags extFMF =
1425  getFastmath().value_or(arith::FastMathFlags::none);
1426  bool isExtContract =
1427  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1428  if (isTruncContract && isExtContract) {
1429  return truncFOp.getOperand();
1430  }
1431  }
1432  }
1433 
1434  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1435  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436  return constFoldCastOp<FloatAttr, FloatAttr>(
1437  adaptor.getOperands(), getType(),
1438  [&targetSemantics](const APFloat &a, bool &castStatus) {
1439  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1440  if (failed(result)) {
1441  castStatus = false;
1442  return a;
1443  }
1444  return *result;
1445  });
1446 }
1447 
1448 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1449  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1450 }
1451 
1452 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // TruncIOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1459  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1460  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1461  Value src = getOperand().getDefiningOp()->getOperand(0);
1462  Type srcType = getElementTypeOrSelf(src.getType());
1463  Type dstType = getElementTypeOrSelf(getType());
1464  // trunci(zexti(a)) -> trunci(a)
1465  // trunci(sexti(a)) -> trunci(a)
1466  if (llvm::cast<IntegerType>(srcType).getWidth() >
1467  llvm::cast<IntegerType>(dstType).getWidth()) {
1468  setOperand(src);
1469  return getResult();
1470  }
1471 
1472  // trunci(zexti(a)) -> a
1473  // trunci(sexti(a)) -> a
1474  if (srcType == dstType)
1475  return src;
1476  }
1477 
1478  // trunci(trunci(a)) -> trunci(a))
1479  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1480  setOperand(getOperand().getDefiningOp()->getOperand(0));
1481  return getResult();
1482  }
1483 
1484  Type resType = getElementTypeOrSelf(getType());
1485  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1486  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1487  adaptor.getOperands(), getType(),
1488  [bitWidth](const APInt &a, bool &castStatus) {
1489  return a.trunc(bitWidth);
1490  });
1491 }
1492 
1493 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1494  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1495 }
1496 
1497 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1498  MLIRContext *context) {
1499  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1500  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1501  context);
1502 }
1503 
1504 LogicalResult arith::TruncIOp::verify() {
1505  return verifyTruncateOp<IntegerType>(*this);
1506 }
1507 
1508 //===----------------------------------------------------------------------===//
1509 // TruncFOp
1510 //===----------------------------------------------------------------------===//
1511 
1512 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1513 /// can be represented without precision loss.
1514 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1515  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1516  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1517  Value src = extOp.getIn();
1518  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1519  auto intermediateType =
1520  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1521  // Check if the srcType is representable in the intermediateType.
1522  if (llvm::APFloatBase::isRepresentableBy(
1523  srcType.getFloatSemantics(),
1524  intermediateType.getFloatSemantics())) {
1525  // truncf(extf(a)) -> truncf(a)
1526  if (srcType.getWidth() > resElemType.getWidth()) {
1527  setOperand(src);
1528  return getResult();
1529  }
1530 
1531  // truncf(extf(a)) -> a
1532  if (srcType == resElemType)
1533  return src;
1534  }
1535  }
1536 
1537  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1538  return constFoldCastOp<FloatAttr, FloatAttr>(
1539  adaptor.getOperands(), getType(),
1540  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1541  RoundingMode roundingMode =
1542  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1543  llvm::RoundingMode llvmRoundingMode =
1544  convertArithRoundingModeToLLVMIR(roundingMode);
1545  FailureOr<APFloat> result =
1546  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1547  if (failed(result)) {
1548  castStatus = false;
1549  return a;
1550  }
1551  return *result;
1552  });
1553 }
1554 
1555 void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1556  MLIRContext *context) {
1557  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1558 }
1559 
1560 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1561  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1562 }
1563 
1564 LogicalResult arith::TruncFOp::verify() {
1565  return verifyTruncateOp<FloatType>(*this);
1566 }
1567 
1568 //===----------------------------------------------------------------------===//
1569 // AndIOp
1570 //===----------------------------------------------------------------------===//
1571 
1572 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1573  MLIRContext *context) {
1574  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1575 }
1576 
1577 //===----------------------------------------------------------------------===//
1578 // OrIOp
1579 //===----------------------------------------------------------------------===//
1580 
1581 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1582  MLIRContext *context) {
1583  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 // Verifiers for casts between integers and floats.
1588 //===----------------------------------------------------------------------===//
1589 
1590 template <typename From, typename To>
1591 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1592  if (!areValidCastInputsAndOutputs(inputs, outputs))
1593  return false;
1594 
1595  auto srcType = getTypeIfLike<From>(inputs.front());
1596  auto dstType = getTypeIfLike<To>(outputs.back());
1597 
1598  return srcType && dstType;
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // UIToFPOp
1603 //===----------------------------------------------------------------------===//
1604 
1605 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1606  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1607 }
1608 
1609 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1610  Type resEleType = getElementTypeOrSelf(getType());
1611  return constFoldCastOp<IntegerAttr, FloatAttr>(
1612  adaptor.getOperands(), getType(),
1613  [&resEleType](const APInt &a, bool &castStatus) {
1614  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1615  APFloat apf(floatTy.getFloatSemantics(),
1616  APInt::getZero(floatTy.getWidth()));
1617  apf.convertFromAPInt(a, /*IsSigned=*/false,
1618  APFloat::rmNearestTiesToEven);
1619  return apf;
1620  });
1621 }
1622 
1623 //===----------------------------------------------------------------------===//
1624 // SIToFPOp
1625 //===----------------------------------------------------------------------===//
1626 
1627 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1628  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1629 }
1630 
1631 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1632  Type resEleType = getElementTypeOrSelf(getType());
1633  return constFoldCastOp<IntegerAttr, FloatAttr>(
1634  adaptor.getOperands(), getType(),
1635  [&resEleType](const APInt &a, bool &castStatus) {
1636  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1637  APFloat apf(floatTy.getFloatSemantics(),
1638  APInt::getZero(floatTy.getWidth()));
1639  apf.convertFromAPInt(a, /*IsSigned=*/true,
1640  APFloat::rmNearestTiesToEven);
1641  return apf;
1642  });
1643 }
1644 
1645 //===----------------------------------------------------------------------===//
1646 // FPToUIOp
1647 //===----------------------------------------------------------------------===//
1648 
1649 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1650  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1651 }
1652 
1653 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1654  Type resType = getElementTypeOrSelf(getType());
1655  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1656  return constFoldCastOp<FloatAttr, IntegerAttr>(
1657  adaptor.getOperands(), getType(),
1658  [&bitWidth](const APFloat &a, bool &castStatus) {
1659  bool ignored;
1660  APSInt api(bitWidth, /*isUnsigned=*/true);
1661  castStatus = APFloat::opInvalidOp !=
1662  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1663  return api;
1664  });
1665 }
1666 
1667 //===----------------------------------------------------------------------===//
1668 // FPToSIOp
1669 //===----------------------------------------------------------------------===//
1670 
1671 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1672  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1673 }
1674 
1675 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1676  Type resType = getElementTypeOrSelf(getType());
1677  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1678  return constFoldCastOp<FloatAttr, IntegerAttr>(
1679  adaptor.getOperands(), getType(),
1680  [&bitWidth](const APFloat &a, bool &castStatus) {
1681  bool ignored;
1682  APSInt api(bitWidth, /*isUnsigned=*/false);
1683  castStatus = APFloat::opInvalidOp !=
1684  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1685  return api;
1686  });
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // IndexCastOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1694  if (!areValidCastInputsAndOutputs(inputs, outputs))
1695  return false;
1696 
1697  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1698  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1699  if (!srcType || !dstType)
1700  return false;
1701 
1702  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1703  (srcType.isSignlessInteger() && dstType.isIndex());
1704 }
1705 
1706 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1707  TypeRange outputs) {
1708  return areIndexCastCompatible(inputs, outputs);
1709 }
1710 
1711 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1712  // index_cast(constant) -> constant
1713  unsigned resultBitwidth = 64; // Default for index integer attributes.
1714  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1715  resultBitwidth = intTy.getWidth();
1716 
1717  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1718  adaptor.getOperands(), getType(),
1719  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1720  return a.sextOrTrunc(resultBitwidth);
1721  });
1722 }
1723 
1724 void arith::IndexCastOp::getCanonicalizationPatterns(
1725  RewritePatternSet &patterns, MLIRContext *context) {
1726  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // IndexCastUIOp
1731 //===----------------------------------------------------------------------===//
1732 
1733 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1734  TypeRange outputs) {
1735  return areIndexCastCompatible(inputs, outputs);
1736 }
1737 
1738 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1739  // index_castui(constant) -> constant
1740  unsigned resultBitwidth = 64; // Default for index integer attributes.
1741  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1742  resultBitwidth = intTy.getWidth();
1743 
1744  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1745  adaptor.getOperands(), getType(),
1746  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1747  return a.zextOrTrunc(resultBitwidth);
1748  });
1749 }
1750 
1751 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1752  RewritePatternSet &patterns, MLIRContext *context) {
1753  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // BitcastOp
1758 //===----------------------------------------------------------------------===//
1759 
1760 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1761  if (!areValidCastInputsAndOutputs(inputs, outputs))
1762  return false;
1763 
1764  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1765  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1766  if (!srcType || !dstType)
1767  return false;
1768 
1769  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1770 }
1771 
1772 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1773  auto resType = getType();
1774  auto operand = adaptor.getIn();
1775  if (!operand)
1776  return {};
1777 
1778  /// Bitcast dense elements.
1779  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1780  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1781  /// Other shaped types unhandled.
1782  if (llvm::isa<ShapedType>(resType))
1783  return {};
1784 
1785  /// Bitcast poison.
1786  if (llvm::isa<ub::PoisonAttr>(operand))
1787  return ub::PoisonAttr::get(getContext());
1788 
1789  /// Bitcast integer or float to integer or float.
1790  APInt bits = llvm::isa<FloatAttr>(operand)
1791  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1792  : llvm::cast<IntegerAttr>(operand).getValue();
1793  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1794  "trying to fold on broken IR: operands have incompatible types");
1795 
1796  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1797  return FloatAttr::get(resType,
1798  APFloat(resFloatType.getFloatSemantics(), bits));
1799  return IntegerAttr::get(resType, bits);
1800 }
1801 
1802 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1803  MLIRContext *context) {
1804  patterns.add<BitcastOfBitcast>(context);
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // CmpIOp
1809 //===----------------------------------------------------------------------===//
1810 
1811 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1812 /// comparison predicates.
1813 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1814  const APInt &lhs, const APInt &rhs) {
1815  switch (predicate) {
1816  case arith::CmpIPredicate::eq:
1817  return lhs.eq(rhs);
1818  case arith::CmpIPredicate::ne:
1819  return lhs.ne(rhs);
1820  case arith::CmpIPredicate::slt:
1821  return lhs.slt(rhs);
1822  case arith::CmpIPredicate::sle:
1823  return lhs.sle(rhs);
1824  case arith::CmpIPredicate::sgt:
1825  return lhs.sgt(rhs);
1826  case arith::CmpIPredicate::sge:
1827  return lhs.sge(rhs);
1828  case arith::CmpIPredicate::ult:
1829  return lhs.ult(rhs);
1830  case arith::CmpIPredicate::ule:
1831  return lhs.ule(rhs);
1832  case arith::CmpIPredicate::ugt:
1833  return lhs.ugt(rhs);
1834  case arith::CmpIPredicate::uge:
1835  return lhs.uge(rhs);
1836  }
1837  llvm_unreachable("unknown cmpi predicate kind");
1838 }
1839 
1840 /// Returns true if the predicate is true for two equal operands.
1841 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1842  switch (predicate) {
1843  case arith::CmpIPredicate::eq:
1844  case arith::CmpIPredicate::sle:
1845  case arith::CmpIPredicate::sge:
1846  case arith::CmpIPredicate::ule:
1847  case arith::CmpIPredicate::uge:
1848  return true;
1849  case arith::CmpIPredicate::ne:
1850  case arith::CmpIPredicate::slt:
1851  case arith::CmpIPredicate::sgt:
1852  case arith::CmpIPredicate::ult:
1853  case arith::CmpIPredicate::ugt:
1854  return false;
1855  }
1856  llvm_unreachable("unknown cmpi predicate kind");
1857 }
1858 
1859 static std::optional<int64_t> getIntegerWidth(Type t) {
1860  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1861  return intType.getWidth();
1862  }
1863  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1864  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1865  }
1866  return std::nullopt;
1867 }
1868 
1869 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1870  // cmpi(pred, x, x)
1871  if (getLhs() == getRhs()) {
1872  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1873  return getBoolAttribute(getType(), val);
1874  }
1875 
1876  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1877  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1878  // extsi(%x : i1 -> iN) != 0 -> %x
1879  std::optional<int64_t> integerWidth =
1880  getIntegerWidth(extOp.getOperand().getType());
1881  if (integerWidth && integerWidth.value() == 1 &&
1882  getPredicate() == arith::CmpIPredicate::ne)
1883  return extOp.getOperand();
1884  }
1885  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1886  // extui(%x : i1 -> iN) != 0 -> %x
1887  std::optional<int64_t> integerWidth =
1888  getIntegerWidth(extOp.getOperand().getType());
1889  if (integerWidth && integerWidth.value() == 1 &&
1890  getPredicate() == arith::CmpIPredicate::ne)
1891  return extOp.getOperand();
1892  }
1893 
1894  // arith.cmpi ne, %val, %zero : i1 -> %val
1895  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1896  getPredicate() == arith::CmpIPredicate::ne)
1897  return getLhs();
1898  }
1899 
1900  if (matchPattern(adaptor.getRhs(), m_One())) {
1901  // arith.cmpi eq, %val, %one : i1 -> %val
1902  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1903  getPredicate() == arith::CmpIPredicate::eq)
1904  return getLhs();
1905  }
1906 
1907  // Move constant to the right side.
1908  if (adaptor.getLhs() && !adaptor.getRhs()) {
1909  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1910  using Pred = CmpIPredicate;
1911  const std::pair<Pred, Pred> invPreds[] = {
1912  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1913  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1914  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1915  {Pred::ne, Pred::ne},
1916  };
1917  Pred origPred = getPredicate();
1918  for (auto pred : invPreds) {
1919  if (origPred == pred.first) {
1920  setPredicate(pred.second);
1921  Value lhs = getLhs();
1922  Value rhs = getRhs();
1923  getLhsMutable().assign(rhs);
1924  getRhsMutable().assign(lhs);
1925  return getResult();
1926  }
1927  }
1928  llvm_unreachable("unknown cmpi predicate kind");
1929  }
1930 
1931  // We are moving constants to the right side; So if lhs is constant rhs is
1932  // guaranteed to be a constant.
1933  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1934  return constFoldBinaryOp<IntegerAttr>(
1935  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1936  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1937  return APInt(1,
1938  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1939  });
1940  }
1941 
1942  return {};
1943 }
1944 
1945 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1946  MLIRContext *context) {
1947  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // CmpFOp
1952 //===----------------------------------------------------------------------===//
1953 
1954 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1955 /// comparison predicates.
1956 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1957  const APFloat &lhs, const APFloat &rhs) {
1958  auto cmpResult = lhs.compare(rhs);
1959  switch (predicate) {
1960  case arith::CmpFPredicate::AlwaysFalse:
1961  return false;
1962  case arith::CmpFPredicate::OEQ:
1963  return cmpResult == APFloat::cmpEqual;
1964  case arith::CmpFPredicate::OGT:
1965  return cmpResult == APFloat::cmpGreaterThan;
1966  case arith::CmpFPredicate::OGE:
1967  return cmpResult == APFloat::cmpGreaterThan ||
1968  cmpResult == APFloat::cmpEqual;
1969  case arith::CmpFPredicate::OLT:
1970  return cmpResult == APFloat::cmpLessThan;
1971  case arith::CmpFPredicate::OLE:
1972  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1973  case arith::CmpFPredicate::ONE:
1974  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1975  case arith::CmpFPredicate::ORD:
1976  return cmpResult != APFloat::cmpUnordered;
1977  case arith::CmpFPredicate::UEQ:
1978  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1979  case arith::CmpFPredicate::UGT:
1980  return cmpResult == APFloat::cmpUnordered ||
1981  cmpResult == APFloat::cmpGreaterThan;
1982  case arith::CmpFPredicate::UGE:
1983  return cmpResult == APFloat::cmpUnordered ||
1984  cmpResult == APFloat::cmpGreaterThan ||
1985  cmpResult == APFloat::cmpEqual;
1986  case arith::CmpFPredicate::ULT:
1987  return cmpResult == APFloat::cmpUnordered ||
1988  cmpResult == APFloat::cmpLessThan;
1989  case arith::CmpFPredicate::ULE:
1990  return cmpResult == APFloat::cmpUnordered ||
1991  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1992  case arith::CmpFPredicate::UNE:
1993  return cmpResult != APFloat::cmpEqual;
1994  case arith::CmpFPredicate::UNO:
1995  return cmpResult == APFloat::cmpUnordered;
1996  case arith::CmpFPredicate::AlwaysTrue:
1997  return true;
1998  }
1999  llvm_unreachable("unknown cmpf predicate kind");
2000 }
2001 
2002 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2003  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2004  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2005 
2006  // If one operand is NaN, making them both NaN does not change the result.
2007  if (lhs && lhs.getValue().isNaN())
2008  rhs = lhs;
2009  if (rhs && rhs.getValue().isNaN())
2010  lhs = rhs;
2011 
2012  if (!lhs || !rhs)
2013  return {};
2014 
2015  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2016  return BoolAttr::get(getContext(), val);
2017 }
2018 
2019 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2020 public:
2022 
2023  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2024  bool isUnsigned) {
2025  using namespace arith;
2026  switch (pred) {
2027  case CmpFPredicate::UEQ:
2028  case CmpFPredicate::OEQ:
2029  return CmpIPredicate::eq;
2030  case CmpFPredicate::UGT:
2031  case CmpFPredicate::OGT:
2032  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2033  case CmpFPredicate::UGE:
2034  case CmpFPredicate::OGE:
2035  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2036  case CmpFPredicate::ULT:
2037  case CmpFPredicate::OLT:
2038  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2039  case CmpFPredicate::ULE:
2040  case CmpFPredicate::OLE:
2041  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2042  case CmpFPredicate::UNE:
2043  case CmpFPredicate::ONE:
2044  return CmpIPredicate::ne;
2045  default:
2046  llvm_unreachable("Unexpected predicate!");
2047  }
2048  }
2049 
2050  LogicalResult matchAndRewrite(CmpFOp op,
2051  PatternRewriter &rewriter) const override {
2052  FloatAttr flt;
2053  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2054  return failure();
2055 
2056  const APFloat &rhs = flt.getValue();
2057 
2058  // Don't attempt to fold a nan.
2059  if (rhs.isNaN())
2060  return failure();
2061 
2062  // Get the width of the mantissa. We don't want to hack on conversions that
2063  // might lose information from the integer, e.g. "i64 -> float"
2064  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2065  int mantissaWidth = floatTy.getFPMantissaWidth();
2066  if (mantissaWidth <= 0)
2067  return failure();
2068 
2069  bool isUnsigned;
2070  Value intVal;
2071 
2072  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2073  isUnsigned = false;
2074  intVal = si.getIn();
2075  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2076  isUnsigned = true;
2077  intVal = ui.getIn();
2078  } else {
2079  return failure();
2080  }
2081 
2082  // Check to see that the input is converted from an integer type that is
2083  // small enough that preserves all bits.
2084  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2085  auto intWidth = intTy.getWidth();
2086 
2087  // Number of bits representing values, as opposed to the sign
2088  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2089 
2090  // Following test does NOT adjust intWidth downwards for signed inputs,
2091  // because the most negative value still requires all the mantissa bits
2092  // to distinguish it from one less than that value.
2093  if ((int)intWidth > mantissaWidth) {
2094  // Conversion would lose accuracy. Check if loss can impact comparison.
2095  int exponent = ilogb(rhs);
2096  if (exponent == APFloat::IEK_Inf) {
2097  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2098  if (maxExponent < (int)valueBits) {
2099  // Conversion could create infinity.
2100  return failure();
2101  }
2102  } else {
2103  // Note that if rhs is zero or NaN, then Exp is negative
2104  // and first condition is trivially false.
2105  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2106  // Conversion could affect comparison.
2107  return failure();
2108  }
2109  }
2110  }
2111 
2112  // Convert to equivalent cmpi predicate
2113  CmpIPredicate pred;
2114  switch (op.getPredicate()) {
2115  case CmpFPredicate::ORD:
2116  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2117  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2118  /*width=*/1);
2119  return success();
2120  case CmpFPredicate::UNO:
2121  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2122  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2123  /*width=*/1);
2124  return success();
2125  default:
2126  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2127  break;
2128  }
2129 
2130  if (!isUnsigned) {
2131  // If the rhs value is > SignedMax, fold the comparison. This handles
2132  // +INF and large values.
2133  APFloat signedMax(rhs.getSemantics());
2134  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2135  APFloat::rmNearestTiesToEven);
2136  if (signedMax < rhs) { // smax < 13123.0
2137  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2138  pred == CmpIPredicate::sle)
2139  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2140  /*width=*/1);
2141  else
2142  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2143  /*width=*/1);
2144  return success();
2145  }
2146  } else {
2147  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2148  // +INF and large values.
2149  APFloat unsignedMax(rhs.getSemantics());
2150  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2151  APFloat::rmNearestTiesToEven);
2152  if (unsignedMax < rhs) { // umax < 13123.0
2153  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2154  pred == CmpIPredicate::ule)
2155  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2156  /*width=*/1);
2157  else
2158  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2159  /*width=*/1);
2160  return success();
2161  }
2162  }
2163 
2164  if (!isUnsigned) {
2165  // See if the rhs value is < SignedMin.
2166  APFloat signedMin(rhs.getSemantics());
2167  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2168  APFloat::rmNearestTiesToEven);
2169  if (signedMin > rhs) { // smin > 12312.0
2170  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2171  pred == CmpIPredicate::sge)
2172  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2173  /*width=*/1);
2174  else
2175  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2176  /*width=*/1);
2177  return success();
2178  }
2179  } else {
2180  // See if the rhs value is < UnsignedMin.
2181  APFloat unsignedMin(rhs.getSemantics());
2182  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2183  APFloat::rmNearestTiesToEven);
2184  if (unsignedMin > rhs) { // umin > 12312.0
2185  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2186  pred == CmpIPredicate::uge)
2187  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2188  /*width=*/1);
2189  else
2190  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2191  /*width=*/1);
2192  return success();
2193  }
2194  }
2195 
2196  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2197  // [0, UMAX], but it may still be fractional. See if it is fractional by
2198  // casting the FP value to the integer value and back, checking for
2199  // equality. Don't do this for zero, because -0.0 is not fractional.
2200  bool ignored;
2201  APSInt rhsInt(intWidth, isUnsigned);
2202  if (APFloat::opInvalidOp ==
2203  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2204  // Undefined behavior invoked - the destination type can't represent
2205  // the input constant.
2206  return failure();
2207  }
2208 
2209  if (!rhs.isZero()) {
2210  APFloat apf(floatTy.getFloatSemantics(),
2211  APInt::getZero(floatTy.getWidth()));
2212  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2213 
2214  bool equal = apf == rhs;
2215  if (!equal) {
2216  // If we had a comparison against a fractional value, we have to adjust
2217  // the compare predicate and sometimes the value. rhsInt is rounded
2218  // towards zero at this point.
2219  switch (pred) {
2220  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2221  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2222  /*width=*/1);
2223  return success();
2224  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2225  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2226  /*width=*/1);
2227  return success();
2228  case CmpIPredicate::ule:
2229  // (float)int <= 4.4 --> int <= 4
2230  // (float)int <= -4.4 --> false
2231  if (rhs.isNegative()) {
2232  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2233  /*width=*/1);
2234  return success();
2235  }
2236  break;
2237  case CmpIPredicate::sle:
2238  // (float)int <= 4.4 --> int <= 4
2239  // (float)int <= -4.4 --> int < -4
2240  if (rhs.isNegative())
2241  pred = CmpIPredicate::slt;
2242  break;
2243  case CmpIPredicate::ult:
2244  // (float)int < -4.4 --> false
2245  // (float)int < 4.4 --> int <= 4
2246  if (rhs.isNegative()) {
2247  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2248  /*width=*/1);
2249  return success();
2250  }
2251  pred = CmpIPredicate::ule;
2252  break;
2253  case CmpIPredicate::slt:
2254  // (float)int < -4.4 --> int < -4
2255  // (float)int < 4.4 --> int <= 4
2256  if (!rhs.isNegative())
2257  pred = CmpIPredicate::sle;
2258  break;
2259  case CmpIPredicate::ugt:
2260  // (float)int > 4.4 --> int > 4
2261  // (float)int > -4.4 --> true
2262  if (rhs.isNegative()) {
2263  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2264  /*width=*/1);
2265  return success();
2266  }
2267  break;
2268  case CmpIPredicate::sgt:
2269  // (float)int > 4.4 --> int > 4
2270  // (float)int > -4.4 --> int >= -4
2271  if (rhs.isNegative())
2272  pred = CmpIPredicate::sge;
2273  break;
2274  case CmpIPredicate::uge:
2275  // (float)int >= -4.4 --> true
2276  // (float)int >= 4.4 --> int > 4
2277  if (rhs.isNegative()) {
2278  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2279  /*width=*/1);
2280  return success();
2281  }
2282  pred = CmpIPredicate::ugt;
2283  break;
2284  case CmpIPredicate::sge:
2285  // (float)int >= -4.4 --> int >= -4
2286  // (float)int >= 4.4 --> int > 4
2287  if (!rhs.isNegative())
2288  pred = CmpIPredicate::sgt;
2289  break;
2290  }
2291  }
2292  }
2293 
2294  // Lower this FP comparison into an appropriate integer version of the
2295  // comparison.
2296  rewriter.replaceOpWithNewOp<CmpIOp>(
2297  op, pred, intVal,
2298  rewriter.create<ConstantOp>(
2299  op.getLoc(), intVal.getType(),
2300  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2301  return success();
2302  }
2303 };
2304 
2305 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2306  MLIRContext *context) {
2307  patterns.insert<CmpFIntToFPConst>(context);
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // SelectOp
2312 //===----------------------------------------------------------------------===//
2313 
2314 // select %arg, %c1, %c0 => extui %arg
2315 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2317 
2318  LogicalResult matchAndRewrite(arith::SelectOp op,
2319  PatternRewriter &rewriter) const override {
2320  // Cannot extui i1 to i1, or i1 to f32
2321  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2322  return failure();
2323 
2324  // select %x, c1, %c0 => extui %arg
2325  if (matchPattern(op.getTrueValue(), m_One()) &&
2326  matchPattern(op.getFalseValue(), m_Zero())) {
2327  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2328  op.getCondition());
2329  return success();
2330  }
2331 
2332  // select %x, c0, %c1 => extui (xor %arg, true)
2333  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2334  matchPattern(op.getFalseValue(), m_One())) {
2335  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2336  op, op.getType(),
2337  rewriter.create<arith::XOrIOp>(
2338  op.getLoc(), op.getCondition(),
2339  rewriter.create<arith::ConstantIntOp>(
2340  op.getLoc(), 1, op.getCondition().getType())));
2341  return success();
2342  }
2343 
2344  return failure();
2345  }
2346 };
2347 
2348 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2349  MLIRContext *context) {
2350  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2351  SelectI1ToNot, SelectToExtUI>(context);
2352 }
2353 
2354 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2355  Value trueVal = getTrueValue();
2356  Value falseVal = getFalseValue();
2357  if (trueVal == falseVal)
2358  return trueVal;
2359 
2360  Value condition = getCondition();
2361 
2362  // select true, %0, %1 => %0
2363  if (matchPattern(adaptor.getCondition(), m_One()))
2364  return trueVal;
2365 
2366  // select false, %0, %1 => %1
2367  if (matchPattern(adaptor.getCondition(), m_Zero()))
2368  return falseVal;
2369 
2370  // If either operand is fully poisoned, return the other.
2371  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2372  return falseVal;
2373 
2374  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2375  return trueVal;
2376 
2377  // select %x, true, false => %x
2378  if (getType().isSignlessInteger(1) &&
2379  matchPattern(adaptor.getTrueValue(), m_One()) &&
2380  matchPattern(adaptor.getFalseValue(), m_Zero()))
2381  return condition;
2382 
2383  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2384  auto pred = cmp.getPredicate();
2385  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2386  auto cmpLhs = cmp.getLhs();
2387  auto cmpRhs = cmp.getRhs();
2388 
2389  // %0 = arith.cmpi eq, %arg0, %arg1
2390  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2391 
2392  // %0 = arith.cmpi ne, %arg0, %arg1
2393  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2394 
2395  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2396  (cmpRhs == trueVal && cmpLhs == falseVal))
2397  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2398  }
2399  }
2400 
2401  // Constant-fold constant operands over non-splat constant condition.
2402  // select %cst_vec, %cst0, %cst1 => %cst2
2403  if (auto cond =
2404  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2405  if (auto lhs =
2406  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2407  if (auto rhs =
2408  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2409  SmallVector<Attribute> results;
2410  results.reserve(static_cast<size_t>(cond.getNumElements()));
2411  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2412  cond.value_end<BoolAttr>());
2413  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2414  lhs.value_end<Attribute>());
2415  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2416  rhs.value_end<Attribute>());
2417 
2418  for (auto [condVal, lhsVal, rhsVal] :
2419  llvm::zip_equal(condVals, lhsVals, rhsVals))
2420  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2421 
2422  return DenseElementsAttr::get(lhs.getType(), results);
2423  }
2424  }
2425  }
2426 
2427  return nullptr;
2428 }
2429 
2430 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2431  Type conditionType, resultType;
2433  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2434  parser.parseOptionalAttrDict(result.attributes) ||
2435  parser.parseColonType(resultType))
2436  return failure();
2437 
2438  // Check for the explicit condition type if this is a masked tensor or vector.
2439  if (succeeded(parser.parseOptionalComma())) {
2440  conditionType = resultType;
2441  if (parser.parseType(resultType))
2442  return failure();
2443  } else {
2444  conditionType = parser.getBuilder().getI1Type();
2445  }
2446 
2447  result.addTypes(resultType);
2448  return parser.resolveOperands(operands,
2449  {conditionType, resultType, resultType},
2450  parser.getNameLoc(), result.operands);
2451 }
2452 
2454  p << " " << getOperands();
2455  p.printOptionalAttrDict((*this)->getAttrs());
2456  p << " : ";
2457  if (ShapedType condType =
2458  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2459  p << condType << ", ";
2460  p << getType();
2461 }
2462 
2463 LogicalResult arith::SelectOp::verify() {
2464  Type conditionType = getCondition().getType();
2465  if (conditionType.isSignlessInteger(1))
2466  return success();
2467 
2468  // If the result type is a vector or tensor, the type can be a mask with the
2469  // same elements.
2470  Type resultType = getType();
2471  if (!llvm::isa<TensorType, VectorType>(resultType))
2472  return emitOpError() << "expected condition to be a signless i1, but got "
2473  << conditionType;
2474  Type shapedConditionType = getI1SameShape(resultType);
2475  if (conditionType != shapedConditionType) {
2476  return emitOpError() << "expected condition type to have the same shape "
2477  "as the result type, expected "
2478  << shapedConditionType << ", but got "
2479  << conditionType;
2480  }
2481  return success();
2482 }
2483 //===----------------------------------------------------------------------===//
2484 // ShLIOp
2485 //===----------------------------------------------------------------------===//
2486 
2487 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2488  // shli(x, 0) -> x
2489  if (matchPattern(adaptor.getRhs(), m_Zero()))
2490  return getLhs();
2491  // Don't fold if shifting more or equal than the bit width.
2492  bool bounded = false;
2493  auto result = constFoldBinaryOp<IntegerAttr>(
2494  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2495  bounded = b.ult(b.getBitWidth());
2496  return a.shl(b);
2497  });
2498  return bounded ? result : Attribute();
2499 }
2500 
2501 //===----------------------------------------------------------------------===//
2502 // ShRUIOp
2503 //===----------------------------------------------------------------------===//
2504 
2505 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2506  // shrui(x, 0) -> x
2507  if (matchPattern(adaptor.getRhs(), m_Zero()))
2508  return getLhs();
2509  // Don't fold if shifting more or equal than the bit width.
2510  bool bounded = false;
2511  auto result = constFoldBinaryOp<IntegerAttr>(
2512  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2513  bounded = b.ult(b.getBitWidth());
2514  return a.lshr(b);
2515  });
2516  return bounded ? result : Attribute();
2517 }
2518 
2519 //===----------------------------------------------------------------------===//
2520 // ShRSIOp
2521 //===----------------------------------------------------------------------===//
2522 
2523 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2524  // shrsi(x, 0) -> x
2525  if (matchPattern(adaptor.getRhs(), m_Zero()))
2526  return getLhs();
2527  // Don't fold if shifting more or equal than the bit width.
2528  bool bounded = false;
2529  auto result = constFoldBinaryOp<IntegerAttr>(
2530  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2531  bounded = b.ult(b.getBitWidth());
2532  return a.ashr(b);
2533  });
2534  return bounded ? result : Attribute();
2535 }
2536 
2537 //===----------------------------------------------------------------------===//
2538 // Atomic Enum
2539 //===----------------------------------------------------------------------===//
2540 
2541 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2542 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2543  OpBuilder &builder, Location loc,
2544  bool useOnlyFiniteValue) {
2545  switch (kind) {
2546  case AtomicRMWKind::maximumf: {
2547  const llvm::fltSemantics &semantic =
2548  llvm::cast<FloatType>(resultType).getFloatSemantics();
2549  APFloat identity = useOnlyFiniteValue
2550  ? APFloat::getLargest(semantic, /*Negative=*/true)
2551  : APFloat::getInf(semantic, /*Negative=*/true);
2552  return builder.getFloatAttr(resultType, identity);
2553  }
2554  case AtomicRMWKind::maxnumf: {
2555  const llvm::fltSemantics &semantic =
2556  llvm::cast<FloatType>(resultType).getFloatSemantics();
2557  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2558  return builder.getFloatAttr(resultType, identity);
2559  }
2560  case AtomicRMWKind::addf:
2561  case AtomicRMWKind::addi:
2562  case AtomicRMWKind::maxu:
2563  case AtomicRMWKind::ori:
2564  return builder.getZeroAttr(resultType);
2565  case AtomicRMWKind::andi:
2566  return builder.getIntegerAttr(
2567  resultType,
2568  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2569  case AtomicRMWKind::maxs:
2570  return builder.getIntegerAttr(
2571  resultType, APInt::getSignedMinValue(
2572  llvm::cast<IntegerType>(resultType).getWidth()));
2573  case AtomicRMWKind::minimumf: {
2574  const llvm::fltSemantics &semantic =
2575  llvm::cast<FloatType>(resultType).getFloatSemantics();
2576  APFloat identity = useOnlyFiniteValue
2577  ? APFloat::getLargest(semantic, /*Negative=*/false)
2578  : APFloat::getInf(semantic, /*Negative=*/false);
2579 
2580  return builder.getFloatAttr(resultType, identity);
2581  }
2582  case AtomicRMWKind::minnumf: {
2583  const llvm::fltSemantics &semantic =
2584  llvm::cast<FloatType>(resultType).getFloatSemantics();
2585  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2586  return builder.getFloatAttr(resultType, identity);
2587  }
2588  case AtomicRMWKind::mins:
2589  return builder.getIntegerAttr(
2590  resultType, APInt::getSignedMaxValue(
2591  llvm::cast<IntegerType>(resultType).getWidth()));
2592  case AtomicRMWKind::minu:
2593  return builder.getIntegerAttr(
2594  resultType,
2595  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2596  case AtomicRMWKind::muli:
2597  return builder.getIntegerAttr(resultType, 1);
2598  case AtomicRMWKind::mulf:
2599  return builder.getFloatAttr(resultType, 1);
2600  // TODO: Add remaining reduction operations.
2601  default:
2602  (void)emitOptionalError(loc, "Reduction operation type not supported");
2603  break;
2604  }
2605  return nullptr;
2606 }
2607 
2608 /// Return the identity numeric value associated to the give op.
2609 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2610  std::optional<AtomicRMWKind> maybeKind =
2612  // Floating-point operations.
2613  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2614  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2615  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2616  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2617  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2618  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2619  // Integer operations.
2620  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2621  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2622  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2623  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2624  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2625  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2626  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2627  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2628  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2629  .Default([](Operation *op) { return std::nullopt; });
2630  if (!maybeKind) {
2631  return std::nullopt;
2632  }
2633 
2634  bool useOnlyFiniteValue = false;
2635  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2636  if (fmfOpInterface) {
2637  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2638  useOnlyFiniteValue =
2639  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2640  }
2641 
2642  // Builder only used as helper for attribute creation.
2643  OpBuilder b(op->getContext());
2644  Type resultType = op->getResult(0).getType();
2645 
2646  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2647  useOnlyFiniteValue);
2648 }
2649 
2650 /// Returns the identity value associated with an AtomicRMWKind op.
2651 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2652  OpBuilder &builder, Location loc,
2653  bool useOnlyFiniteValue) {
2654  auto attr =
2655  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2656  return builder.create<arith::ConstantOp>(loc, attr);
2657 }
2658 
2659 /// Return the value obtained by applying the reduction operation kind
2660 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2661 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2662  Location loc, Value lhs, Value rhs) {
2663  switch (op) {
2664  case AtomicRMWKind::addf:
2665  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2666  case AtomicRMWKind::addi:
2667  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2668  case AtomicRMWKind::mulf:
2669  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2670  case AtomicRMWKind::muli:
2671  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2672  case AtomicRMWKind::maximumf:
2673  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2674  case AtomicRMWKind::minimumf:
2675  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2676  case AtomicRMWKind::maxnumf:
2677  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2678  case AtomicRMWKind::minnumf:
2679  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2680  case AtomicRMWKind::maxs:
2681  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2682  case AtomicRMWKind::mins:
2683  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2684  case AtomicRMWKind::maxu:
2685  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2686  case AtomicRMWKind::minu:
2687  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2688  case AtomicRMWKind::ori:
2689  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2690  case AtomicRMWKind::andi:
2691  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2692  // TODO: Add remaining reduction operations.
2693  default:
2694  (void)emitOptionalError(loc, "Reduction operation type not supported");
2695  break;
2696  }
2697  return nullptr;
2698 }
2699 
2700 //===----------------------------------------------------------------------===//
2701 // TableGen'd op method definitions
2702 //===----------------------------------------------------------------------===//
2703 
2704 #define GET_OP_CLASSES
2705 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2706 
2707 //===----------------------------------------------------------------------===//
2708 // TableGen'd enum attribute definitions
2709 //===----------------------------------------------------------------------===//
2710 
2711 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:618
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1326
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1262
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1841
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:579
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1276
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1248
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:678
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:660
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1341
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1693
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1591
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1298
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:859
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1269
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1859
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1284
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1242
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:340
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1312
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2050
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2023
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:273
static bool classof(Operation *op)
Definition: ArithOps.cpp:279
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:285
static bool classof(Operation *op)
Definition: ArithOps.cpp:291
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:252
static bool classof(Operation *op)
Definition: ArithOps.cpp:267
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2609
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1813
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2542
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2661
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2651
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2318
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes