MLIR  20.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/APSInt.h"
28 #include "llvm/ADT/FloatingPointMode.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 
37 //===----------------------------------------------------------------------===//
38 // Pattern helpers
39 //===----------------------------------------------------------------------===//
40 
41 static IntegerAttr
43  Attribute rhs,
44  function_ref<APInt(const APInt &, const APInt &)> binFn) {
45  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47  APInt value = binFn(lhsVal, rhsVal);
48  return IntegerAttr::get(res.getType(), value);
49 }
50 
51 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
52  Attribute lhs, Attribute rhs) {
53  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54 }
55 
56 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
57  Attribute lhs, Attribute rhs) {
58  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59 }
60 
61 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
62  Attribute lhs, Attribute rhs) {
63  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64 }
65 
66 // Merge overflow flags from 2 ops, selecting the most conservative combination.
67 static IntegerOverflowFlagsAttr
68 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69  IntegerOverflowFlagsAttr val2) {
70  return IntegerOverflowFlagsAttr::get(val1.getContext(),
71  val1.getValue() & val2.getValue());
72 }
73 
74 /// Invert an integer comparison predicate.
75 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76  switch (pred) {
77  case arith::CmpIPredicate::eq:
78  return arith::CmpIPredicate::ne;
79  case arith::CmpIPredicate::ne:
80  return arith::CmpIPredicate::eq;
81  case arith::CmpIPredicate::slt:
82  return arith::CmpIPredicate::sge;
83  case arith::CmpIPredicate::sle:
84  return arith::CmpIPredicate::sgt;
85  case arith::CmpIPredicate::sgt:
86  return arith::CmpIPredicate::sle;
87  case arith::CmpIPredicate::sge:
88  return arith::CmpIPredicate::slt;
89  case arith::CmpIPredicate::ult:
90  return arith::CmpIPredicate::uge;
91  case arith::CmpIPredicate::ule:
92  return arith::CmpIPredicate::ugt;
93  case arith::CmpIPredicate::ugt:
94  return arith::CmpIPredicate::ule;
95  case arith::CmpIPredicate::uge:
96  return arith::CmpIPredicate::ult;
97  }
98  llvm_unreachable("unknown cmpi predicate kind");
99 }
100 
101 /// Equivalent to
102 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103 ///
104 /// Not possible to implement as chain of calls as this would introduce a
105 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106 /// on the LLVM dialect and on translation to LLVM.
107 static llvm::RoundingMode
108 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109  switch (roundingMode) {
110  case RoundingMode::downward:
111  return llvm::RoundingMode::TowardNegative;
112  case RoundingMode::to_nearest_away:
113  return llvm::RoundingMode::NearestTiesToAway;
114  case RoundingMode::to_nearest_even:
115  return llvm::RoundingMode::NearestTiesToEven;
116  case RoundingMode::toward_zero:
117  return llvm::RoundingMode::TowardZero;
118  case RoundingMode::upward:
119  return llvm::RoundingMode::TowardPositive;
120  }
121  llvm_unreachable("Unhandled rounding mode");
122 }
123 
124 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125  return arith::CmpIPredicateAttr::get(pred.getContext(),
126  invertPredicate(pred.getValue()));
127 }
128 
129 static int64_t getScalarOrElementWidth(Type type) {
130  Type elemTy = getElementTypeOrSelf(type);
131  if (elemTy.isIntOrFloat())
132  return elemTy.getIntOrFloatBitWidth();
133 
134  return -1;
135 }
136 
137 static int64_t getScalarOrElementWidth(Value value) {
138  return getScalarOrElementWidth(value.getType());
139 }
140 
141 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142  APInt value;
143  if (matchPattern(attr, m_ConstantInt(&value)))
144  return value;
145 
146  return failure();
147 }
148 
149 static Attribute getBoolAttribute(Type type, bool value) {
150  auto boolAttr = BoolAttr::get(type.getContext(), value);
151  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
152  if (!shapedType)
153  return boolAttr;
154  return DenseElementsAttr::get(shapedType, boolAttr);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TableGen'd canonicalization patterns
159 //===----------------------------------------------------------------------===//
160 
161 namespace {
162 #include "ArithCanonicalization.inc"
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // Common helpers
167 //===----------------------------------------------------------------------===//
168 
169 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
170 static Type getI1SameShape(Type type) {
171  auto i1Type = IntegerType::get(type.getContext(), 1);
172  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
173  return shapedType.cloneWith(std::nullopt, i1Type);
174  if (llvm::isa<UnrankedTensorType>(type))
175  return UnrankedTensorType::get(i1Type);
176  return i1Type;
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // ConstantOp
181 //===----------------------------------------------------------------------===//
182 
183 void arith::ConstantOp::getAsmResultNames(
184  function_ref<void(Value, StringRef)> setNameFn) {
185  auto type = getType();
186  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
187  auto intType = llvm::dyn_cast<IntegerType>(type);
188 
189  // Sugar i1 constants with 'true' and 'false'.
190  if (intType && intType.getWidth() == 1)
191  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192 
193  // Otherwise, build a complex name with the value and type.
194  SmallString<32> specialNameBuffer;
195  llvm::raw_svector_ostream specialName(specialNameBuffer);
196  specialName << 'c' << intCst.getValue();
197  if (intType)
198  specialName << '_' << type;
199  setNameFn(getResult(), specialName.str());
200  } else {
201  setNameFn(getResult(), "cst");
202  }
203 }
204 
205 /// TODO: disallow arith.constant to return anything other than signless integer
206 /// or float like.
207 LogicalResult arith::ConstantOp::verify() {
208  auto type = getType();
209  // The value's type must match the return type.
210  if (getValue().getType() != type) {
211  return emitOpError() << "value type " << getValue().getType()
212  << " must match return type: " << type;
213  }
214  // Integer values must be signless.
215  if (llvm::isa<IntegerType>(type) &&
216  !llvm::cast<IntegerType>(type).isSignless())
217  return emitOpError("integer return type must be signless");
218  // Any float or elements attribute are acceptable.
219  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
220  return emitOpError(
221  "value must be an integer, float, or elements attribute");
222  }
223 
224  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
225  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
226  // However, this would most likely require updating the lowerings to LLVM.
227  auto vecType = dyn_cast<VectorType>(type);
228  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
229  return emitOpError(
230  "intializing scalable vectors with elements attribute is not supported"
231  " unless it's a vector splat");
232  return success();
233 }
234 
235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
236  // The value's type must be the same as the provided type.
237  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238  if (!typedAttr || typedAttr.getType() != type)
239  return false;
240  // Integer values must be signless.
241  if (llvm::isa<IntegerType>(type) &&
242  !llvm::cast<IntegerType>(type).isSignless())
243  return false;
244  // Integer, float, and element attributes are buildable.
245  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246 }
247 
248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
249  Type type, Location loc) {
250  if (isBuildableWith(value, type))
251  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
252  return nullptr;
253 }
254 
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
256 
258  int64_t value, unsigned width) {
259  auto type = builder.getIntegerType(width);
260  arith::ConstantOp::build(builder, result, type,
261  builder.getIntegerAttr(type, value));
262 }
263 
265  int64_t value, Type type) {
266  assert(type.isSignlessInteger() &&
267  "ConstantIntOp can only have signless integer type values");
268  arith::ConstantOp::build(builder, result, type,
269  builder.getIntegerAttr(type, value));
270 }
271 
273  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274  return constOp.getType().isSignlessInteger();
275  return false;
276 }
277 
279  const APFloat &value, FloatType type) {
280  arith::ConstantOp::build(builder, result, type,
281  builder.getFloatAttr(type, value));
282 }
283 
285  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286  return llvm::isa<FloatType>(constOp.getType());
287  return false;
288 }
289 
291  int64_t value) {
292  arith::ConstantOp::build(builder, result, builder.getIndexType(),
293  builder.getIndexAttr(value));
294 }
295 
297  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298  return constOp.getType().isIndex();
299  return false;
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // AddIOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
307  // addi(x, 0) -> x
308  if (matchPattern(adaptor.getRhs(), m_Zero()))
309  return getLhs();
310 
311  // addi(subi(a, b), b) -> a
312  if (auto sub = getLhs().getDefiningOp<SubIOp>())
313  if (getRhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  // addi(b, subi(a, b)) -> a
317  if (auto sub = getRhs().getDefiningOp<SubIOp>())
318  if (getLhs() == sub.getRhs())
319  return sub.getLhs();
320 
321  return constFoldBinaryOp<IntegerAttr>(
322  adaptor.getOperands(),
323  [](APInt a, const APInt &b) { return std::move(a) + b; });
324 }
325 
326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
327  MLIRContext *context) {
328  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // AddUIExtendedOp
334 //===----------------------------------------------------------------------===//
335 
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
339  return llvm::to_vector<4>(vt.getShape());
340  return std::nullopt;
341 }
342 
343 // Returns the overflow bit, assuming that `sum` is the result of unsigned
344 // addition of `operand` and another number.
345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
346  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
347 }
348 
349 LogicalResult
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352  Type overflowTy = getOverflow().getType();
353  // addui_extended(x, 0) -> x, false
354  if (matchPattern(getRhs(), m_Zero())) {
355  Builder builder(getContext());
356  auto falseValue = builder.getZeroAttr(overflowTy);
357 
358  results.push_back(getLhs());
359  results.push_back(falseValue);
360  return success();
361  }
362 
363  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
364  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
365  // operands. If that succeeds, calculate the overflow bit based on the sum
366  // and the first (constant) operand, `lhs`.
367  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368  adaptor.getOperands(),
369  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
370  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371  ArrayRef({sumAttr, adaptor.getLhs()}),
372  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
374  if (!overflowAttr)
375  return failure();
376 
377  results.push_back(sumAttr);
378  results.push_back(overflowAttr);
379  return success();
380  }
381 
382  return failure();
383 }
384 
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
386  RewritePatternSet &patterns, MLIRContext *context) {
387  patterns.add<AddUIExtendedToAddI>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // SubIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395  // subi(x,x) -> 0
396  if (getOperand(0) == getOperand(1))
397  return Builder(getContext()).getZeroAttr(getType());
398  // subi(x,0) -> x
399  if (matchPattern(adaptor.getRhs(), m_Zero()))
400  return getLhs();
401 
402  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
403  // subi(addi(a, b), b) -> a
404  if (getRhs() == add.getRhs())
405  return add.getLhs();
406  // subi(addi(a, b), a) -> b
407  if (getRhs() == add.getLhs())
408  return add.getRhs();
409  }
410 
411  return constFoldBinaryOp<IntegerAttr>(
412  adaptor.getOperands(),
413  [](APInt a, const APInt &b) { return std::move(a) - b; });
414 }
415 
416 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
417  MLIRContext *context) {
418  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
419  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
420  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // MulIOp
425 //===----------------------------------------------------------------------===//
426 
427 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
428  // muli(x, 0) -> 0
429  if (matchPattern(adaptor.getRhs(), m_Zero()))
430  return getRhs();
431  // muli(x, 1) -> x
432  if (matchPattern(adaptor.getRhs(), m_One()))
433  return getLhs();
434  // TODO: Handle the overflow case.
435 
436  // default folder
437  return constFoldBinaryOp<IntegerAttr>(
438  adaptor.getOperands(),
439  [](const APInt &a, const APInt &b) { return a * b; });
440 }
441 
442 void arith::MulIOp::getAsmResultNames(
443  function_ref<void(Value, StringRef)> setNameFn) {
444  if (!isa<IndexType>(getType()))
445  return;
446 
447  // Match vector.vscale by name to avoid depending on the vector dialect (which
448  // is a circular dependency).
449  auto isVscale = [](Operation *op) {
450  return op && op->getName().getStringRef() == "vector.vscale";
451  };
452 
453  IntegerAttr baseValue;
454  auto isVscaleExpr = [&](Value a, Value b) {
455  return matchPattern(a, m_Constant(&baseValue)) &&
456  isVscale(b.getDefiningOp());
457  };
458 
459  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
460  return;
461 
462  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
463  SmallString<32> specialNameBuffer;
464  llvm::raw_svector_ostream specialName(specialNameBuffer);
465  specialName << 'c' << baseValue.getInt() << "_vscale";
466  setNameFn(getResult(), specialName.str());
467 }
468 
469 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
470  MLIRContext *context) {
471  patterns.add<MulIMulIConstant>(context);
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // MulSIExtendedOp
476 //===----------------------------------------------------------------------===//
477 
478 std::optional<SmallVector<int64_t, 4>>
479 arith::MulSIExtendedOp::getShapeForUnroll() {
480  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
481  return llvm::to_vector<4>(vt.getShape());
482  return std::nullopt;
483 }
484 
485 LogicalResult
486 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
488  // mulsi_extended(x, 0) -> 0, 0
489  if (matchPattern(adaptor.getRhs(), m_Zero())) {
490  Attribute zero = adaptor.getRhs();
491  results.push_back(zero);
492  results.push_back(zero);
493  return success();
494  }
495 
496  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
497  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
498  adaptor.getOperands(),
499  [](const APInt &a, const APInt &b) { return a * b; })) {
500  // Invoke the constant fold helper again to calculate the 'high' result.
501  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
502  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
503  return llvm::APIntOps::mulhs(a, b);
504  });
505  assert(highAttr && "Unexpected constant-folding failure");
506 
507  results.push_back(lowAttr);
508  results.push_back(highAttr);
509  return success();
510  }
511 
512  return failure();
513 }
514 
515 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516  RewritePatternSet &patterns, MLIRContext *context) {
517  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // MulUIExtendedOp
522 //===----------------------------------------------------------------------===//
523 
524 std::optional<SmallVector<int64_t, 4>>
525 arith::MulUIExtendedOp::getShapeForUnroll() {
526  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
527  return llvm::to_vector<4>(vt.getShape());
528  return std::nullopt;
529 }
530 
531 LogicalResult
532 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
534  // mului_extended(x, 0) -> 0, 0
535  if (matchPattern(adaptor.getRhs(), m_Zero())) {
536  Attribute zero = adaptor.getRhs();
537  results.push_back(zero);
538  results.push_back(zero);
539  return success();
540  }
541 
542  // mului_extended(x, 1) -> x, 0
543  if (matchPattern(adaptor.getRhs(), m_One())) {
544  Builder builder(getContext());
545  Attribute zero = builder.getZeroAttr(getLhs().getType());
546  results.push_back(getLhs());
547  results.push_back(zero);
548  return success();
549  }
550 
551  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
552  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
553  adaptor.getOperands(),
554  [](const APInt &a, const APInt &b) { return a * b; })) {
555  // Invoke the constant fold helper again to calculate the 'high' result.
556  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
557  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
558  return llvm::APIntOps::mulhu(a, b);
559  });
560  assert(highAttr && "Unexpected constant-folding failure");
561 
562  results.push_back(lowAttr);
563  results.push_back(highAttr);
564  return success();
565  }
566 
567  return failure();
568 }
569 
570 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571  RewritePatternSet &patterns, MLIRContext *context) {
572  patterns.add<MulUIExtendedToMulI>(context);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // DivUIOp
577 //===----------------------------------------------------------------------===//
578 
579 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
580  // divui (x, 1) -> x.
581  if (matchPattern(adaptor.getRhs(), m_One()))
582  return getLhs();
583 
584  // Don't fold if it would require a division by zero.
585  bool div0 = false;
586  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
587  [&](APInt a, const APInt &b) {
588  if (div0 || !b) {
589  div0 = true;
590  return a;
591  }
592  return a.udiv(b);
593  });
594 
595  return div0 ? Attribute() : result;
596 }
597 
598 /// Returns whether an unsigned division by `divisor` is speculatable.
600  // X / 0 => UB
601  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
603 
605 }
606 
607 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
608  return getDivUISpeculatability(getRhs());
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // DivSIOp
613 //===----------------------------------------------------------------------===//
614 
615 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
616  // divsi (x, 1) -> x.
617  if (matchPattern(adaptor.getRhs(), m_One()))
618  return getLhs();
619 
620  // Don't fold if it would overflow or if it requires a division by zero.
621  bool overflowOrDiv0 = false;
622  auto result = constFoldBinaryOp<IntegerAttr>(
623  adaptor.getOperands(), [&](APInt a, const APInt &b) {
624  if (overflowOrDiv0 || !b) {
625  overflowOrDiv0 = true;
626  return a;
627  }
628  return a.sdiv_ov(b, overflowOrDiv0);
629  });
630 
631  return overflowOrDiv0 ? Attribute() : result;
632 }
633 
634 /// Returns whether a signed division by `divisor` is speculatable. This
635 /// function conservatively assumes that all signed division by -1 are not
636 /// speculatable.
638  // X / 0 => UB
639  // INT_MIN / -1 => UB
640  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
643 
645 }
646 
647 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
648  return getDivSISpeculatability(getRhs());
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // Ceil and floor division folding helpers
653 //===----------------------------------------------------------------------===//
654 
655 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
656  bool &overflow) {
657  // Returns (a-1)/b + 1
658  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
659  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
660  return val.sadd_ov(one, overflow);
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // CeilDivUIOp
665 //===----------------------------------------------------------------------===//
666 
667 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
668  // ceildivui (x, 1) -> x.
669  if (matchPattern(adaptor.getRhs(), m_One()))
670  return getLhs();
671 
672  bool overflowOrDiv0 = false;
673  auto result = constFoldBinaryOp<IntegerAttr>(
674  adaptor.getOperands(), [&](APInt a, const APInt &b) {
675  if (overflowOrDiv0 || !b) {
676  overflowOrDiv0 = true;
677  return a;
678  }
679  APInt quotient = a.udiv(b);
680  if (!a.urem(b))
681  return quotient;
682  APInt one(a.getBitWidth(), 1, true);
683  return quotient.uadd_ov(one, overflowOrDiv0);
684  });
685 
686  return overflowOrDiv0 ? Attribute() : result;
687 }
688 
689 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
690  return getDivUISpeculatability(getRhs());
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // CeilDivSIOp
695 //===----------------------------------------------------------------------===//
696 
697 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
698  // ceildivsi (x, 1) -> x.
699  if (matchPattern(adaptor.getRhs(), m_One()))
700  return getLhs();
701 
702  // Don't fold if it would overflow or if it requires a division by zero.
703  // TODO: This hook won't fold operations where a = MININT, because
704  // negating MININT overflows. This can be improved.
705  bool overflowOrDiv0 = false;
706  auto result = constFoldBinaryOp<IntegerAttr>(
707  adaptor.getOperands(), [&](APInt a, const APInt &b) {
708  if (overflowOrDiv0 || !b) {
709  overflowOrDiv0 = true;
710  return a;
711  }
712  if (!a)
713  return a;
714  // After this point we know that neither a or b are zero.
715  unsigned bits = a.getBitWidth();
716  APInt zero = APInt::getZero(bits);
717  bool aGtZero = a.sgt(zero);
718  bool bGtZero = b.sgt(zero);
719  if (aGtZero && bGtZero) {
720  // Both positive, return ceil(a, b).
721  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
722  }
723 
724  // No folding happens if any of the intermediate arithmetic operations
725  // overflows.
726  bool overflowNegA = false;
727  bool overflowNegB = false;
728  bool overflowDiv = false;
729  bool overflowNegRes = false;
730  if (!aGtZero && !bGtZero) {
731  // Both negative, return ceil(-a, -b).
732  APInt posA = zero.ssub_ov(a, overflowNegA);
733  APInt posB = zero.ssub_ov(b, overflowNegB);
734  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
735  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
736  return res;
737  }
738  if (!aGtZero && bGtZero) {
739  // A is negative, b is positive, return - ( -a / b).
740  APInt posA = zero.ssub_ov(a, overflowNegA);
741  APInt div = posA.sdiv_ov(b, overflowDiv);
742  APInt res = zero.ssub_ov(div, overflowNegRes);
743  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
744  return res;
745  }
746  // A is positive, b is negative, return - (a / -b).
747  APInt posB = zero.ssub_ov(b, overflowNegB);
748  APInt div = a.sdiv_ov(posB, overflowDiv);
749  APInt res = zero.ssub_ov(div, overflowNegRes);
750 
751  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
752  return res;
753  });
754 
755  return overflowOrDiv0 ? Attribute() : result;
756 }
757 
758 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
759  return getDivSISpeculatability(getRhs());
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // FloorDivSIOp
764 //===----------------------------------------------------------------------===//
765 
766 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
767  // floordivsi (x, 1) -> x.
768  if (matchPattern(adaptor.getRhs(), m_One()))
769  return getLhs();
770 
771  // Don't fold if it would overflow or if it requires a division by zero.
772  bool overflowOrDiv = false;
773  auto result = constFoldBinaryOp<IntegerAttr>(
774  adaptor.getOperands(), [&](APInt a, const APInt &b) {
775  if (b.isZero()) {
776  overflowOrDiv = true;
777  return a;
778  }
779  return a.sfloordiv_ov(b, overflowOrDiv);
780  });
781 
782  return overflowOrDiv ? Attribute() : result;
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // RemUIOp
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
790  // remui (x, 1) -> 0.
791  if (matchPattern(adaptor.getRhs(), m_One()))
792  return Builder(getContext()).getZeroAttr(getType());
793 
794  // Don't fold if it would require a division by zero.
795  bool div0 = false;
796  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
797  [&](APInt a, const APInt &b) {
798  if (div0 || b.isZero()) {
799  div0 = true;
800  return a;
801  }
802  return a.urem(b);
803  });
804 
805  return div0 ? Attribute() : result;
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // RemSIOp
810 //===----------------------------------------------------------------------===//
811 
812 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
813  // remsi (x, 1) -> 0.
814  if (matchPattern(adaptor.getRhs(), m_One()))
815  return Builder(getContext()).getZeroAttr(getType());
816 
817  // Don't fold if it would require a division by zero.
818  bool div0 = false;
819  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820  [&](APInt a, const APInt &b) {
821  if (div0 || b.isZero()) {
822  div0 = true;
823  return a;
824  }
825  return a.srem(b);
826  });
827 
828  return div0 ? Attribute() : result;
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // AndIOp
833 //===----------------------------------------------------------------------===//
834 
835 /// Fold `and(a, and(a, b))` to `and(a, b)`
836 static Value foldAndIofAndI(arith::AndIOp op) {
837  for (bool reversePrev : {false, true}) {
838  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
839  .getDefiningOp<arith::AndIOp>();
840  if (!prev)
841  continue;
842 
843  Value other = (reversePrev ? op.getLhs() : op.getRhs());
844  if (other != prev.getLhs() && other != prev.getRhs())
845  continue;
846 
847  return prev.getResult();
848  }
849  return {};
850 }
851 
852 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
853  /// and(x, 0) -> 0
854  if (matchPattern(adaptor.getRhs(), m_Zero()))
855  return getRhs();
856  /// and(x, allOnes) -> x
857  APInt intValue;
858  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
859  intValue.isAllOnes())
860  return getLhs();
861  /// and(x, not(x)) -> 0
862  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
863  m_ConstantInt(&intValue))) &&
864  intValue.isAllOnes())
865  return Builder(getContext()).getZeroAttr(getType());
866  /// and(not(x), x) -> 0
867  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
868  m_ConstantInt(&intValue))) &&
869  intValue.isAllOnes())
870  return Builder(getContext()).getZeroAttr(getType());
871 
872  /// and(a, and(a, b)) -> and(a, b)
873  if (Value result = foldAndIofAndI(*this))
874  return result;
875 
876  return constFoldBinaryOp<IntegerAttr>(
877  adaptor.getOperands(),
878  [](APInt a, const APInt &b) { return std::move(a) & b; });
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // OrIOp
883 //===----------------------------------------------------------------------===//
884 
885 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
886  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
887  /// or(x, 0) -> x
888  if (rhsVal.isZero())
889  return getLhs();
890  /// or(x, <all ones>) -> <all ones>
891  if (rhsVal.isAllOnes())
892  return adaptor.getRhs();
893  }
894 
895  APInt intValue;
896  /// or(x, xor(x, 1)) -> 1
897  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
898  m_ConstantInt(&intValue))) &&
899  intValue.isAllOnes())
900  return getRhs().getDefiningOp<XOrIOp>().getRhs();
901  /// or(xor(x, 1), x) -> 1
902  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
903  m_ConstantInt(&intValue))) &&
904  intValue.isAllOnes())
905  return getLhs().getDefiningOp<XOrIOp>().getRhs();
906 
907  return constFoldBinaryOp<IntegerAttr>(
908  adaptor.getOperands(),
909  [](APInt a, const APInt &b) { return std::move(a) | b; });
910 }
911 
912 //===----------------------------------------------------------------------===//
913 // XOrIOp
914 //===----------------------------------------------------------------------===//
915 
916 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
917  /// xor(x, 0) -> x
918  if (matchPattern(adaptor.getRhs(), m_Zero()))
919  return getLhs();
920  /// xor(x, x) -> 0
921  if (getLhs() == getRhs())
922  return Builder(getContext()).getZeroAttr(getType());
923  /// xor(xor(x, a), a) -> x
924  /// xor(xor(a, x), a) -> x
925  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
926  if (prev.getRhs() == getRhs())
927  return prev.getLhs();
928  if (prev.getLhs() == getRhs())
929  return prev.getRhs();
930  }
931  /// xor(a, xor(x, a)) -> x
932  /// xor(a, xor(a, x)) -> x
933  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
934  if (prev.getRhs() == getLhs())
935  return prev.getLhs();
936  if (prev.getLhs() == getLhs())
937  return prev.getRhs();
938  }
939 
940  return constFoldBinaryOp<IntegerAttr>(
941  adaptor.getOperands(),
942  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
943 }
944 
945 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
946  MLIRContext *context) {
947  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
948 }
949 
950 //===----------------------------------------------------------------------===//
951 // NegFOp
952 //===----------------------------------------------------------------------===//
953 
954 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
955  /// negf(negf(x)) -> x
956  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
957  return op.getOperand();
958  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
959  [](const APFloat &a) { return -a; });
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // AddFOp
964 //===----------------------------------------------------------------------===//
965 
966 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
967  // addf(x, -0) -> x
968  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
969  return getLhs();
970 
971  return constFoldBinaryOp<FloatAttr>(
972  adaptor.getOperands(),
973  [](const APFloat &a, const APFloat &b) { return a + b; });
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // SubFOp
978 //===----------------------------------------------------------------------===//
979 
980 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
981  // subf(x, +0) -> x
982  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
983  return getLhs();
984 
985  return constFoldBinaryOp<FloatAttr>(
986  adaptor.getOperands(),
987  [](const APFloat &a, const APFloat &b) { return a - b; });
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // MaximumFOp
992 //===----------------------------------------------------------------------===//
993 
994 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
995  // maximumf(x,x) -> x
996  if (getLhs() == getRhs())
997  return getRhs();
998 
999  // maximumf(x, -inf) -> x
1000  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1001  return getLhs();
1002 
1003  return constFoldBinaryOp<FloatAttr>(
1004  adaptor.getOperands(),
1005  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1006 }
1007 
1008 //===----------------------------------------------------------------------===//
1009 // MaxNumFOp
1010 //===----------------------------------------------------------------------===//
1011 
1012 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1013  // maxnumf(x,x) -> x
1014  if (getLhs() == getRhs())
1015  return getRhs();
1016 
1017  // maxnumf(x, -inf) -> x
1018  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1019  return getLhs();
1020 
1021  return constFoldBinaryOp<FloatAttr>(
1022  adaptor.getOperands(),
1023  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // MaxSIOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1031  // maxsi(x,x) -> x
1032  if (getLhs() == getRhs())
1033  return getRhs();
1034 
1035  if (APInt intValue;
1036  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1037  // maxsi(x,MAX_INT) -> MAX_INT
1038  if (intValue.isMaxSignedValue())
1039  return getRhs();
1040  // maxsi(x, MIN_INT) -> x
1041  if (intValue.isMinSignedValue())
1042  return getLhs();
1043  }
1044 
1045  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046  [](const APInt &a, const APInt &b) {
1047  return llvm::APIntOps::smax(a, b);
1048  });
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // MaxUIOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1056  // maxui(x,x) -> x
1057  if (getLhs() == getRhs())
1058  return getRhs();
1059 
1060  if (APInt intValue;
1061  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1062  // maxui(x,MAX_INT) -> MAX_INT
1063  if (intValue.isMaxValue())
1064  return getRhs();
1065  // maxui(x, MIN_INT) -> x
1066  if (intValue.isMinValue())
1067  return getLhs();
1068  }
1069 
1070  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1071  [](const APInt &a, const APInt &b) {
1072  return llvm::APIntOps::umax(a, b);
1073  });
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // MinimumFOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1081  // minimumf(x,x) -> x
1082  if (getLhs() == getRhs())
1083  return getRhs();
1084 
1085  // minimumf(x, +inf) -> x
1086  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1087  return getLhs();
1088 
1089  return constFoldBinaryOp<FloatAttr>(
1090  adaptor.getOperands(),
1091  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // MinNumFOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1099  // minnumf(x,x) -> x
1100  if (getLhs() == getRhs())
1101  return getRhs();
1102 
1103  // minnumf(x, +inf) -> x
1104  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1105  return getLhs();
1106 
1107  return constFoldBinaryOp<FloatAttr>(
1108  adaptor.getOperands(),
1109  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // MinSIOp
1114 //===----------------------------------------------------------------------===//
1115 
1116 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1117  // minsi(x,x) -> x
1118  if (getLhs() == getRhs())
1119  return getRhs();
1120 
1121  if (APInt intValue;
1122  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1123  // minsi(x,MIN_INT) -> MIN_INT
1124  if (intValue.isMinSignedValue())
1125  return getRhs();
1126  // minsi(x, MAX_INT) -> x
1127  if (intValue.isMaxSignedValue())
1128  return getLhs();
1129  }
1130 
1131  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1132  [](const APInt &a, const APInt &b) {
1133  return llvm::APIntOps::smin(a, b);
1134  });
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // MinUIOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1142  // minui(x,x) -> x
1143  if (getLhs() == getRhs())
1144  return getRhs();
1145 
1146  if (APInt intValue;
1147  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1148  // minui(x,MIN_INT) -> MIN_INT
1149  if (intValue.isMinValue())
1150  return getRhs();
1151  // minui(x, MAX_INT) -> x
1152  if (intValue.isMaxValue())
1153  return getLhs();
1154  }
1155 
1156  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1157  [](const APInt &a, const APInt &b) {
1158  return llvm::APIntOps::umin(a, b);
1159  });
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // MulFOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1167  // mulf(x, 1) -> x
1168  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1169  return getLhs();
1170 
1171  return constFoldBinaryOp<FloatAttr>(
1172  adaptor.getOperands(),
1173  [](const APFloat &a, const APFloat &b) { return a * b; });
1174 }
1175 
1176 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1177  MLIRContext *context) {
1178  patterns.add<MulFOfNegF>(context);
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // DivFOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1186  // divf(x, 1) -> x
1187  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1188  return getLhs();
1189 
1190  return constFoldBinaryOp<FloatAttr>(
1191  adaptor.getOperands(),
1192  [](const APFloat &a, const APFloat &b) { return a / b; });
1193 }
1194 
1195 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1196  MLIRContext *context) {
1197  patterns.add<DivFOfNegF>(context);
1198 }
1199 
1200 //===----------------------------------------------------------------------===//
1201 // RemFOp
1202 //===----------------------------------------------------------------------===//
1203 
1204 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1205  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1206  [](const APFloat &a, const APFloat &b) {
1207  APFloat result(a);
1208  // APFloat::mod() offers the remainder
1209  // behavior we want, i.e. the result has
1210  // the sign of LHS operand.
1211  (void)result.mod(b);
1212  return result;
1213  });
1214 }
1215 
1216 //===----------------------------------------------------------------------===//
1217 // Utility functions for verifying cast ops
1218 //===----------------------------------------------------------------------===//
1219 
1220 template <typename... Types>
1221 using type_list = std::tuple<Types...> *;
1222 
1223 /// Returns a non-null type only if the provided type is one of the allowed
1224 /// types or one of the allowed shaped types of the allowed types. Returns the
1225 /// element type if a valid shaped type is provided.
1226 template <typename... ShapedTypes, typename... ElementTypes>
1229  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1230  return {};
1231 
1232  auto underlyingType = getElementTypeOrSelf(type);
1233  if (!llvm::isa<ElementTypes...>(underlyingType))
1234  return {};
1235 
1236  return underlyingType;
1237 }
1238 
1239 /// Get allowed underlying types for vectors and tensors.
1240 template <typename... ElementTypes>
1241 static Type getTypeIfLike(Type type) {
1244 }
1245 
1246 /// Get allowed underlying types for vectors, tensors, and memrefs.
1247 template <typename... ElementTypes>
1249  return getUnderlyingType(type,
1252 }
1253 
1254 /// Return false if both types are ranked tensor with mismatching encoding.
1255 static bool hasSameEncoding(Type typeA, Type typeB) {
1256  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1257  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1258  if (!rankedTensorA || !rankedTensorB)
1259  return true;
1260  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1261 }
1262 
1263 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1264  if (inputs.size() != 1 || outputs.size() != 1)
1265  return false;
1266  if (!hasSameEncoding(inputs.front(), outputs.front()))
1267  return false;
1268  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1269 }
1270 
1271 //===----------------------------------------------------------------------===//
1272 // Verifiers for integer and floating point extension/truncation ops
1273 //===----------------------------------------------------------------------===//
1274 
1275 // Extend ops can only extend to a wider type.
1276 template <typename ValType, typename Op>
1277 static LogicalResult verifyExtOp(Op op) {
1278  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1279  Type dstType = getElementTypeOrSelf(op.getType());
1280 
1281  if (llvm::cast<ValType>(srcType).getWidth() >=
1282  llvm::cast<ValType>(dstType).getWidth())
1283  return op.emitError("result type ")
1284  << dstType << " must be wider than operand type " << srcType;
1285 
1286  return success();
1287 }
1288 
1289 // Truncate ops can only truncate to a shorter type.
1290 template <typename ValType, typename Op>
1291 static LogicalResult verifyTruncateOp(Op op) {
1292  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1293  Type dstType = getElementTypeOrSelf(op.getType());
1294 
1295  if (llvm::cast<ValType>(srcType).getWidth() <=
1296  llvm::cast<ValType>(dstType).getWidth())
1297  return op.emitError("result type ")
1298  << dstType << " must be shorter than operand type " << srcType;
1299 
1300  return success();
1301 }
1302 
1303 /// Validate a cast that changes the width of a type.
1304 template <template <typename> class WidthComparator, typename... ElementTypes>
1305 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1306  if (!areValidCastInputsAndOutputs(inputs, outputs))
1307  return false;
1308 
1309  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1310  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1311  if (!srcType || !dstType)
1312  return false;
1313 
1314  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1315  srcType.getIntOrFloatBitWidth());
1316 }
1317 
1318 /// Attempts to convert `sourceValue` to an APFloat value with
1319 /// `targetSemantics` and `roundingMode`, without any information loss.
1320 static FailureOr<APFloat> convertFloatValue(
1321  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1322  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1323  bool losesInfo = false;
1324  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1325  if (losesInfo || status != APFloat::opOK)
1326  return failure();
1327 
1328  return sourceValue;
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // ExtUIOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1336  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1337  getInMutable().assign(lhs.getIn());
1338  return getResult();
1339  }
1340 
1341  Type resType = getElementTypeOrSelf(getType());
1342  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1343  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1344  adaptor.getOperands(), getType(),
1345  [bitWidth](const APInt &a, bool &castStatus) {
1346  return a.zext(bitWidth);
1347  });
1348 }
1349 
1350 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1351  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1352 }
1353 
1354 LogicalResult arith::ExtUIOp::verify() {
1355  return verifyExtOp<IntegerType>(*this);
1356 }
1357 
1358 //===----------------------------------------------------------------------===//
1359 // ExtSIOp
1360 //===----------------------------------------------------------------------===//
1361 
1362 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1363  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1364  getInMutable().assign(lhs.getIn());
1365  return getResult();
1366  }
1367 
1368  Type resType = getElementTypeOrSelf(getType());
1369  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1370  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1371  adaptor.getOperands(), getType(),
1372  [bitWidth](const APInt &a, bool &castStatus) {
1373  return a.sext(bitWidth);
1374  });
1375 }
1376 
1377 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1378  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1379 }
1380 
1381 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1382  MLIRContext *context) {
1383  patterns.add<ExtSIOfExtUI>(context);
1384 }
1385 
1386 LogicalResult arith::ExtSIOp::verify() {
1387  return verifyExtOp<IntegerType>(*this);
1388 }
1389 
1390 //===----------------------------------------------------------------------===//
1391 // ExtFOp
1392 //===----------------------------------------------------------------------===//
1393 
1394 /// Fold extension of float constants when there is no information loss due the
1395 /// difference in fp semantics.
1396 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1397  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1398  if (truncFOp.getOperand().getType() == getType()) {
1399  arith::FastMathFlags truncFMF =
1400  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1401  bool isTruncContract =
1402  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1403  arith::FastMathFlags extFMF =
1404  getFastmath().value_or(arith::FastMathFlags::none);
1405  bool isExtContract =
1406  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1407  if (isTruncContract && isExtContract) {
1408  return truncFOp.getOperand();
1409  }
1410  }
1411  }
1412 
1413  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1414  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1415  return constFoldCastOp<FloatAttr, FloatAttr>(
1416  adaptor.getOperands(), getType(),
1417  [&targetSemantics](const APFloat &a, bool &castStatus) {
1418  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1419  if (failed(result)) {
1420  castStatus = false;
1421  return a;
1422  }
1423  return *result;
1424  });
1425 }
1426 
1427 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1428  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1429 }
1430 
1431 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1432 
1433 //===----------------------------------------------------------------------===//
1434 // TruncIOp
1435 //===----------------------------------------------------------------------===//
1436 
1437 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1438  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1439  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1440  Value src = getOperand().getDefiningOp()->getOperand(0);
1441  Type srcType = getElementTypeOrSelf(src.getType());
1442  Type dstType = getElementTypeOrSelf(getType());
1443  // trunci(zexti(a)) -> trunci(a)
1444  // trunci(sexti(a)) -> trunci(a)
1445  if (llvm::cast<IntegerType>(srcType).getWidth() >
1446  llvm::cast<IntegerType>(dstType).getWidth()) {
1447  setOperand(src);
1448  return getResult();
1449  }
1450 
1451  // trunci(zexti(a)) -> a
1452  // trunci(sexti(a)) -> a
1453  if (srcType == dstType)
1454  return src;
1455  }
1456 
1457  // trunci(trunci(a)) -> trunci(a))
1458  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1459  setOperand(getOperand().getDefiningOp()->getOperand(0));
1460  return getResult();
1461  }
1462 
1463  Type resType = getElementTypeOrSelf(getType());
1464  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1465  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1466  adaptor.getOperands(), getType(),
1467  [bitWidth](const APInt &a, bool &castStatus) {
1468  return a.trunc(bitWidth);
1469  });
1470 }
1471 
1472 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1473  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1474 }
1475 
1476 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1477  MLIRContext *context) {
1478  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1479  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1480  context);
1481 }
1482 
1483 LogicalResult arith::TruncIOp::verify() {
1484  return verifyTruncateOp<IntegerType>(*this);
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // TruncFOp
1489 //===----------------------------------------------------------------------===//
1490 
1491 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1492 /// can be represented without precision loss.
1493 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1494  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1495  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1496  return constFoldCastOp<FloatAttr, FloatAttr>(
1497  adaptor.getOperands(), getType(),
1498  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1499  RoundingMode roundingMode =
1500  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1501  llvm::RoundingMode llvmRoundingMode =
1502  convertArithRoundingModeToLLVMIR(roundingMode);
1503  FailureOr<APFloat> result =
1504  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1505  if (failed(result)) {
1506  castStatus = false;
1507  return a;
1508  }
1509  return *result;
1510  });
1511 }
1512 
1513 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1514  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1515 }
1516 
1517 LogicalResult arith::TruncFOp::verify() {
1518  return verifyTruncateOp<FloatType>(*this);
1519 }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // AndIOp
1523 //===----------------------------------------------------------------------===//
1524 
1525 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1526  MLIRContext *context) {
1527  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1528 }
1529 
1530 //===----------------------------------------------------------------------===//
1531 // OrIOp
1532 //===----------------------------------------------------------------------===//
1533 
1534 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1535  MLIRContext *context) {
1536  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1537 }
1538 
1539 //===----------------------------------------------------------------------===//
1540 // Verifiers for casts between integers and floats.
1541 //===----------------------------------------------------------------------===//
1542 
1543 template <typename From, typename To>
1544 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1545  if (!areValidCastInputsAndOutputs(inputs, outputs))
1546  return false;
1547 
1548  auto srcType = getTypeIfLike<From>(inputs.front());
1549  auto dstType = getTypeIfLike<To>(outputs.back());
1550 
1551  return srcType && dstType;
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // UIToFPOp
1556 //===----------------------------------------------------------------------===//
1557 
1558 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1559  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1560 }
1561 
1562 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1563  Type resEleType = getElementTypeOrSelf(getType());
1564  return constFoldCastOp<IntegerAttr, FloatAttr>(
1565  adaptor.getOperands(), getType(),
1566  [&resEleType](const APInt &a, bool &castStatus) {
1567  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1568  APFloat apf(floatTy.getFloatSemantics(),
1569  APInt::getZero(floatTy.getWidth()));
1570  apf.convertFromAPInt(a, /*IsSigned=*/false,
1571  APFloat::rmNearestTiesToEven);
1572  return apf;
1573  });
1574 }
1575 
1576 //===----------------------------------------------------------------------===//
1577 // SIToFPOp
1578 //===----------------------------------------------------------------------===//
1579 
1580 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1581  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1582 }
1583 
1584 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1585  Type resEleType = getElementTypeOrSelf(getType());
1586  return constFoldCastOp<IntegerAttr, FloatAttr>(
1587  adaptor.getOperands(), getType(),
1588  [&resEleType](const APInt &a, bool &castStatus) {
1589  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1590  APFloat apf(floatTy.getFloatSemantics(),
1591  APInt::getZero(floatTy.getWidth()));
1592  apf.convertFromAPInt(a, /*IsSigned=*/true,
1593  APFloat::rmNearestTiesToEven);
1594  return apf;
1595  });
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // FPToUIOp
1600 //===----------------------------------------------------------------------===//
1601 
1602 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1603  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1604 }
1605 
1606 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1607  Type resType = getElementTypeOrSelf(getType());
1608  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1609  return constFoldCastOp<FloatAttr, IntegerAttr>(
1610  adaptor.getOperands(), getType(),
1611  [&bitWidth](const APFloat &a, bool &castStatus) {
1612  bool ignored;
1613  APSInt api(bitWidth, /*isUnsigned=*/true);
1614  castStatus = APFloat::opInvalidOp !=
1615  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1616  return api;
1617  });
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // FPToSIOp
1622 //===----------------------------------------------------------------------===//
1623 
1624 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1625  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1626 }
1627 
1628 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1629  Type resType = getElementTypeOrSelf(getType());
1630  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1631  return constFoldCastOp<FloatAttr, IntegerAttr>(
1632  adaptor.getOperands(), getType(),
1633  [&bitWidth](const APFloat &a, bool &castStatus) {
1634  bool ignored;
1635  APSInt api(bitWidth, /*isUnsigned=*/false);
1636  castStatus = APFloat::opInvalidOp !=
1637  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1638  return api;
1639  });
1640 }
1641 
1642 //===----------------------------------------------------------------------===//
1643 // IndexCastOp
1644 //===----------------------------------------------------------------------===//
1645 
1646 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1647  if (!areValidCastInputsAndOutputs(inputs, outputs))
1648  return false;
1649 
1650  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1651  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1652  if (!srcType || !dstType)
1653  return false;
1654 
1655  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1656  (srcType.isSignlessInteger() && dstType.isIndex());
1657 }
1658 
1659 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1660  TypeRange outputs) {
1661  return areIndexCastCompatible(inputs, outputs);
1662 }
1663 
1664 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1665  // index_cast(constant) -> constant
1666  unsigned resultBitwidth = 64; // Default for index integer attributes.
1667  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1668  resultBitwidth = intTy.getWidth();
1669 
1670  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1671  adaptor.getOperands(), getType(),
1672  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1673  return a.sextOrTrunc(resultBitwidth);
1674  });
1675 }
1676 
1677 void arith::IndexCastOp::getCanonicalizationPatterns(
1678  RewritePatternSet &patterns, MLIRContext *context) {
1679  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1680 }
1681 
1682 //===----------------------------------------------------------------------===//
1683 // IndexCastUIOp
1684 //===----------------------------------------------------------------------===//
1685 
1686 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1687  TypeRange outputs) {
1688  return areIndexCastCompatible(inputs, outputs);
1689 }
1690 
1691 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1692  // index_castui(constant) -> constant
1693  unsigned resultBitwidth = 64; // Default for index integer attributes.
1694  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1695  resultBitwidth = intTy.getWidth();
1696 
1697  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1698  adaptor.getOperands(), getType(),
1699  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1700  return a.zextOrTrunc(resultBitwidth);
1701  });
1702 }
1703 
1704 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1705  RewritePatternSet &patterns, MLIRContext *context) {
1706  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // BitcastOp
1711 //===----------------------------------------------------------------------===//
1712 
1713 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1714  if (!areValidCastInputsAndOutputs(inputs, outputs))
1715  return false;
1716 
1717  auto srcType =
1718  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1719  auto dstType =
1720  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1721  if (!srcType || !dstType)
1722  return false;
1723 
1724  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1725 }
1726 
1727 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1728  auto resType = getType();
1729  auto operand = adaptor.getIn();
1730  if (!operand)
1731  return {};
1732 
1733  /// Bitcast dense elements.
1734  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1735  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1736  /// Other shaped types unhandled.
1737  if (llvm::isa<ShapedType>(resType))
1738  return {};
1739 
1740  /// Bitcast integer or float to integer or float.
1741  APInt bits = llvm::isa<FloatAttr>(operand)
1742  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1743  : llvm::cast<IntegerAttr>(operand).getValue();
1744  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1745  "trying to fold on broken IR: operands have incompatible types");
1746 
1747  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1748  return FloatAttr::get(resType,
1749  APFloat(resFloatType.getFloatSemantics(), bits));
1750  return IntegerAttr::get(resType, bits);
1751 }
1752 
1753 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1754  MLIRContext *context) {
1755  patterns.add<BitcastOfBitcast>(context);
1756 }
1757 
1758 //===----------------------------------------------------------------------===//
1759 // CmpIOp
1760 //===----------------------------------------------------------------------===//
1761 
1762 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1763 /// comparison predicates.
1764 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1765  const APInt &lhs, const APInt &rhs) {
1766  switch (predicate) {
1767  case arith::CmpIPredicate::eq:
1768  return lhs.eq(rhs);
1769  case arith::CmpIPredicate::ne:
1770  return lhs.ne(rhs);
1771  case arith::CmpIPredicate::slt:
1772  return lhs.slt(rhs);
1773  case arith::CmpIPredicate::sle:
1774  return lhs.sle(rhs);
1775  case arith::CmpIPredicate::sgt:
1776  return lhs.sgt(rhs);
1777  case arith::CmpIPredicate::sge:
1778  return lhs.sge(rhs);
1779  case arith::CmpIPredicate::ult:
1780  return lhs.ult(rhs);
1781  case arith::CmpIPredicate::ule:
1782  return lhs.ule(rhs);
1783  case arith::CmpIPredicate::ugt:
1784  return lhs.ugt(rhs);
1785  case arith::CmpIPredicate::uge:
1786  return lhs.uge(rhs);
1787  }
1788  llvm_unreachable("unknown cmpi predicate kind");
1789 }
1790 
1791 /// Returns true if the predicate is true for two equal operands.
1792 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1793  switch (predicate) {
1794  case arith::CmpIPredicate::eq:
1795  case arith::CmpIPredicate::sle:
1796  case arith::CmpIPredicate::sge:
1797  case arith::CmpIPredicate::ule:
1798  case arith::CmpIPredicate::uge:
1799  return true;
1800  case arith::CmpIPredicate::ne:
1801  case arith::CmpIPredicate::slt:
1802  case arith::CmpIPredicate::sgt:
1803  case arith::CmpIPredicate::ult:
1804  case arith::CmpIPredicate::ugt:
1805  return false;
1806  }
1807  llvm_unreachable("unknown cmpi predicate kind");
1808 }
1809 
1810 static std::optional<int64_t> getIntegerWidth(Type t) {
1811  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1812  return intType.getWidth();
1813  }
1814  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1815  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1816  }
1817  return std::nullopt;
1818 }
1819 
1820 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1821  // cmpi(pred, x, x)
1822  if (getLhs() == getRhs()) {
1823  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1824  return getBoolAttribute(getType(), val);
1825  }
1826 
1827  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1828  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1829  // extsi(%x : i1 -> iN) != 0 -> %x
1830  std::optional<int64_t> integerWidth =
1831  getIntegerWidth(extOp.getOperand().getType());
1832  if (integerWidth && integerWidth.value() == 1 &&
1833  getPredicate() == arith::CmpIPredicate::ne)
1834  return extOp.getOperand();
1835  }
1836  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1837  // extui(%x : i1 -> iN) != 0 -> %x
1838  std::optional<int64_t> integerWidth =
1839  getIntegerWidth(extOp.getOperand().getType());
1840  if (integerWidth && integerWidth.value() == 1 &&
1841  getPredicate() == arith::CmpIPredicate::ne)
1842  return extOp.getOperand();
1843  }
1844  }
1845 
1846  // Move constant to the right side.
1847  if (adaptor.getLhs() && !adaptor.getRhs()) {
1848  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1849  using Pred = CmpIPredicate;
1850  const std::pair<Pred, Pred> invPreds[] = {
1851  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1852  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1853  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1854  {Pred::ne, Pred::ne},
1855  };
1856  Pred origPred = getPredicate();
1857  for (auto pred : invPreds) {
1858  if (origPred == pred.first) {
1859  setPredicate(pred.second);
1860  Value lhs = getLhs();
1861  Value rhs = getRhs();
1862  getLhsMutable().assign(rhs);
1863  getRhsMutable().assign(lhs);
1864  return getResult();
1865  }
1866  }
1867  llvm_unreachable("unknown cmpi predicate kind");
1868  }
1869 
1870  // We are moving constants to the right side; So if lhs is constant rhs is
1871  // guaranteed to be a constant.
1872  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1873  return constFoldBinaryOp<IntegerAttr>(
1874  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1875  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1876  return APInt(1,
1877  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1878  });
1879  }
1880 
1881  return {};
1882 }
1883 
1884 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1885  MLIRContext *context) {
1886  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // CmpFOp
1891 //===----------------------------------------------------------------------===//
1892 
1893 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1894 /// comparison predicates.
1895 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1896  const APFloat &lhs, const APFloat &rhs) {
1897  auto cmpResult = lhs.compare(rhs);
1898  switch (predicate) {
1899  case arith::CmpFPredicate::AlwaysFalse:
1900  return false;
1901  case arith::CmpFPredicate::OEQ:
1902  return cmpResult == APFloat::cmpEqual;
1903  case arith::CmpFPredicate::OGT:
1904  return cmpResult == APFloat::cmpGreaterThan;
1905  case arith::CmpFPredicate::OGE:
1906  return cmpResult == APFloat::cmpGreaterThan ||
1907  cmpResult == APFloat::cmpEqual;
1908  case arith::CmpFPredicate::OLT:
1909  return cmpResult == APFloat::cmpLessThan;
1910  case arith::CmpFPredicate::OLE:
1911  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1912  case arith::CmpFPredicate::ONE:
1913  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1914  case arith::CmpFPredicate::ORD:
1915  return cmpResult != APFloat::cmpUnordered;
1916  case arith::CmpFPredicate::UEQ:
1917  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1918  case arith::CmpFPredicate::UGT:
1919  return cmpResult == APFloat::cmpUnordered ||
1920  cmpResult == APFloat::cmpGreaterThan;
1921  case arith::CmpFPredicate::UGE:
1922  return cmpResult == APFloat::cmpUnordered ||
1923  cmpResult == APFloat::cmpGreaterThan ||
1924  cmpResult == APFloat::cmpEqual;
1925  case arith::CmpFPredicate::ULT:
1926  return cmpResult == APFloat::cmpUnordered ||
1927  cmpResult == APFloat::cmpLessThan;
1928  case arith::CmpFPredicate::ULE:
1929  return cmpResult == APFloat::cmpUnordered ||
1930  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1931  case arith::CmpFPredicate::UNE:
1932  return cmpResult != APFloat::cmpEqual;
1933  case arith::CmpFPredicate::UNO:
1934  return cmpResult == APFloat::cmpUnordered;
1935  case arith::CmpFPredicate::AlwaysTrue:
1936  return true;
1937  }
1938  llvm_unreachable("unknown cmpf predicate kind");
1939 }
1940 
1941 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1942  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1943  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1944 
1945  // If one operand is NaN, making them both NaN does not change the result.
1946  if (lhs && lhs.getValue().isNaN())
1947  rhs = lhs;
1948  if (rhs && rhs.getValue().isNaN())
1949  lhs = rhs;
1950 
1951  if (!lhs || !rhs)
1952  return {};
1953 
1954  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1955  return BoolAttr::get(getContext(), val);
1956 }
1957 
1958 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1959 public:
1961 
1962  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1963  bool isUnsigned) {
1964  using namespace arith;
1965  switch (pred) {
1966  case CmpFPredicate::UEQ:
1967  case CmpFPredicate::OEQ:
1968  return CmpIPredicate::eq;
1969  case CmpFPredicate::UGT:
1970  case CmpFPredicate::OGT:
1971  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1972  case CmpFPredicate::UGE:
1973  case CmpFPredicate::OGE:
1974  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1975  case CmpFPredicate::ULT:
1976  case CmpFPredicate::OLT:
1977  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1978  case CmpFPredicate::ULE:
1979  case CmpFPredicate::OLE:
1980  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1981  case CmpFPredicate::UNE:
1982  case CmpFPredicate::ONE:
1983  return CmpIPredicate::ne;
1984  default:
1985  llvm_unreachable("Unexpected predicate!");
1986  }
1987  }
1988 
1989  LogicalResult matchAndRewrite(CmpFOp op,
1990  PatternRewriter &rewriter) const override {
1991  FloatAttr flt;
1992  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1993  return failure();
1994 
1995  const APFloat &rhs = flt.getValue();
1996 
1997  // Don't attempt to fold a nan.
1998  if (rhs.isNaN())
1999  return failure();
2000 
2001  // Get the width of the mantissa. We don't want to hack on conversions that
2002  // might lose information from the integer, e.g. "i64 -> float"
2003  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2004  int mantissaWidth = floatTy.getFPMantissaWidth();
2005  if (mantissaWidth <= 0)
2006  return failure();
2007 
2008  bool isUnsigned;
2009  Value intVal;
2010 
2011  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2012  isUnsigned = false;
2013  intVal = si.getIn();
2014  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2015  isUnsigned = true;
2016  intVal = ui.getIn();
2017  } else {
2018  return failure();
2019  }
2020 
2021  // Check to see that the input is converted from an integer type that is
2022  // small enough that preserves all bits.
2023  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2024  auto intWidth = intTy.getWidth();
2025 
2026  // Number of bits representing values, as opposed to the sign
2027  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2028 
2029  // Following test does NOT adjust intWidth downwards for signed inputs,
2030  // because the most negative value still requires all the mantissa bits
2031  // to distinguish it from one less than that value.
2032  if ((int)intWidth > mantissaWidth) {
2033  // Conversion would lose accuracy. Check if loss can impact comparison.
2034  int exponent = ilogb(rhs);
2035  if (exponent == APFloat::IEK_Inf) {
2036  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2037  if (maxExponent < (int)valueBits) {
2038  // Conversion could create infinity.
2039  return failure();
2040  }
2041  } else {
2042  // Note that if rhs is zero or NaN, then Exp is negative
2043  // and first condition is trivially false.
2044  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2045  // Conversion could affect comparison.
2046  return failure();
2047  }
2048  }
2049  }
2050 
2051  // Convert to equivalent cmpi predicate
2052  CmpIPredicate pred;
2053  switch (op.getPredicate()) {
2054  case CmpFPredicate::ORD:
2055  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2056  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2057  /*width=*/1);
2058  return success();
2059  case CmpFPredicate::UNO:
2060  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2061  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2062  /*width=*/1);
2063  return success();
2064  default:
2065  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2066  break;
2067  }
2068 
2069  if (!isUnsigned) {
2070  // If the rhs value is > SignedMax, fold the comparison. This handles
2071  // +INF and large values.
2072  APFloat signedMax(rhs.getSemantics());
2073  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2074  APFloat::rmNearestTiesToEven);
2075  if (signedMax < rhs) { // smax < 13123.0
2076  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2077  pred == CmpIPredicate::sle)
2078  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2079  /*width=*/1);
2080  else
2081  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2082  /*width=*/1);
2083  return success();
2084  }
2085  } else {
2086  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2087  // +INF and large values.
2088  APFloat unsignedMax(rhs.getSemantics());
2089  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2090  APFloat::rmNearestTiesToEven);
2091  if (unsignedMax < rhs) { // umax < 13123.0
2092  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2093  pred == CmpIPredicate::ule)
2094  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2095  /*width=*/1);
2096  else
2097  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2098  /*width=*/1);
2099  return success();
2100  }
2101  }
2102 
2103  if (!isUnsigned) {
2104  // See if the rhs value is < SignedMin.
2105  APFloat signedMin(rhs.getSemantics());
2106  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2107  APFloat::rmNearestTiesToEven);
2108  if (signedMin > rhs) { // smin > 12312.0
2109  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2110  pred == CmpIPredicate::sge)
2111  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2112  /*width=*/1);
2113  else
2114  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2115  /*width=*/1);
2116  return success();
2117  }
2118  } else {
2119  // See if the rhs value is < UnsignedMin.
2120  APFloat unsignedMin(rhs.getSemantics());
2121  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2122  APFloat::rmNearestTiesToEven);
2123  if (unsignedMin > rhs) { // umin > 12312.0
2124  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2125  pred == CmpIPredicate::uge)
2126  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2127  /*width=*/1);
2128  else
2129  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2130  /*width=*/1);
2131  return success();
2132  }
2133  }
2134 
2135  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2136  // [0, UMAX], but it may still be fractional. See if it is fractional by
2137  // casting the FP value to the integer value and back, checking for
2138  // equality. Don't do this for zero, because -0.0 is not fractional.
2139  bool ignored;
2140  APSInt rhsInt(intWidth, isUnsigned);
2141  if (APFloat::opInvalidOp ==
2142  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2143  // Undefined behavior invoked - the destination type can't represent
2144  // the input constant.
2145  return failure();
2146  }
2147 
2148  if (!rhs.isZero()) {
2149  APFloat apf(floatTy.getFloatSemantics(),
2150  APInt::getZero(floatTy.getWidth()));
2151  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2152 
2153  bool equal = apf == rhs;
2154  if (!equal) {
2155  // If we had a comparison against a fractional value, we have to adjust
2156  // the compare predicate and sometimes the value. rhsInt is rounded
2157  // towards zero at this point.
2158  switch (pred) {
2159  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2160  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2161  /*width=*/1);
2162  return success();
2163  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2164  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2165  /*width=*/1);
2166  return success();
2167  case CmpIPredicate::ule:
2168  // (float)int <= 4.4 --> int <= 4
2169  // (float)int <= -4.4 --> false
2170  if (rhs.isNegative()) {
2171  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2172  /*width=*/1);
2173  return success();
2174  }
2175  break;
2176  case CmpIPredicate::sle:
2177  // (float)int <= 4.4 --> int <= 4
2178  // (float)int <= -4.4 --> int < -4
2179  if (rhs.isNegative())
2180  pred = CmpIPredicate::slt;
2181  break;
2182  case CmpIPredicate::ult:
2183  // (float)int < -4.4 --> false
2184  // (float)int < 4.4 --> int <= 4
2185  if (rhs.isNegative()) {
2186  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2187  /*width=*/1);
2188  return success();
2189  }
2190  pred = CmpIPredicate::ule;
2191  break;
2192  case CmpIPredicate::slt:
2193  // (float)int < -4.4 --> int < -4
2194  // (float)int < 4.4 --> int <= 4
2195  if (!rhs.isNegative())
2196  pred = CmpIPredicate::sle;
2197  break;
2198  case CmpIPredicate::ugt:
2199  // (float)int > 4.4 --> int > 4
2200  // (float)int > -4.4 --> true
2201  if (rhs.isNegative()) {
2202  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2203  /*width=*/1);
2204  return success();
2205  }
2206  break;
2207  case CmpIPredicate::sgt:
2208  // (float)int > 4.4 --> int > 4
2209  // (float)int > -4.4 --> int >= -4
2210  if (rhs.isNegative())
2211  pred = CmpIPredicate::sge;
2212  break;
2213  case CmpIPredicate::uge:
2214  // (float)int >= -4.4 --> true
2215  // (float)int >= 4.4 --> int > 4
2216  if (rhs.isNegative()) {
2217  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2218  /*width=*/1);
2219  return success();
2220  }
2221  pred = CmpIPredicate::ugt;
2222  break;
2223  case CmpIPredicate::sge:
2224  // (float)int >= -4.4 --> int >= -4
2225  // (float)int >= 4.4 --> int > 4
2226  if (!rhs.isNegative())
2227  pred = CmpIPredicate::sgt;
2228  break;
2229  }
2230  }
2231  }
2232 
2233  // Lower this FP comparison into an appropriate integer version of the
2234  // comparison.
2235  rewriter.replaceOpWithNewOp<CmpIOp>(
2236  op, pred, intVal,
2237  rewriter.create<ConstantOp>(
2238  op.getLoc(), intVal.getType(),
2239  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2240  return success();
2241  }
2242 };
2243 
2244 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2245  MLIRContext *context) {
2246  patterns.insert<CmpFIntToFPConst>(context);
2247 }
2248 
2249 //===----------------------------------------------------------------------===//
2250 // SelectOp
2251 //===----------------------------------------------------------------------===//
2252 
2253 // select %arg, %c1, %c0 => extui %arg
2254 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2256 
2257  LogicalResult matchAndRewrite(arith::SelectOp op,
2258  PatternRewriter &rewriter) const override {
2259  // Cannot extui i1 to i1, or i1 to f32
2260  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2261  return failure();
2262 
2263  // select %x, c1, %c0 => extui %arg
2264  if (matchPattern(op.getTrueValue(), m_One()) &&
2265  matchPattern(op.getFalseValue(), m_Zero())) {
2266  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2267  op.getCondition());
2268  return success();
2269  }
2270 
2271  // select %x, c0, %c1 => extui (xor %arg, true)
2272  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2273  matchPattern(op.getFalseValue(), m_One())) {
2274  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2275  op, op.getType(),
2276  rewriter.create<arith::XOrIOp>(
2277  op.getLoc(), op.getCondition(),
2278  rewriter.create<arith::ConstantIntOp>(
2279  op.getLoc(), 1, op.getCondition().getType())));
2280  return success();
2281  }
2282 
2283  return failure();
2284  }
2285 };
2286 
2287 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2288  MLIRContext *context) {
2289  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2290  SelectI1ToNot, SelectToExtUI>(context);
2291 }
2292 
2293 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2294  Value trueVal = getTrueValue();
2295  Value falseVal = getFalseValue();
2296  if (trueVal == falseVal)
2297  return trueVal;
2298 
2299  Value condition = getCondition();
2300 
2301  // select true, %0, %1 => %0
2302  if (matchPattern(adaptor.getCondition(), m_One()))
2303  return trueVal;
2304 
2305  // select false, %0, %1 => %1
2306  if (matchPattern(adaptor.getCondition(), m_Zero()))
2307  return falseVal;
2308 
2309  // If either operand is fully poisoned, return the other.
2310  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2311  return falseVal;
2312 
2313  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2314  return trueVal;
2315 
2316  // select %x, true, false => %x
2317  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2318  matchPattern(adaptor.getFalseValue(), m_Zero()))
2319  return condition;
2320 
2321  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2322  auto pred = cmp.getPredicate();
2323  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2324  auto cmpLhs = cmp.getLhs();
2325  auto cmpRhs = cmp.getRhs();
2326 
2327  // %0 = arith.cmpi eq, %arg0, %arg1
2328  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2329 
2330  // %0 = arith.cmpi ne, %arg0, %arg1
2331  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2332 
2333  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2334  (cmpRhs == trueVal && cmpLhs == falseVal))
2335  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2336  }
2337  }
2338 
2339  // Constant-fold constant operands over non-splat constant condition.
2340  // select %cst_vec, %cst0, %cst1 => %cst2
2341  if (auto cond =
2342  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2343  if (auto lhs =
2344  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2345  if (auto rhs =
2346  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2347  SmallVector<Attribute> results;
2348  results.reserve(static_cast<size_t>(cond.getNumElements()));
2349  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2350  cond.value_end<BoolAttr>());
2351  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2352  lhs.value_end<Attribute>());
2353  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2354  rhs.value_end<Attribute>());
2355 
2356  for (auto [condVal, lhsVal, rhsVal] :
2357  llvm::zip_equal(condVals, lhsVals, rhsVals))
2358  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2359 
2360  return DenseElementsAttr::get(lhs.getType(), results);
2361  }
2362  }
2363  }
2364 
2365  return nullptr;
2366 }
2367 
2368 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2369  Type conditionType, resultType;
2371  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2372  parser.parseOptionalAttrDict(result.attributes) ||
2373  parser.parseColonType(resultType))
2374  return failure();
2375 
2376  // Check for the explicit condition type if this is a masked tensor or vector.
2377  if (succeeded(parser.parseOptionalComma())) {
2378  conditionType = resultType;
2379  if (parser.parseType(resultType))
2380  return failure();
2381  } else {
2382  conditionType = parser.getBuilder().getI1Type();
2383  }
2384 
2385  result.addTypes(resultType);
2386  return parser.resolveOperands(operands,
2387  {conditionType, resultType, resultType},
2388  parser.getNameLoc(), result.operands);
2389 }
2390 
2392  p << " " << getOperands();
2393  p.printOptionalAttrDict((*this)->getAttrs());
2394  p << " : ";
2395  if (ShapedType condType =
2396  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2397  p << condType << ", ";
2398  p << getType();
2399 }
2400 
2401 LogicalResult arith::SelectOp::verify() {
2402  Type conditionType = getCondition().getType();
2403  if (conditionType.isSignlessInteger(1))
2404  return success();
2405 
2406  // If the result type is a vector or tensor, the type can be a mask with the
2407  // same elements.
2408  Type resultType = getType();
2409  if (!llvm::isa<TensorType, VectorType>(resultType))
2410  return emitOpError() << "expected condition to be a signless i1, but got "
2411  << conditionType;
2412  Type shapedConditionType = getI1SameShape(resultType);
2413  if (conditionType != shapedConditionType) {
2414  return emitOpError() << "expected condition type to have the same shape "
2415  "as the result type, expected "
2416  << shapedConditionType << ", but got "
2417  << conditionType;
2418  }
2419  return success();
2420 }
2421 //===----------------------------------------------------------------------===//
2422 // ShLIOp
2423 //===----------------------------------------------------------------------===//
2424 
2425 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2426  // shli(x, 0) -> x
2427  if (matchPattern(adaptor.getRhs(), m_Zero()))
2428  return getLhs();
2429  // Don't fold if shifting more or equal than the bit width.
2430  bool bounded = false;
2431  auto result = constFoldBinaryOp<IntegerAttr>(
2432  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2433  bounded = b.ult(b.getBitWidth());
2434  return a.shl(b);
2435  });
2436  return bounded ? result : Attribute();
2437 }
2438 
2439 //===----------------------------------------------------------------------===//
2440 // ShRUIOp
2441 //===----------------------------------------------------------------------===//
2442 
2443 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2444  // shrui(x, 0) -> x
2445  if (matchPattern(adaptor.getRhs(), m_Zero()))
2446  return getLhs();
2447  // Don't fold if shifting more or equal than the bit width.
2448  bool bounded = false;
2449  auto result = constFoldBinaryOp<IntegerAttr>(
2450  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2451  bounded = b.ult(b.getBitWidth());
2452  return a.lshr(b);
2453  });
2454  return bounded ? result : Attribute();
2455 }
2456 
2457 //===----------------------------------------------------------------------===//
2458 // ShRSIOp
2459 //===----------------------------------------------------------------------===//
2460 
2461 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2462  // shrsi(x, 0) -> x
2463  if (matchPattern(adaptor.getRhs(), m_Zero()))
2464  return getLhs();
2465  // Don't fold if shifting more or equal than the bit width.
2466  bool bounded = false;
2467  auto result = constFoldBinaryOp<IntegerAttr>(
2468  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2469  bounded = b.ult(b.getBitWidth());
2470  return a.ashr(b);
2471  });
2472  return bounded ? result : Attribute();
2473 }
2474 
2475 //===----------------------------------------------------------------------===//
2476 // Atomic Enum
2477 //===----------------------------------------------------------------------===//
2478 
2479 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2480 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2481  OpBuilder &builder, Location loc,
2482  bool useOnlyFiniteValue) {
2483  switch (kind) {
2484  case AtomicRMWKind::maximumf: {
2485  const llvm::fltSemantics &semantic =
2486  llvm::cast<FloatType>(resultType).getFloatSemantics();
2487  APFloat identity = useOnlyFiniteValue
2488  ? APFloat::getLargest(semantic, /*Negative=*/true)
2489  : APFloat::getInf(semantic, /*Negative=*/true);
2490  return builder.getFloatAttr(resultType, identity);
2491  }
2492  case AtomicRMWKind::maxnumf: {
2493  const llvm::fltSemantics &semantic =
2494  llvm::cast<FloatType>(resultType).getFloatSemantics();
2495  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2496  return builder.getFloatAttr(resultType, identity);
2497  }
2498  case AtomicRMWKind::addf:
2499  case AtomicRMWKind::addi:
2500  case AtomicRMWKind::maxu:
2501  case AtomicRMWKind::ori:
2502  return builder.getZeroAttr(resultType);
2503  case AtomicRMWKind::andi:
2504  return builder.getIntegerAttr(
2505  resultType,
2506  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2507  case AtomicRMWKind::maxs:
2508  return builder.getIntegerAttr(
2509  resultType, APInt::getSignedMinValue(
2510  llvm::cast<IntegerType>(resultType).getWidth()));
2511  case AtomicRMWKind::minimumf: {
2512  const llvm::fltSemantics &semantic =
2513  llvm::cast<FloatType>(resultType).getFloatSemantics();
2514  APFloat identity = useOnlyFiniteValue
2515  ? APFloat::getLargest(semantic, /*Negative=*/false)
2516  : APFloat::getInf(semantic, /*Negative=*/false);
2517 
2518  return builder.getFloatAttr(resultType, identity);
2519  }
2520  case AtomicRMWKind::minnumf: {
2521  const llvm::fltSemantics &semantic =
2522  llvm::cast<FloatType>(resultType).getFloatSemantics();
2523  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2524  return builder.getFloatAttr(resultType, identity);
2525  }
2526  case AtomicRMWKind::mins:
2527  return builder.getIntegerAttr(
2528  resultType, APInt::getSignedMaxValue(
2529  llvm::cast<IntegerType>(resultType).getWidth()));
2530  case AtomicRMWKind::minu:
2531  return builder.getIntegerAttr(
2532  resultType,
2533  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2534  case AtomicRMWKind::muli:
2535  return builder.getIntegerAttr(resultType, 1);
2536  case AtomicRMWKind::mulf:
2537  return builder.getFloatAttr(resultType, 1);
2538  // TODO: Add remaining reduction operations.
2539  default:
2540  (void)emitOptionalError(loc, "Reduction operation type not supported");
2541  break;
2542  }
2543  return nullptr;
2544 }
2545 
2546 /// Return the identity numeric value associated to the give op.
2547 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2548  std::optional<AtomicRMWKind> maybeKind =
2550  // Floating-point operations.
2551  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2552  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2553  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2554  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2555  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2556  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2557  // Integer operations.
2558  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2559  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2560  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2561  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2562  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2563  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2564  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2565  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2566  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2567  .Default([](Operation *op) { return std::nullopt; });
2568  if (!maybeKind) {
2569  return std::nullopt;
2570  }
2571 
2572  bool useOnlyFiniteValue = false;
2573  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2574  if (fmfOpInterface) {
2575  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2576  useOnlyFiniteValue =
2577  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2578  }
2579 
2580  // Builder only used as helper for attribute creation.
2581  OpBuilder b(op->getContext());
2582  Type resultType = op->getResult(0).getType();
2583 
2584  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2585  useOnlyFiniteValue);
2586 }
2587 
2588 /// Returns the identity value associated with an AtomicRMWKind op.
2589 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2590  OpBuilder &builder, Location loc,
2591  bool useOnlyFiniteValue) {
2592  auto attr =
2593  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2594  return builder.create<arith::ConstantOp>(loc, attr);
2595 }
2596 
2597 /// Return the value obtained by applying the reduction operation kind
2598 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2599 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2600  Location loc, Value lhs, Value rhs) {
2601  switch (op) {
2602  case AtomicRMWKind::addf:
2603  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2604  case AtomicRMWKind::addi:
2605  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2606  case AtomicRMWKind::mulf:
2607  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2608  case AtomicRMWKind::muli:
2609  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2610  case AtomicRMWKind::maximumf:
2611  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2612  case AtomicRMWKind::minimumf:
2613  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2614  case AtomicRMWKind::maxnumf:
2615  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2616  case AtomicRMWKind::minnumf:
2617  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2618  case AtomicRMWKind::maxs:
2619  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2620  case AtomicRMWKind::mins:
2621  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2622  case AtomicRMWKind::maxu:
2623  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2624  case AtomicRMWKind::minu:
2625  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2626  case AtomicRMWKind::ori:
2627  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2628  case AtomicRMWKind::andi:
2629  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2630  // TODO: Add remaining reduction operations.
2631  default:
2632  (void)emitOptionalError(loc, "Reduction operation type not supported");
2633  break;
2634  }
2635  return nullptr;
2636 }
2637 
2638 //===----------------------------------------------------------------------===//
2639 // TableGen'd op method definitions
2640 //===----------------------------------------------------------------------===//
2641 
2642 #define GET_OP_CLASSES
2643 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2644 
2645 //===----------------------------------------------------------------------===//
2646 // TableGen'd enum attribute definitions
2647 //===----------------------------------------------------------------------===//
2648 
2649 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:599
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1305
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1241
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1792
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1255
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1227
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:655
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:637
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:141
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:149
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1320
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1646
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1544
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1277
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:129
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:836
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1248
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:170
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1810
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1263
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1221
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:345
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1291
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1989
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1962
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:250
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:273
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:99
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:343
IntegerType getI1Type()
Definition: Builders.cpp:85
IndexType getIndexType()
Definition: Builders.cpp:83
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:70
bool isIndex() const
Definition: Types.cpp:59
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:122
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:278
static bool classof(Operation *op)
Definition: ArithOps.cpp:284
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:290
static bool classof(Operation *op)
Definition: ArithOps.cpp:296
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:257
static bool classof(Operation *op)
Definition: ArithOps.cpp:272
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2547
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1764
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2480
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2599
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2589
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
auto m_Val(Value v)
Definition: Matchers.h:534
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:466
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:496
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:437
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:430
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:457
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:422
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:450
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2257
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes