MLIR  18.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/APSInt.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 
32 using namespace mlir;
33 using namespace mlir::arith;
34 
35 //===----------------------------------------------------------------------===//
36 // Pattern helpers
37 //===----------------------------------------------------------------------===//
38 
39 static IntegerAttr
41  Attribute rhs,
42  function_ref<APInt(const APInt &, const APInt &)> binFn) {
43  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
44  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
45  APInt value = binFn(lhsVal, rhsVal);
46  return IntegerAttr::get(res.getType(), value);
47 }
48 
49 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
50  Attribute lhs, Attribute rhs) {
51  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
52 }
53 
54 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
55  Attribute lhs, Attribute rhs) {
56  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
57 }
58 
59 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
60  Attribute lhs, Attribute rhs) {
61  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
62 }
63 
64 /// Invert an integer comparison predicate.
65 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
66  switch (pred) {
67  case arith::CmpIPredicate::eq:
68  return arith::CmpIPredicate::ne;
69  case arith::CmpIPredicate::ne:
70  return arith::CmpIPredicate::eq;
71  case arith::CmpIPredicate::slt:
72  return arith::CmpIPredicate::sge;
73  case arith::CmpIPredicate::sle:
74  return arith::CmpIPredicate::sgt;
75  case arith::CmpIPredicate::sgt:
76  return arith::CmpIPredicate::sle;
77  case arith::CmpIPredicate::sge:
78  return arith::CmpIPredicate::slt;
79  case arith::CmpIPredicate::ult:
80  return arith::CmpIPredicate::uge;
81  case arith::CmpIPredicate::ule:
82  return arith::CmpIPredicate::ugt;
83  case arith::CmpIPredicate::ugt:
84  return arith::CmpIPredicate::ule;
85  case arith::CmpIPredicate::uge:
86  return arith::CmpIPredicate::ult;
87  }
88  llvm_unreachable("unknown cmpi predicate kind");
89 }
90 
91 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
92  return arith::CmpIPredicateAttr::get(pred.getContext(),
93  invertPredicate(pred.getValue()));
94 }
95 
96 static int64_t getScalarOrElementWidth(Type type) {
97  Type elemTy = getElementTypeOrSelf(type);
98  if (elemTy.isIntOrFloat())
99  return elemTy.getIntOrFloatBitWidth();
100 
101  return -1;
102 }
103 
104 static int64_t getScalarOrElementWidth(Value value) {
105  return getScalarOrElementWidth(value.getType());
106 }
107 
109  APInt value;
110  if (matchPattern(attr, m_ConstantInt(&value)))
111  return value;
112 
113  return failure();
114 }
115 
116 static Attribute getBoolAttribute(Type type, bool value) {
117  auto boolAttr = BoolAttr::get(type.getContext(), value);
118  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
119  if (!shapedType)
120  return boolAttr;
121  return DenseElementsAttr::get(shapedType, boolAttr);
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // TableGen'd canonicalization patterns
126 //===----------------------------------------------------------------------===//
127 
128 namespace {
129 #include "ArithCanonicalization.inc"
130 } // namespace
131 
132 //===----------------------------------------------------------------------===//
133 // Common helpers
134 //===----------------------------------------------------------------------===//
135 
136 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
137 static Type getI1SameShape(Type type) {
138  auto i1Type = IntegerType::get(type.getContext(), 1);
139  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
140  return shapedType.cloneWith(std::nullopt, i1Type);
141  if (llvm::isa<UnrankedTensorType>(type))
142  return UnrankedTensorType::get(i1Type);
143  return i1Type;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // ConstantOp
148 //===----------------------------------------------------------------------===//
149 
150 void arith::ConstantOp::getAsmResultNames(
151  function_ref<void(Value, StringRef)> setNameFn) {
152  auto type = getType();
153  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
154  auto intType = llvm::dyn_cast<IntegerType>(type);
155 
156  // Sugar i1 constants with 'true' and 'false'.
157  if (intType && intType.getWidth() == 1)
158  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
159 
160  // Otherwise, build a complex name with the value and type.
161  SmallString<32> specialNameBuffer;
162  llvm::raw_svector_ostream specialName(specialNameBuffer);
163  specialName << 'c' << intCst.getValue();
164  if (intType)
165  specialName << '_' << type;
166  setNameFn(getResult(), specialName.str());
167  } else {
168  setNameFn(getResult(), "cst");
169  }
170 }
171 
172 /// TODO: disallow arith.constant to return anything other than signless integer
173 /// or float like.
175  auto type = getType();
176  // The value's type must match the return type.
177  if (getValue().getType() != type) {
178  return emitOpError() << "value type " << getValue().getType()
179  << " must match return type: " << type;
180  }
181  // Integer values must be signless.
182  if (llvm::isa<IntegerType>(type) &&
183  !llvm::cast<IntegerType>(type).isSignless())
184  return emitOpError("integer return type must be signless");
185  // Any float or elements attribute are acceptable.
186  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
187  return emitOpError(
188  "value must be an integer, float, or elements attribute");
189  }
190  return success();
191 }
192 
193 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
194  // The value's type must be the same as the provided type.
195  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
196  if (!typedAttr || typedAttr.getType() != type)
197  return false;
198  // Integer values must be signless.
199  if (llvm::isa<IntegerType>(type) &&
200  !llvm::cast<IntegerType>(type).isSignless())
201  return false;
202  // Integer, float, and element attributes are buildable.
203  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
204 }
205 
206 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
207  Type type, Location loc) {
208  if (isBuildableWith(value, type))
209  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
210  return nullptr;
211 }
212 
213 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
214 
216  int64_t value, unsigned width) {
217  auto type = builder.getIntegerType(width);
218  arith::ConstantOp::build(builder, result, type,
219  builder.getIntegerAttr(type, value));
220 }
221 
223  int64_t value, Type type) {
224  assert(type.isSignlessInteger() &&
225  "ConstantIntOp can only have signless integer type values");
226  arith::ConstantOp::build(builder, result, type,
227  builder.getIntegerAttr(type, value));
228 }
229 
231  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
232  return constOp.getType().isSignlessInteger();
233  return false;
234 }
235 
237  const APFloat &value, FloatType type) {
238  arith::ConstantOp::build(builder, result, type,
239  builder.getFloatAttr(type, value));
240 }
241 
243  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
244  return llvm::isa<FloatType>(constOp.getType());
245  return false;
246 }
247 
249  int64_t value) {
250  arith::ConstantOp::build(builder, result, builder.getIndexType(),
251  builder.getIndexAttr(value));
252 }
253 
255  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
256  return constOp.getType().isIndex();
257  return false;
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // AddIOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
265  // addi(x, 0) -> x
266  if (matchPattern(adaptor.getRhs(), m_Zero()))
267  return getLhs();
268 
269  // addi(subi(a, b), b) -> a
270  if (auto sub = getLhs().getDefiningOp<SubIOp>())
271  if (getRhs() == sub.getRhs())
272  return sub.getLhs();
273 
274  // addi(b, subi(a, b)) -> a
275  if (auto sub = getRhs().getDefiningOp<SubIOp>())
276  if (getLhs() == sub.getRhs())
277  return sub.getLhs();
278 
279  return constFoldBinaryOp<IntegerAttr>(
280  adaptor.getOperands(),
281  [](APInt a, const APInt &b) { return std::move(a) + b; });
282 }
283 
284 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
285  MLIRContext *context) {
286  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
287  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // AddUIExtendedOp
292 //===----------------------------------------------------------------------===//
293 
294 std::optional<SmallVector<int64_t, 4>>
295 arith::AddUIExtendedOp::getShapeForUnroll() {
296  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
297  return llvm::to_vector<4>(vt.getShape());
298  return std::nullopt;
299 }
300 
301 // Returns the overflow bit, assuming that `sum` is the result of unsigned
302 // addition of `operand` and another number.
303 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
304  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
305 }
306 
308 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
310  Type overflowTy = getOverflow().getType();
311  // addui_extended(x, 0) -> x, false
312  if (matchPattern(getRhs(), m_Zero())) {
313  Builder builder(getContext());
314  auto falseValue = builder.getZeroAttr(overflowTy);
315 
316  results.push_back(getLhs());
317  results.push_back(falseValue);
318  return success();
319  }
320 
321  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
322  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
323  // operands. If that succeeds, calculate the overflow bit based on the sum
324  // and the first (constant) operand, `lhs`.
325  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
326  adaptor.getOperands(),
327  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
328  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
329  ArrayRef({sumAttr, adaptor.getLhs()}),
330  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
332  if (!overflowAttr)
333  return failure();
334 
335  results.push_back(sumAttr);
336  results.push_back(overflowAttr);
337  return success();
338  }
339 
340  return failure();
341 }
342 
343 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
344  RewritePatternSet &patterns, MLIRContext *context) {
345  patterns.add<AddUIExtendedToAddI>(context);
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // SubIOp
350 //===----------------------------------------------------------------------===//
351 
352 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
353  // subi(x,x) -> 0
354  if (getOperand(0) == getOperand(1))
355  return Builder(getContext()).getZeroAttr(getType());
356  // subi(x,0) -> x
357  if (matchPattern(adaptor.getRhs(), m_Zero()))
358  return getLhs();
359 
360  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
361  // subi(addi(a, b), b) -> a
362  if (getRhs() == add.getRhs())
363  return add.getLhs();
364  // subi(addi(a, b), a) -> b
365  if (getRhs() == add.getLhs())
366  return add.getRhs();
367  }
368 
369  return constFoldBinaryOp<IntegerAttr>(
370  adaptor.getOperands(),
371  [](APInt a, const APInt &b) { return std::move(a) - b; });
372 }
373 
374 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
375  MLIRContext *context) {
376  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
377  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
378  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // MulIOp
383 //===----------------------------------------------------------------------===//
384 
385 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
386  // muli(x, 0) -> 0
387  if (matchPattern(adaptor.getRhs(), m_Zero()))
388  return getRhs();
389  // muli(x, 1) -> x
390  if (matchPattern(adaptor.getRhs(), m_One()))
391  return getLhs();
392  // TODO: Handle the overflow case.
393 
394  // default folder
395  return constFoldBinaryOp<IntegerAttr>(
396  adaptor.getOperands(),
397  [](const APInt &a, const APInt &b) { return a * b; });
398 }
399 
400 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
401  MLIRContext *context) {
402  patterns.add<MulIMulIConstant>(context);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // MulSIExtendedOp
407 //===----------------------------------------------------------------------===//
408 
409 std::optional<SmallVector<int64_t, 4>>
410 arith::MulSIExtendedOp::getShapeForUnroll() {
411  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
412  return llvm::to_vector<4>(vt.getShape());
413  return std::nullopt;
414 }
415 
417 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
419  // mulsi_extended(x, 0) -> 0, 0
420  if (matchPattern(adaptor.getRhs(), m_Zero())) {
421  Attribute zero = adaptor.getRhs();
422  results.push_back(zero);
423  results.push_back(zero);
424  return success();
425  }
426 
427  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
428  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
429  adaptor.getOperands(),
430  [](const APInt &a, const APInt &b) { return a * b; })) {
431  // Invoke the constant fold helper again to calculate the 'high' result.
432  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
433  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
434  unsigned bitWidth = a.getBitWidth();
435  APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
436  return fullProduct.extractBits(bitWidth, bitWidth);
437  });
438  assert(highAttr && "Unexpected constant-folding failure");
439 
440  results.push_back(lowAttr);
441  results.push_back(highAttr);
442  return success();
443  }
444 
445  return failure();
446 }
447 
448 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
449  RewritePatternSet &patterns, MLIRContext *context) {
450  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // MulUIExtendedOp
455 //===----------------------------------------------------------------------===//
456 
457 std::optional<SmallVector<int64_t, 4>>
458 arith::MulUIExtendedOp::getShapeForUnroll() {
459  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
460  return llvm::to_vector<4>(vt.getShape());
461  return std::nullopt;
462 }
463 
465 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
467  // mului_extended(x, 0) -> 0, 0
468  if (matchPattern(adaptor.getRhs(), m_Zero())) {
469  Attribute zero = adaptor.getRhs();
470  results.push_back(zero);
471  results.push_back(zero);
472  return success();
473  }
474 
475  // mului_extended(x, 1) -> x, 0
476  if (matchPattern(adaptor.getRhs(), m_One())) {
477  Builder builder(getContext());
478  Attribute zero = builder.getZeroAttr(getLhs().getType());
479  results.push_back(getLhs());
480  results.push_back(zero);
481  return success();
482  }
483 
484  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
485  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
486  adaptor.getOperands(),
487  [](const APInt &a, const APInt &b) { return a * b; })) {
488  // Invoke the constant fold helper again to calculate the 'high' result.
489  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
490  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
491  unsigned bitWidth = a.getBitWidth();
492  APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
493  return fullProduct.extractBits(bitWidth, bitWidth);
494  });
495  assert(highAttr && "Unexpected constant-folding failure");
496 
497  results.push_back(lowAttr);
498  results.push_back(highAttr);
499  return success();
500  }
501 
502  return failure();
503 }
504 
505 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
506  RewritePatternSet &patterns, MLIRContext *context) {
507  patterns.add<MulUIExtendedToMulI>(context);
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // DivUIOp
512 //===----------------------------------------------------------------------===//
513 
514 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
515  // divui (x, 1) -> x.
516  if (matchPattern(adaptor.getRhs(), m_One()))
517  return getLhs();
518 
519  // Don't fold if it would require a division by zero.
520  bool div0 = false;
521  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
522  [&](APInt a, const APInt &b) {
523  if (div0 || !b) {
524  div0 = true;
525  return a;
526  }
527  return a.udiv(b);
528  });
529 
530  return div0 ? Attribute() : result;
531 }
532 
533 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
534  // X / 0 => UB
535  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // DivSIOp
541 //===----------------------------------------------------------------------===//
542 
543 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
544  // divsi (x, 1) -> x.
545  if (matchPattern(adaptor.getRhs(), m_One()))
546  return getLhs();
547 
548  // Don't fold if it would overflow or if it requires a division by zero.
549  bool overflowOrDiv0 = false;
550  auto result = constFoldBinaryOp<IntegerAttr>(
551  adaptor.getOperands(), [&](APInt a, const APInt &b) {
552  if (overflowOrDiv0 || !b) {
553  overflowOrDiv0 = true;
554  return a;
555  }
556  return a.sdiv_ov(b, overflowOrDiv0);
557  });
558 
559  return overflowOrDiv0 ? Attribute() : result;
560 }
561 
562 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
563  bool mayHaveUB = true;
564 
565  APInt constRHS;
566  // X / 0 => UB
567  // INT_MIN / -1 => UB
568  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
569  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
570 
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // Ceil and floor division folding helpers
576 //===----------------------------------------------------------------------===//
577 
578 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
579  bool &overflow) {
580  // Returns (a-1)/b + 1
581  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
582  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
583  return val.sadd_ov(one, overflow);
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // CeilDivUIOp
588 //===----------------------------------------------------------------------===//
589 
590 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
591  // ceildivui (x, 1) -> x.
592  if (matchPattern(adaptor.getRhs(), m_One()))
593  return getLhs();
594 
595  bool overflowOrDiv0 = false;
596  auto result = constFoldBinaryOp<IntegerAttr>(
597  adaptor.getOperands(), [&](APInt a, const APInt &b) {
598  if (overflowOrDiv0 || !b) {
599  overflowOrDiv0 = true;
600  return a;
601  }
602  APInt quotient = a.udiv(b);
603  if (!a.urem(b))
604  return quotient;
605  APInt one(a.getBitWidth(), 1, true);
606  return quotient.uadd_ov(one, overflowOrDiv0);
607  });
608 
609  return overflowOrDiv0 ? Attribute() : result;
610 }
611 
612 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
613  // X / 0 => UB
614  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // CeilDivSIOp
620 //===----------------------------------------------------------------------===//
621 
622 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
623  // ceildivsi (x, 1) -> x.
624  if (matchPattern(adaptor.getRhs(), m_One()))
625  return getLhs();
626 
627  // Don't fold if it would overflow or if it requires a division by zero.
628  bool overflowOrDiv0 = false;
629  auto result = constFoldBinaryOp<IntegerAttr>(
630  adaptor.getOperands(), [&](APInt a, const APInt &b) {
631  if (overflowOrDiv0 || !b) {
632  overflowOrDiv0 = true;
633  return a;
634  }
635  if (!a)
636  return a;
637  // After this point we know that neither a or b are zero.
638  unsigned bits = a.getBitWidth();
639  APInt zero = APInt::getZero(bits);
640  bool aGtZero = a.sgt(zero);
641  bool bGtZero = b.sgt(zero);
642  if (aGtZero && bGtZero) {
643  // Both positive, return ceil(a, b).
644  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
645  }
646  if (!aGtZero && !bGtZero) {
647  // Both negative, return ceil(-a, -b).
648  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
649  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
650  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
651  }
652  if (!aGtZero && bGtZero) {
653  // A is negative, b is positive, return - ( -a / b).
654  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
655  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
656  return zero.ssub_ov(div, overflowOrDiv0);
657  }
658  // A is positive, b is negative, return - (a / -b).
659  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
660  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
661  return zero.ssub_ov(div, overflowOrDiv0);
662  });
663 
664  return overflowOrDiv0 ? Attribute() : result;
665 }
666 
667 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
668  bool mayHaveUB = true;
669 
670  APInt constRHS;
671  // X / 0 => UB
672  // INT_MIN / -1 => UB
673  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
674  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
675 
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // FloorDivSIOp
681 //===----------------------------------------------------------------------===//
682 
683 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
684  // floordivsi (x, 1) -> x.
685  if (matchPattern(adaptor.getRhs(), m_One()))
686  return getLhs();
687 
688  // Don't fold if it would overflow or if it requires a division by zero.
689  bool overflowOrDiv0 = false;
690  auto result = constFoldBinaryOp<IntegerAttr>(
691  adaptor.getOperands(), [&](APInt a, const APInt &b) {
692  if (overflowOrDiv0 || !b) {
693  overflowOrDiv0 = true;
694  return a;
695  }
696  if (!a)
697  return a;
698  // After this point we know that neither a or b are zero.
699  unsigned bits = a.getBitWidth();
700  APInt zero = APInt::getZero(bits);
701  bool aGtZero = a.sgt(zero);
702  bool bGtZero = b.sgt(zero);
703  if (aGtZero && bGtZero) {
704  // Both positive, return a / b.
705  return a.sdiv_ov(b, overflowOrDiv0);
706  }
707  if (!aGtZero && !bGtZero) {
708  // Both negative, return -a / -b.
709  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
710  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
711  return posA.sdiv_ov(posB, overflowOrDiv0);
712  }
713  if (!aGtZero && bGtZero) {
714  // A is negative, b is positive, return - ceil(-a, b).
715  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
716  APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
717  return zero.ssub_ov(ceil, overflowOrDiv0);
718  }
719  // A is positive, b is negative, return - ceil(a, -b).
720  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
721  APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
722  return zero.ssub_ov(ceil, overflowOrDiv0);
723  });
724 
725  return overflowOrDiv0 ? Attribute() : result;
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // RemUIOp
730 //===----------------------------------------------------------------------===//
731 
732 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
733  // remui (x, 1) -> 0.
734  if (matchPattern(adaptor.getRhs(), m_One()))
735  return Builder(getContext()).getZeroAttr(getType());
736 
737  // Don't fold if it would require a division by zero.
738  bool div0 = false;
739  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
740  [&](APInt a, const APInt &b) {
741  if (div0 || b.isZero()) {
742  div0 = true;
743  return a;
744  }
745  return a.urem(b);
746  });
747 
748  return div0 ? Attribute() : result;
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // RemSIOp
753 //===----------------------------------------------------------------------===//
754 
755 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
756  // remsi (x, 1) -> 0.
757  if (matchPattern(adaptor.getRhs(), m_One()))
758  return Builder(getContext()).getZeroAttr(getType());
759 
760  // Don't fold if it would require a division by zero.
761  bool div0 = false;
762  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
763  [&](APInt a, const APInt &b) {
764  if (div0 || b.isZero()) {
765  div0 = true;
766  return a;
767  }
768  return a.srem(b);
769  });
770 
771  return div0 ? Attribute() : result;
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // AndIOp
776 //===----------------------------------------------------------------------===//
777 
778 /// Fold `and(a, and(a, b))` to `and(a, b)`
779 static Value foldAndIofAndI(arith::AndIOp op) {
780  for (bool reversePrev : {false, true}) {
781  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
782  .getDefiningOp<arith::AndIOp>();
783  if (!prev)
784  continue;
785 
786  Value other = (reversePrev ? op.getLhs() : op.getRhs());
787  if (other != prev.getLhs() && other != prev.getRhs())
788  continue;
789 
790  return prev.getResult();
791  }
792  return {};
793 }
794 
795 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
796  /// and(x, 0) -> 0
797  if (matchPattern(adaptor.getRhs(), m_Zero()))
798  return getRhs();
799  /// and(x, allOnes) -> x
800  APInt intValue;
801  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
802  intValue.isAllOnes())
803  return getLhs();
804  /// and(x, not(x)) -> 0
805  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
806  m_ConstantInt(&intValue))) &&
807  intValue.isAllOnes())
808  return Builder(getContext()).getZeroAttr(getType());
809  /// and(not(x), x) -> 0
810  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
811  m_ConstantInt(&intValue))) &&
812  intValue.isAllOnes())
813  return Builder(getContext()).getZeroAttr(getType());
814 
815  /// and(a, and(a, b)) -> and(a, b)
816  if (Value result = foldAndIofAndI(*this))
817  return result;
818 
819  return constFoldBinaryOp<IntegerAttr>(
820  adaptor.getOperands(),
821  [](APInt a, const APInt &b) { return std::move(a) & b; });
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // OrIOp
826 //===----------------------------------------------------------------------===//
827 
828 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
829  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
830  /// or(x, 0) -> x
831  if (rhsVal.isZero())
832  return getLhs();
833  /// or(x, <all ones>) -> <all ones>
834  if (rhsVal.isAllOnes())
835  return adaptor.getRhs();
836  }
837 
838  APInt intValue;
839  /// or(x, xor(x, 1)) -> 1
840  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
841  m_ConstantInt(&intValue))) &&
842  intValue.isAllOnes())
843  return getRhs().getDefiningOp<XOrIOp>().getRhs();
844  /// or(xor(x, 1), x) -> 1
845  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
846  m_ConstantInt(&intValue))) &&
847  intValue.isAllOnes())
848  return getLhs().getDefiningOp<XOrIOp>().getRhs();
849 
850  return constFoldBinaryOp<IntegerAttr>(
851  adaptor.getOperands(),
852  [](APInt a, const APInt &b) { return std::move(a) | b; });
853 }
854 
855 //===----------------------------------------------------------------------===//
856 // XOrIOp
857 //===----------------------------------------------------------------------===//
858 
859 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
860  /// xor(x, 0) -> x
861  if (matchPattern(adaptor.getRhs(), m_Zero()))
862  return getLhs();
863  /// xor(x, x) -> 0
864  if (getLhs() == getRhs())
865  return Builder(getContext()).getZeroAttr(getType());
866  /// xor(xor(x, a), a) -> x
867  /// xor(xor(a, x), a) -> x
868  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
869  if (prev.getRhs() == getRhs())
870  return prev.getLhs();
871  if (prev.getLhs() == getRhs())
872  return prev.getRhs();
873  }
874  /// xor(a, xor(x, a)) -> x
875  /// xor(a, xor(a, x)) -> x
876  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
877  if (prev.getRhs() == getLhs())
878  return prev.getLhs();
879  if (prev.getLhs() == getLhs())
880  return prev.getRhs();
881  }
882 
883  return constFoldBinaryOp<IntegerAttr>(
884  adaptor.getOperands(),
885  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
886 }
887 
888 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
889  MLIRContext *context) {
890  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // NegFOp
895 //===----------------------------------------------------------------------===//
896 
897 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
898  /// negf(negf(x)) -> x
899  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
900  return op.getOperand();
901  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
902  [](const APFloat &a) { return -a; });
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // AddFOp
907 //===----------------------------------------------------------------------===//
908 
909 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
910  // addf(x, -0) -> x
911  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
912  return getLhs();
913 
914  return constFoldBinaryOp<FloatAttr>(
915  adaptor.getOperands(),
916  [](const APFloat &a, const APFloat &b) { return a + b; });
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // SubFOp
921 //===----------------------------------------------------------------------===//
922 
923 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
924  // subf(x, +0) -> x
925  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
926  return getLhs();
927 
928  return constFoldBinaryOp<FloatAttr>(
929  adaptor.getOperands(),
930  [](const APFloat &a, const APFloat &b) { return a - b; });
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // MaximumFOp
935 //===----------------------------------------------------------------------===//
936 
937 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
938  // maximumf(x,x) -> x
939  if (getLhs() == getRhs())
940  return getRhs();
941 
942  // maximumf(x, -inf) -> x
943  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
944  return getLhs();
945 
946  return constFoldBinaryOp<FloatAttr>(
947  adaptor.getOperands(),
948  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // MaxNumFOp
953 //===----------------------------------------------------------------------===//
954 
955 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
956  // maxnumf(x,x) -> x
957  if (getLhs() == getRhs())
958  return getRhs();
959 
960  // maxnumf(x, -inf) -> x
961  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
962  return getLhs();
963 
964  return constFoldBinaryOp<FloatAttr>(
965  adaptor.getOperands(),
966  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
967 }
968 
969 
970 //===----------------------------------------------------------------------===//
971 // MaxSIOp
972 //===----------------------------------------------------------------------===//
973 
974 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
975  // maxsi(x,x) -> x
976  if (getLhs() == getRhs())
977  return getRhs();
978 
979  if (APInt intValue;
980  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
981  // maxsi(x,MAX_INT) -> MAX_INT
982  if (intValue.isMaxSignedValue())
983  return getRhs();
984  // maxsi(x, MIN_INT) -> x
985  if (intValue.isMinSignedValue())
986  return getLhs();
987  }
988 
989  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
990  [](const APInt &a, const APInt &b) {
991  return llvm::APIntOps::smax(a, b);
992  });
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // MaxUIOp
997 //===----------------------------------------------------------------------===//
998 
999 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1000  // maxui(x,x) -> x
1001  if (getLhs() == getRhs())
1002  return getRhs();
1003 
1004  if (APInt intValue;
1005  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1006  // maxui(x,MAX_INT) -> MAX_INT
1007  if (intValue.isMaxValue())
1008  return getRhs();
1009  // maxui(x, MIN_INT) -> x
1010  if (intValue.isMinValue())
1011  return getLhs();
1012  }
1013 
1014  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1015  [](const APInt &a, const APInt &b) {
1016  return llvm::APIntOps::umax(a, b);
1017  });
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // MinimumFOp
1022 //===----------------------------------------------------------------------===//
1023 
1024 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1025  // minimumf(x,x) -> x
1026  if (getLhs() == getRhs())
1027  return getRhs();
1028 
1029  // minimumf(x, +inf) -> x
1030  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1031  return getLhs();
1032 
1033  return constFoldBinaryOp<FloatAttr>(
1034  adaptor.getOperands(),
1035  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // MinNumFOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1043  // minnumf(x,x) -> x
1044  if (getLhs() == getRhs())
1045  return getRhs();
1046 
1047  // minnumf(x, +inf) -> x
1048  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1049  return getLhs();
1050 
1051  return constFoldBinaryOp<FloatAttr>(
1052  adaptor.getOperands(),
1053  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1054 }
1055 
1056 //===----------------------------------------------------------------------===//
1057 // MinSIOp
1058 //===----------------------------------------------------------------------===//
1059 
1060 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1061  // minsi(x,x) -> x
1062  if (getLhs() == getRhs())
1063  return getRhs();
1064 
1065  if (APInt intValue;
1066  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1067  // minsi(x,MIN_INT) -> MIN_INT
1068  if (intValue.isMinSignedValue())
1069  return getRhs();
1070  // minsi(x, MAX_INT) -> x
1071  if (intValue.isMaxSignedValue())
1072  return getLhs();
1073  }
1074 
1075  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1076  [](const APInt &a, const APInt &b) {
1077  return llvm::APIntOps::smin(a, b);
1078  });
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // MinUIOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1086  // minui(x,x) -> x
1087  if (getLhs() == getRhs())
1088  return getRhs();
1089 
1090  if (APInt intValue;
1091  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1092  // minui(x,MIN_INT) -> MIN_INT
1093  if (intValue.isMinValue())
1094  return getRhs();
1095  // minui(x, MAX_INT) -> x
1096  if (intValue.isMaxValue())
1097  return getLhs();
1098  }
1099 
1100  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1101  [](const APInt &a, const APInt &b) {
1102  return llvm::APIntOps::umin(a, b);
1103  });
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // MulFOp
1108 //===----------------------------------------------------------------------===//
1109 
1110 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1111  // mulf(x, 1) -> x
1112  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1113  return getLhs();
1114 
1115  return constFoldBinaryOp<FloatAttr>(
1116  adaptor.getOperands(),
1117  [](const APFloat &a, const APFloat &b) { return a * b; });
1118 }
1119 
1120 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1121  MLIRContext *context) {
1122  patterns.add<MulFOfNegF>(context);
1123 }
1124 
1125 //===----------------------------------------------------------------------===//
1126 // DivFOp
1127 //===----------------------------------------------------------------------===//
1128 
1129 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1130  // divf(x, 1) -> x
1131  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1132  return getLhs();
1133 
1134  return constFoldBinaryOp<FloatAttr>(
1135  adaptor.getOperands(),
1136  [](const APFloat &a, const APFloat &b) { return a / b; });
1137 }
1138 
1139 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1140  MLIRContext *context) {
1141  patterns.add<DivFOfNegF>(context);
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // RemFOp
1146 //===----------------------------------------------------------------------===//
1147 
1148 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1149  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1150  [](const APFloat &a, const APFloat &b) {
1151  APFloat result(a);
1152  (void)result.remainder(b);
1153  return result;
1154  });
1155 }
1156 
1157 //===----------------------------------------------------------------------===//
1158 // Utility functions for verifying cast ops
1159 //===----------------------------------------------------------------------===//
1160 
1161 template <typename... Types>
1162 using type_list = std::tuple<Types...> *;
1163 
1164 /// Returns a non-null type only if the provided type is one of the allowed
1165 /// types or one of the allowed shaped types of the allowed types. Returns the
1166 /// element type if a valid shaped type is provided.
1167 template <typename... ShapedTypes, typename... ElementTypes>
1170  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1171  return {};
1172 
1173  auto underlyingType = getElementTypeOrSelf(type);
1174  if (!llvm::isa<ElementTypes...>(underlyingType))
1175  return {};
1176 
1177  return underlyingType;
1178 }
1179 
1180 /// Get allowed underlying types for vectors and tensors.
1181 template <typename... ElementTypes>
1182 static Type getTypeIfLike(Type type) {
1185 }
1186 
1187 /// Get allowed underlying types for vectors, tensors, and memrefs.
1188 template <typename... ElementTypes>
1190  return getUnderlyingType(type,
1193 }
1194 
1195 /// Return false if both types are ranked tensor with mismatching encoding.
1196 static bool hasSameEncoding(Type typeA, Type typeB) {
1197  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1198  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1199  if (!rankedTensorA || !rankedTensorB)
1200  return true;
1201  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1202 }
1203 
1204 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1205  if (inputs.size() != 1 || outputs.size() != 1)
1206  return false;
1207  if (!hasSameEncoding(inputs.front(), outputs.front()))
1208  return false;
1209  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // Verifiers for integer and floating point extension/truncation ops
1214 //===----------------------------------------------------------------------===//
1215 
1216 // Extend ops can only extend to a wider type.
1217 template <typename ValType, typename Op>
1219  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1220  Type dstType = getElementTypeOrSelf(op.getType());
1221 
1222  if (llvm::cast<ValType>(srcType).getWidth() >=
1223  llvm::cast<ValType>(dstType).getWidth())
1224  return op.emitError("result type ")
1225  << dstType << " must be wider than operand type " << srcType;
1226 
1227  return success();
1228 }
1229 
1230 // Truncate ops can only truncate to a shorter type.
1231 template <typename ValType, typename Op>
1233  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1234  Type dstType = getElementTypeOrSelf(op.getType());
1235 
1236  if (llvm::cast<ValType>(srcType).getWidth() <=
1237  llvm::cast<ValType>(dstType).getWidth())
1238  return op.emitError("result type ")
1239  << dstType << " must be shorter than operand type " << srcType;
1240 
1241  return success();
1242 }
1243 
1244 /// Validate a cast that changes the width of a type.
1245 template <template <typename> class WidthComparator, typename... ElementTypes>
1246 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1247  if (!areValidCastInputsAndOutputs(inputs, outputs))
1248  return false;
1249 
1250  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1251  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1252  if (!srcType || !dstType)
1253  return false;
1254 
1255  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1256  srcType.getIntOrFloatBitWidth());
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // ExtUIOp
1261 //===----------------------------------------------------------------------===//
1262 
1263 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1264  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1265  getInMutable().assign(lhs.getIn());
1266  return getResult();
1267  }
1268 
1269  Type resType = getElementTypeOrSelf(getType());
1270  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1271  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1272  adaptor.getOperands(), getType(),
1273  [bitWidth](const APInt &a, bool &castStatus) {
1274  return a.zext(bitWidth);
1275  });
1276 }
1277 
1278 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1279  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1280 }
1281 
1283  return verifyExtOp<IntegerType>(*this);
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // ExtSIOp
1288 //===----------------------------------------------------------------------===//
1289 
1290 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1291  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1292  getInMutable().assign(lhs.getIn());
1293  return getResult();
1294  }
1295 
1296  Type resType = getElementTypeOrSelf(getType());
1297  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1298  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1299  adaptor.getOperands(), getType(),
1300  [bitWidth](const APInt &a, bool &castStatus) {
1301  return a.sext(bitWidth);
1302  });
1303 }
1304 
1305 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1306  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1307 }
1308 
1309 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1310  MLIRContext *context) {
1311  patterns.add<ExtSIOfExtUI>(context);
1312 }
1313 
1315  return verifyExtOp<IntegerType>(*this);
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // ExtFOp
1320 //===----------------------------------------------------------------------===//
1321 
1322 /// Always fold extension of FP constants.
1323 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1324  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
1325  if (!constOperand)
1326  return {};
1327 
1328  // Convert to target type via 'double'.
1329  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
1330 }
1331 
1332 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1333  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1334 }
1335 
1336 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1337 
1338 //===----------------------------------------------------------------------===//
1339 // TruncIOp
1340 //===----------------------------------------------------------------------===//
1341 
1342 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1343  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1344  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1345  Value src = getOperand().getDefiningOp()->getOperand(0);
1346  Type srcType = getElementTypeOrSelf(src.getType());
1347  Type dstType = getElementTypeOrSelf(getType());
1348  // trunci(zexti(a)) -> trunci(a)
1349  // trunci(sexti(a)) -> trunci(a)
1350  if (llvm::cast<IntegerType>(srcType).getWidth() >
1351  llvm::cast<IntegerType>(dstType).getWidth()) {
1352  setOperand(src);
1353  return getResult();
1354  }
1355  // trunci(zexti(a)) -> a
1356  // trunci(sexti(a)) -> a
1357  return src;
1358  }
1359 
1360  // trunci(trunci(a)) -> trunci(a))
1361  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1362  setOperand(getOperand().getDefiningOp()->getOperand(0));
1363  return getResult();
1364  }
1365 
1366  Type resType = getElementTypeOrSelf(getType());
1367  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1368  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1369  adaptor.getOperands(), getType(),
1370  [bitWidth](const APInt &a, bool &castStatus) {
1371  return a.trunc(bitWidth);
1372  });
1373 }
1374 
1375 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1376  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1377 }
1378 
1379 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1380  MLIRContext *context) {
1381  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1382  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1383  context);
1384 }
1385 
1387  return verifyTruncateOp<IntegerType>(*this);
1388 }
1389 
1390 //===----------------------------------------------------------------------===//
1391 // TruncFOp
1392 //===----------------------------------------------------------------------===//
1393 
1394 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
1395 /// can be represented without precision loss or rounding.
1396 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1397  auto constOperand = adaptor.getIn();
1398  if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
1399  return {};
1400 
1401  // Convert to target type via 'double'.
1402  double sourceValue =
1403  llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
1404  auto targetAttr = FloatAttr::get(getType(), sourceValue);
1405 
1406  // Propagate if constant's value does not change after truncation.
1407  if (sourceValue == targetAttr.getValue().convertToDouble())
1408  return targetAttr;
1409 
1410  return {};
1411 }
1412 
1413 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1414  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1415 }
1416 
1418  return verifyTruncateOp<FloatType>(*this);
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // AndIOp
1423 //===----------------------------------------------------------------------===//
1424 
1425 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1426  MLIRContext *context) {
1427  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1428 }
1429 
1430 //===----------------------------------------------------------------------===//
1431 // OrIOp
1432 //===----------------------------------------------------------------------===//
1433 
1434 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1435  MLIRContext *context) {
1436  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1437 }
1438 
1439 //===----------------------------------------------------------------------===//
1440 // Verifiers for casts between integers and floats.
1441 //===----------------------------------------------------------------------===//
1442 
1443 template <typename From, typename To>
1444 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1445  if (!areValidCastInputsAndOutputs(inputs, outputs))
1446  return false;
1447 
1448  auto srcType = getTypeIfLike<From>(inputs.front());
1449  auto dstType = getTypeIfLike<To>(outputs.back());
1450 
1451  return srcType && dstType;
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // UIToFPOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1459  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1460 }
1461 
1462 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1463  Type resEleType = getElementTypeOrSelf(getType());
1464  return constFoldCastOp<IntegerAttr, FloatAttr>(
1465  adaptor.getOperands(), getType(),
1466  [&resEleType](const APInt &a, bool &castStatus) {
1467  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1468  APFloat apf(floatTy.getFloatSemantics(),
1469  APInt::getZero(floatTy.getWidth()));
1470  apf.convertFromAPInt(a, /*IsSigned=*/false,
1471  APFloat::rmNearestTiesToEven);
1472  return apf;
1473  });
1474 }
1475 
1476 //===----------------------------------------------------------------------===//
1477 // SIToFPOp
1478 //===----------------------------------------------------------------------===//
1479 
1480 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1481  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1482 }
1483 
1484 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1485  Type resEleType = getElementTypeOrSelf(getType());
1486  return constFoldCastOp<IntegerAttr, FloatAttr>(
1487  adaptor.getOperands(), getType(),
1488  [&resEleType](const APInt &a, bool &castStatus) {
1489  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1490  APFloat apf(floatTy.getFloatSemantics(),
1491  APInt::getZero(floatTy.getWidth()));
1492  apf.convertFromAPInt(a, /*IsSigned=*/true,
1493  APFloat::rmNearestTiesToEven);
1494  return apf;
1495  });
1496 }
1497 //===----------------------------------------------------------------------===//
1498 // FPToUIOp
1499 //===----------------------------------------------------------------------===//
1500 
1501 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1502  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1503 }
1504 
1505 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1506  Type resType = getElementTypeOrSelf(getType());
1507  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1508  return constFoldCastOp<FloatAttr, IntegerAttr>(
1509  adaptor.getOperands(), getType(),
1510  [&bitWidth](const APFloat &a, bool &castStatus) {
1511  bool ignored;
1512  APSInt api(bitWidth, /*isUnsigned=*/true);
1513  castStatus = APFloat::opInvalidOp !=
1514  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1515  return api;
1516  });
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // FPToSIOp
1521 //===----------------------------------------------------------------------===//
1522 
1523 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1524  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1525 }
1526 
1527 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1528  Type resType = getElementTypeOrSelf(getType());
1529  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1530  return constFoldCastOp<FloatAttr, IntegerAttr>(
1531  adaptor.getOperands(), getType(),
1532  [&bitWidth](const APFloat &a, bool &castStatus) {
1533  bool ignored;
1534  APSInt api(bitWidth, /*isUnsigned=*/false);
1535  castStatus = APFloat::opInvalidOp !=
1536  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1537  return api;
1538  });
1539 }
1540 
1541 //===----------------------------------------------------------------------===//
1542 // IndexCastOp
1543 //===----------------------------------------------------------------------===//
1544 
1545 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1546  if (!areValidCastInputsAndOutputs(inputs, outputs))
1547  return false;
1548 
1549  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1550  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1551  if (!srcType || !dstType)
1552  return false;
1553 
1554  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1555  (srcType.isSignlessInteger() && dstType.isIndex());
1556 }
1557 
1558 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1559  TypeRange outputs) {
1560  return areIndexCastCompatible(inputs, outputs);
1561 }
1562 
1563 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1564  // index_cast(constant) -> constant
1565  unsigned resultBitwidth = 64; // Default for index integer attributes.
1566  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1567  resultBitwidth = intTy.getWidth();
1568 
1569  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1570  adaptor.getOperands(), getType(),
1571  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1572  return a.sextOrTrunc(resultBitwidth);
1573  });
1574 }
1575 
1576 void arith::IndexCastOp::getCanonicalizationPatterns(
1577  RewritePatternSet &patterns, MLIRContext *context) {
1578  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // IndexCastUIOp
1583 //===----------------------------------------------------------------------===//
1584 
1585 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1586  TypeRange outputs) {
1587  return areIndexCastCompatible(inputs, outputs);
1588 }
1589 
1590 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1591  // index_castui(constant) -> constant
1592  unsigned resultBitwidth = 64; // Default for index integer attributes.
1593  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1594  resultBitwidth = intTy.getWidth();
1595 
1596  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1597  adaptor.getOperands(), getType(),
1598  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1599  return a.zextOrTrunc(resultBitwidth);
1600  });
1601 }
1602 
1603 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1604  RewritePatternSet &patterns, MLIRContext *context) {
1605  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1606 }
1607 
1608 //===----------------------------------------------------------------------===//
1609 // BitcastOp
1610 //===----------------------------------------------------------------------===//
1611 
1612 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1613  if (!areValidCastInputsAndOutputs(inputs, outputs))
1614  return false;
1615 
1616  auto srcType =
1617  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1618  auto dstType =
1619  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1620  if (!srcType || !dstType)
1621  return false;
1622 
1623  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1624 }
1625 
1626 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1627  auto resType = getType();
1628  auto operand = adaptor.getIn();
1629  if (!operand)
1630  return {};
1631 
1632  /// Bitcast dense elements.
1633  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1634  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1635  /// Other shaped types unhandled.
1636  if (llvm::isa<ShapedType>(resType))
1637  return {};
1638 
1639  /// Bitcast integer or float to integer or float.
1640  APInt bits = llvm::isa<FloatAttr>(operand)
1641  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1642  : llvm::cast<IntegerAttr>(operand).getValue();
1643 
1644  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1645  return FloatAttr::get(resType,
1646  APFloat(resFloatType.getFloatSemantics(), bits));
1647  return IntegerAttr::get(resType, bits);
1648 }
1649 
1650 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1651  MLIRContext *context) {
1652  patterns.add<BitcastOfBitcast>(context);
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // CmpIOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1660 /// comparison predicates.
1661 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1662  const APInt &lhs, const APInt &rhs) {
1663  switch (predicate) {
1664  case arith::CmpIPredicate::eq:
1665  return lhs.eq(rhs);
1666  case arith::CmpIPredicate::ne:
1667  return lhs.ne(rhs);
1668  case arith::CmpIPredicate::slt:
1669  return lhs.slt(rhs);
1670  case arith::CmpIPredicate::sle:
1671  return lhs.sle(rhs);
1672  case arith::CmpIPredicate::sgt:
1673  return lhs.sgt(rhs);
1674  case arith::CmpIPredicate::sge:
1675  return lhs.sge(rhs);
1676  case arith::CmpIPredicate::ult:
1677  return lhs.ult(rhs);
1678  case arith::CmpIPredicate::ule:
1679  return lhs.ule(rhs);
1680  case arith::CmpIPredicate::ugt:
1681  return lhs.ugt(rhs);
1682  case arith::CmpIPredicate::uge:
1683  return lhs.uge(rhs);
1684  }
1685  llvm_unreachable("unknown cmpi predicate kind");
1686 }
1687 
1688 /// Returns true if the predicate is true for two equal operands.
1689 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1690  switch (predicate) {
1691  case arith::CmpIPredicate::eq:
1692  case arith::CmpIPredicate::sle:
1693  case arith::CmpIPredicate::sge:
1694  case arith::CmpIPredicate::ule:
1695  case arith::CmpIPredicate::uge:
1696  return true;
1697  case arith::CmpIPredicate::ne:
1698  case arith::CmpIPredicate::slt:
1699  case arith::CmpIPredicate::sgt:
1700  case arith::CmpIPredicate::ult:
1701  case arith::CmpIPredicate::ugt:
1702  return false;
1703  }
1704  llvm_unreachable("unknown cmpi predicate kind");
1705 }
1706 
1707 static std::optional<int64_t> getIntegerWidth(Type t) {
1708  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1709  return intType.getWidth();
1710  }
1711  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1712  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1713  }
1714  return std::nullopt;
1715 }
1716 
1717 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1718  // cmpi(pred, x, x)
1719  if (getLhs() == getRhs()) {
1720  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1721  return getBoolAttribute(getType(), val);
1722  }
1723 
1724  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1725  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1726  // extsi(%x : i1 -> iN) != 0 -> %x
1727  std::optional<int64_t> integerWidth =
1728  getIntegerWidth(extOp.getOperand().getType());
1729  if (integerWidth && integerWidth.value() == 1 &&
1730  getPredicate() == arith::CmpIPredicate::ne)
1731  return extOp.getOperand();
1732  }
1733  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1734  // extui(%x : i1 -> iN) != 0 -> %x
1735  std::optional<int64_t> integerWidth =
1736  getIntegerWidth(extOp.getOperand().getType());
1737  if (integerWidth && integerWidth.value() == 1 &&
1738  getPredicate() == arith::CmpIPredicate::ne)
1739  return extOp.getOperand();
1740  }
1741  }
1742 
1743  // Move constant to the right side.
1744  if (adaptor.getLhs() && !adaptor.getRhs()) {
1745  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1746  using Pred = CmpIPredicate;
1747  const std::pair<Pred, Pred> invPreds[] = {
1748  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1749  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1750  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1751  {Pred::ne, Pred::ne},
1752  };
1753  Pred origPred = getPredicate();
1754  for (auto pred : invPreds) {
1755  if (origPred == pred.first) {
1756  setPredicate(pred.second);
1757  Value lhs = getLhs();
1758  Value rhs = getRhs();
1759  getLhsMutable().assign(rhs);
1760  getRhsMutable().assign(lhs);
1761  return getResult();
1762  }
1763  }
1764  llvm_unreachable("unknown cmpi predicate kind");
1765  }
1766 
1767  // We are moving constants to the right side; So if lhs is constant rhs is
1768  // guaranteed to be a constant.
1769  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1770  return constFoldBinaryOp<IntegerAttr>(
1771  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1772  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1773  return APInt(1,
1774  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1775  });
1776  }
1777 
1778  return {};
1779 }
1780 
1781 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1782  MLIRContext *context) {
1783  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // CmpFOp
1788 //===----------------------------------------------------------------------===//
1789 
1790 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1791 /// comparison predicates.
1792 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1793  const APFloat &lhs, const APFloat &rhs) {
1794  auto cmpResult = lhs.compare(rhs);
1795  switch (predicate) {
1796  case arith::CmpFPredicate::AlwaysFalse:
1797  return false;
1798  case arith::CmpFPredicate::OEQ:
1799  return cmpResult == APFloat::cmpEqual;
1800  case arith::CmpFPredicate::OGT:
1801  return cmpResult == APFloat::cmpGreaterThan;
1802  case arith::CmpFPredicate::OGE:
1803  return cmpResult == APFloat::cmpGreaterThan ||
1804  cmpResult == APFloat::cmpEqual;
1805  case arith::CmpFPredicate::OLT:
1806  return cmpResult == APFloat::cmpLessThan;
1807  case arith::CmpFPredicate::OLE:
1808  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1809  case arith::CmpFPredicate::ONE:
1810  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1811  case arith::CmpFPredicate::ORD:
1812  return cmpResult != APFloat::cmpUnordered;
1813  case arith::CmpFPredicate::UEQ:
1814  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1815  case arith::CmpFPredicate::UGT:
1816  return cmpResult == APFloat::cmpUnordered ||
1817  cmpResult == APFloat::cmpGreaterThan;
1818  case arith::CmpFPredicate::UGE:
1819  return cmpResult == APFloat::cmpUnordered ||
1820  cmpResult == APFloat::cmpGreaterThan ||
1821  cmpResult == APFloat::cmpEqual;
1822  case arith::CmpFPredicate::ULT:
1823  return cmpResult == APFloat::cmpUnordered ||
1824  cmpResult == APFloat::cmpLessThan;
1825  case arith::CmpFPredicate::ULE:
1826  return cmpResult == APFloat::cmpUnordered ||
1827  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1828  case arith::CmpFPredicate::UNE:
1829  return cmpResult != APFloat::cmpEqual;
1830  case arith::CmpFPredicate::UNO:
1831  return cmpResult == APFloat::cmpUnordered;
1832  case arith::CmpFPredicate::AlwaysTrue:
1833  return true;
1834  }
1835  llvm_unreachable("unknown cmpf predicate kind");
1836 }
1837 
1838 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1839  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1840  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1841 
1842  // If one operand is NaN, making them both NaN does not change the result.
1843  if (lhs && lhs.getValue().isNaN())
1844  rhs = lhs;
1845  if (rhs && rhs.getValue().isNaN())
1846  lhs = rhs;
1847 
1848  if (!lhs || !rhs)
1849  return {};
1850 
1851  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1852  return BoolAttr::get(getContext(), val);
1853 }
1854 
1855 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1856 public:
1858 
1859  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1860  bool isUnsigned) {
1861  using namespace arith;
1862  switch (pred) {
1863  case CmpFPredicate::UEQ:
1864  case CmpFPredicate::OEQ:
1865  return CmpIPredicate::eq;
1866  case CmpFPredicate::UGT:
1867  case CmpFPredicate::OGT:
1868  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1869  case CmpFPredicate::UGE:
1870  case CmpFPredicate::OGE:
1871  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1872  case CmpFPredicate::ULT:
1873  case CmpFPredicate::OLT:
1874  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1875  case CmpFPredicate::ULE:
1876  case CmpFPredicate::OLE:
1877  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1878  case CmpFPredicate::UNE:
1879  case CmpFPredicate::ONE:
1880  return CmpIPredicate::ne;
1881  default:
1882  llvm_unreachable("Unexpected predicate!");
1883  }
1884  }
1885 
1887  PatternRewriter &rewriter) const override {
1888  FloatAttr flt;
1889  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1890  return failure();
1891 
1892  const APFloat &rhs = flt.getValue();
1893 
1894  // Don't attempt to fold a nan.
1895  if (rhs.isNaN())
1896  return failure();
1897 
1898  // Get the width of the mantissa. We don't want to hack on conversions that
1899  // might lose information from the integer, e.g. "i64 -> float"
1900  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1901  int mantissaWidth = floatTy.getFPMantissaWidth();
1902  if (mantissaWidth <= 0)
1903  return failure();
1904 
1905  bool isUnsigned;
1906  Value intVal;
1907 
1908  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1909  isUnsigned = false;
1910  intVal = si.getIn();
1911  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1912  isUnsigned = true;
1913  intVal = ui.getIn();
1914  } else {
1915  return failure();
1916  }
1917 
1918  // Check to see that the input is converted from an integer type that is
1919  // small enough that preserves all bits.
1920  auto intTy = llvm::cast<IntegerType>(intVal.getType());
1921  auto intWidth = intTy.getWidth();
1922 
1923  // Number of bits representing values, as opposed to the sign
1924  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1925 
1926  // Following test does NOT adjust intWidth downwards for signed inputs,
1927  // because the most negative value still requires all the mantissa bits
1928  // to distinguish it from one less than that value.
1929  if ((int)intWidth > mantissaWidth) {
1930  // Conversion would lose accuracy. Check if loss can impact comparison.
1931  int exponent = ilogb(rhs);
1932  if (exponent == APFloat::IEK_Inf) {
1933  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1934  if (maxExponent < (int)valueBits) {
1935  // Conversion could create infinity.
1936  return failure();
1937  }
1938  } else {
1939  // Note that if rhs is zero or NaN, then Exp is negative
1940  // and first condition is trivially false.
1941  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1942  // Conversion could affect comparison.
1943  return failure();
1944  }
1945  }
1946  }
1947 
1948  // Convert to equivalent cmpi predicate
1949  CmpIPredicate pred;
1950  switch (op.getPredicate()) {
1951  case CmpFPredicate::ORD:
1952  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1953  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1954  /*width=*/1);
1955  return success();
1956  case CmpFPredicate::UNO:
1957  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1958  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1959  /*width=*/1);
1960  return success();
1961  default:
1962  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1963  break;
1964  }
1965 
1966  if (!isUnsigned) {
1967  // If the rhs value is > SignedMax, fold the comparison. This handles
1968  // +INF and large values.
1969  APFloat signedMax(rhs.getSemantics());
1970  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1971  APFloat::rmNearestTiesToEven);
1972  if (signedMax < rhs) { // smax < 13123.0
1973  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1974  pred == CmpIPredicate::sle)
1975  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1976  /*width=*/1);
1977  else
1978  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1979  /*width=*/1);
1980  return success();
1981  }
1982  } else {
1983  // If the rhs value is > UnsignedMax, fold the comparison. This handles
1984  // +INF and large values.
1985  APFloat unsignedMax(rhs.getSemantics());
1986  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1987  APFloat::rmNearestTiesToEven);
1988  if (unsignedMax < rhs) { // umax < 13123.0
1989  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1990  pred == CmpIPredicate::ule)
1991  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1992  /*width=*/1);
1993  else
1994  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1995  /*width=*/1);
1996  return success();
1997  }
1998  }
1999 
2000  if (!isUnsigned) {
2001  // See if the rhs value is < SignedMin.
2002  APFloat signedMin(rhs.getSemantics());
2003  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2004  APFloat::rmNearestTiesToEven);
2005  if (signedMin > rhs) { // smin > 12312.0
2006  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2007  pred == CmpIPredicate::sge)
2008  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2009  /*width=*/1);
2010  else
2011  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2012  /*width=*/1);
2013  return success();
2014  }
2015  } else {
2016  // See if the rhs value is < UnsignedMin.
2017  APFloat unsignedMin(rhs.getSemantics());
2018  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2019  APFloat::rmNearestTiesToEven);
2020  if (unsignedMin > rhs) { // umin > 12312.0
2021  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2022  pred == CmpIPredicate::uge)
2023  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2024  /*width=*/1);
2025  else
2026  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2027  /*width=*/1);
2028  return success();
2029  }
2030  }
2031 
2032  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2033  // [0, UMAX], but it may still be fractional. See if it is fractional by
2034  // casting the FP value to the integer value and back, checking for
2035  // equality. Don't do this for zero, because -0.0 is not fractional.
2036  bool ignored;
2037  APSInt rhsInt(intWidth, isUnsigned);
2038  if (APFloat::opInvalidOp ==
2039  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2040  // Undefined behavior invoked - the destination type can't represent
2041  // the input constant.
2042  return failure();
2043  }
2044 
2045  if (!rhs.isZero()) {
2046  APFloat apf(floatTy.getFloatSemantics(),
2047  APInt::getZero(floatTy.getWidth()));
2048  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2049 
2050  bool equal = apf == rhs;
2051  if (!equal) {
2052  // If we had a comparison against a fractional value, we have to adjust
2053  // the compare predicate and sometimes the value. rhsInt is rounded
2054  // towards zero at this point.
2055  switch (pred) {
2056  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2057  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2058  /*width=*/1);
2059  return success();
2060  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2061  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2062  /*width=*/1);
2063  return success();
2064  case CmpIPredicate::ule:
2065  // (float)int <= 4.4 --> int <= 4
2066  // (float)int <= -4.4 --> false
2067  if (rhs.isNegative()) {
2068  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2069  /*width=*/1);
2070  return success();
2071  }
2072  break;
2073  case CmpIPredicate::sle:
2074  // (float)int <= 4.4 --> int <= 4
2075  // (float)int <= -4.4 --> int < -4
2076  if (rhs.isNegative())
2077  pred = CmpIPredicate::slt;
2078  break;
2079  case CmpIPredicate::ult:
2080  // (float)int < -4.4 --> false
2081  // (float)int < 4.4 --> int <= 4
2082  if (rhs.isNegative()) {
2083  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2084  /*width=*/1);
2085  return success();
2086  }
2087  pred = CmpIPredicate::ule;
2088  break;
2089  case CmpIPredicate::slt:
2090  // (float)int < -4.4 --> int < -4
2091  // (float)int < 4.4 --> int <= 4
2092  if (!rhs.isNegative())
2093  pred = CmpIPredicate::sle;
2094  break;
2095  case CmpIPredicate::ugt:
2096  // (float)int > 4.4 --> int > 4
2097  // (float)int > -4.4 --> true
2098  if (rhs.isNegative()) {
2099  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2100  /*width=*/1);
2101  return success();
2102  }
2103  break;
2104  case CmpIPredicate::sgt:
2105  // (float)int > 4.4 --> int > 4
2106  // (float)int > -4.4 --> int >= -4
2107  if (rhs.isNegative())
2108  pred = CmpIPredicate::sge;
2109  break;
2110  case CmpIPredicate::uge:
2111  // (float)int >= -4.4 --> true
2112  // (float)int >= 4.4 --> int > 4
2113  if (rhs.isNegative()) {
2114  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2115  /*width=*/1);
2116  return success();
2117  }
2118  pred = CmpIPredicate::ugt;
2119  break;
2120  case CmpIPredicate::sge:
2121  // (float)int >= -4.4 --> int >= -4
2122  // (float)int >= 4.4 --> int > 4
2123  if (!rhs.isNegative())
2124  pred = CmpIPredicate::sgt;
2125  break;
2126  }
2127  }
2128  }
2129 
2130  // Lower this FP comparison into an appropriate integer version of the
2131  // comparison.
2132  rewriter.replaceOpWithNewOp<CmpIOp>(
2133  op, pred, intVal,
2134  rewriter.create<ConstantOp>(
2135  op.getLoc(), intVal.getType(),
2136  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2137  return success();
2138  }
2139 };
2140 
2141 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2142  MLIRContext *context) {
2143  patterns.insert<CmpFIntToFPConst>(context);
2144 }
2145 
2146 //===----------------------------------------------------------------------===//
2147 // SelectOp
2148 //===----------------------------------------------------------------------===//
2149 
2150 // Transforms a select of a boolean to arithmetic operations
2151 //
2152 // arith.select %arg, %x, %y : i1
2153 //
2154 // becomes
2155 //
2156 // and(%arg, %x) or and(!%arg, %y)
2157 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
2159 
2160  LogicalResult matchAndRewrite(arith::SelectOp op,
2161  PatternRewriter &rewriter) const override {
2162  if (!op.getType().isInteger(1))
2163  return failure();
2164 
2165  Value falseConstant =
2166  rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
2167  Value notCondition = rewriter.create<arith::XOrIOp>(
2168  op.getLoc(), op.getCondition(), falseConstant);
2169 
2170  Value trueVal = rewriter.create<arith::AndIOp>(
2171  op.getLoc(), op.getCondition(), op.getTrueValue());
2172  Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
2173  op.getFalseValue());
2174  rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
2175  return success();
2176  }
2177 };
2178 
2179 // select %arg, %c1, %c0 => extui %arg
2180 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2182 
2183  LogicalResult matchAndRewrite(arith::SelectOp op,
2184  PatternRewriter &rewriter) const override {
2185  // Cannot extui i1 to i1, or i1 to f32
2186  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2187  return failure();
2188 
2189  // select %x, c1, %c0 => extui %arg
2190  if (matchPattern(op.getTrueValue(), m_One()) &&
2191  matchPattern(op.getFalseValue(), m_Zero())) {
2192  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2193  op.getCondition());
2194  return success();
2195  }
2196 
2197  // select %x, c0, %c1 => extui (xor %arg, true)
2198  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2199  matchPattern(op.getFalseValue(), m_One())) {
2200  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2201  op, op.getType(),
2202  rewriter.create<arith::XOrIOp>(
2203  op.getLoc(), op.getCondition(),
2204  rewriter.create<arith::ConstantIntOp>(
2205  op.getLoc(), 1, op.getCondition().getType())));
2206  return success();
2207  }
2208 
2209  return failure();
2210  }
2211 };
2212 
2213 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2214  MLIRContext *context) {
2215  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
2216  SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
2217  SelectNotCond, SelectToExtUI>(context);
2218 }
2219 
2220 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2221  Value trueVal = getTrueValue();
2222  Value falseVal = getFalseValue();
2223  if (trueVal == falseVal)
2224  return trueVal;
2225 
2226  Value condition = getCondition();
2227 
2228  // select true, %0, %1 => %0
2229  if (matchPattern(adaptor.getCondition(), m_One()))
2230  return trueVal;
2231 
2232  // select false, %0, %1 => %1
2233  if (matchPattern(adaptor.getCondition(), m_Zero()))
2234  return falseVal;
2235 
2236  // If either operand is fully poisoned, return the other.
2237  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2238  return falseVal;
2239 
2240  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2241  return trueVal;
2242 
2243  // select %x, true, false => %x
2244  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2245  matchPattern(adaptor.getFalseValue(), m_Zero()))
2246  return condition;
2247 
2248  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2249  auto pred = cmp.getPredicate();
2250  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2251  auto cmpLhs = cmp.getLhs();
2252  auto cmpRhs = cmp.getRhs();
2253 
2254  // %0 = arith.cmpi eq, %arg0, %arg1
2255  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2256 
2257  // %0 = arith.cmpi ne, %arg0, %arg1
2258  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2259 
2260  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2261  (cmpRhs == trueVal && cmpLhs == falseVal))
2262  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2263  }
2264  }
2265 
2266  // Constant-fold constant operands over non-splat constant condition.
2267  // select %cst_vec, %cst0, %cst1 => %cst2
2268  if (auto cond =
2269  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2270  if (auto lhs =
2271  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2272  if (auto rhs =
2273  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2274  SmallVector<Attribute> results;
2275  results.reserve(static_cast<size_t>(cond.getNumElements()));
2276  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2277  cond.value_end<BoolAttr>());
2278  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2279  lhs.value_end<Attribute>());
2280  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2281  rhs.value_end<Attribute>());
2282 
2283  for (auto [condVal, lhsVal, rhsVal] :
2284  llvm::zip_equal(condVals, lhsVals, rhsVals))
2285  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2286 
2287  return DenseElementsAttr::get(lhs.getType(), results);
2288  }
2289  }
2290  }
2291 
2292  return nullptr;
2293 }
2294 
2296  Type conditionType, resultType;
2298  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2299  parser.parseOptionalAttrDict(result.attributes) ||
2300  parser.parseColonType(resultType))
2301  return failure();
2302 
2303  // Check for the explicit condition type if this is a masked tensor or vector.
2304  if (succeeded(parser.parseOptionalComma())) {
2305  conditionType = resultType;
2306  if (parser.parseType(resultType))
2307  return failure();
2308  } else {
2309  conditionType = parser.getBuilder().getI1Type();
2310  }
2311 
2312  result.addTypes(resultType);
2313  return parser.resolveOperands(operands,
2314  {conditionType, resultType, resultType},
2315  parser.getNameLoc(), result.operands);
2316 }
2317 
2319  p << " " << getOperands();
2320  p.printOptionalAttrDict((*this)->getAttrs());
2321  p << " : ";
2322  if (ShapedType condType =
2323  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2324  p << condType << ", ";
2325  p << getType();
2326 }
2327 
2329  Type conditionType = getCondition().getType();
2330  if (conditionType.isSignlessInteger(1))
2331  return success();
2332 
2333  // If the result type is a vector or tensor, the type can be a mask with the
2334  // same elements.
2335  Type resultType = getType();
2336  if (!llvm::isa<TensorType, VectorType>(resultType))
2337  return emitOpError() << "expected condition to be a signless i1, but got "
2338  << conditionType;
2339  Type shapedConditionType = getI1SameShape(resultType);
2340  if (conditionType != shapedConditionType) {
2341  return emitOpError() << "expected condition type to have the same shape "
2342  "as the result type, expected "
2343  << shapedConditionType << ", but got "
2344  << conditionType;
2345  }
2346  return success();
2347 }
2348 //===----------------------------------------------------------------------===//
2349 // ShLIOp
2350 //===----------------------------------------------------------------------===//
2351 
2352 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2353  // shli(x, 0) -> x
2354  if (matchPattern(adaptor.getRhs(), m_Zero()))
2355  return getLhs();
2356  // Don't fold if shifting more than the bit width.
2357  bool bounded = false;
2358  auto result = constFoldBinaryOp<IntegerAttr>(
2359  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2360  bounded = b.ule(b.getBitWidth());
2361  return a.shl(b);
2362  });
2363  return bounded ? result : Attribute();
2364 }
2365 
2366 //===----------------------------------------------------------------------===//
2367 // ShRUIOp
2368 //===----------------------------------------------------------------------===//
2369 
2370 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2371  // shrui(x, 0) -> x
2372  if (matchPattern(adaptor.getRhs(), m_Zero()))
2373  return getLhs();
2374  // Don't fold if shifting more than the bit width.
2375  bool bounded = false;
2376  auto result = constFoldBinaryOp<IntegerAttr>(
2377  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2378  bounded = b.ule(b.getBitWidth());
2379  return a.lshr(b);
2380  });
2381  return bounded ? result : Attribute();
2382 }
2383 
2384 //===----------------------------------------------------------------------===//
2385 // ShRSIOp
2386 //===----------------------------------------------------------------------===//
2387 
2388 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2389  // shrsi(x, 0) -> x
2390  if (matchPattern(adaptor.getRhs(), m_Zero()))
2391  return getLhs();
2392  // Don't fold if shifting more than the bit width.
2393  bool bounded = false;
2394  auto result = constFoldBinaryOp<IntegerAttr>(
2395  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2396  bounded = b.ule(b.getBitWidth());
2397  return a.ashr(b);
2398  });
2399  return bounded ? result : Attribute();
2400 }
2401 
2402 //===----------------------------------------------------------------------===//
2403 // Atomic Enum
2404 //===----------------------------------------------------------------------===//
2405 
2406 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2407 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2408  OpBuilder &builder, Location loc,
2409  bool useOnlyFiniteValue) {
2410  switch (kind) {
2411  case AtomicRMWKind::maximumf: {
2412  const llvm::fltSemantics &semantic =
2413  llvm::cast<FloatType>(resultType).getFloatSemantics();
2414  APFloat identity = useOnlyFiniteValue
2415  ? APFloat::getLargest(semantic, /*Negative=*/true)
2416  : APFloat::getInf(semantic, /*Negative=*/true);
2417  return builder.getFloatAttr(resultType, identity);
2418  }
2419  case AtomicRMWKind::addf:
2420  case AtomicRMWKind::addi:
2421  case AtomicRMWKind::maxu:
2422  case AtomicRMWKind::ori:
2423  return builder.getZeroAttr(resultType);
2424  case AtomicRMWKind::andi:
2425  return builder.getIntegerAttr(
2426  resultType,
2427  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2428  case AtomicRMWKind::maxs:
2429  return builder.getIntegerAttr(
2430  resultType, APInt::getSignedMinValue(
2431  llvm::cast<IntegerType>(resultType).getWidth()));
2432  case AtomicRMWKind::minimumf: {
2433  const llvm::fltSemantics &semantic =
2434  llvm::cast<FloatType>(resultType).getFloatSemantics();
2435  APFloat identity = useOnlyFiniteValue
2436  ? APFloat::getLargest(semantic, /*Negative=*/false)
2437  : APFloat::getInf(semantic, /*Negative=*/false);
2438 
2439  return builder.getFloatAttr(resultType, identity);
2440  }
2441  case AtomicRMWKind::mins:
2442  return builder.getIntegerAttr(
2443  resultType, APInt::getSignedMaxValue(
2444  llvm::cast<IntegerType>(resultType).getWidth()));
2445  case AtomicRMWKind::minu:
2446  return builder.getIntegerAttr(
2447  resultType,
2448  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2449  case AtomicRMWKind::muli:
2450  return builder.getIntegerAttr(resultType, 1);
2451  case AtomicRMWKind::mulf:
2452  return builder.getFloatAttr(resultType, 1);
2453  // TODO: Add remaining reduction operations.
2454  default:
2455  (void)emitOptionalError(loc, "Reduction operation type not supported");
2456  break;
2457  }
2458  return nullptr;
2459 }
2460 
2461 /// Return the identity numeric value associated to the give op.
2462 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2463  std::optional<AtomicRMWKind> maybeKind =
2465  // Floating-point operations.
2466  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2467  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2468  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2469  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2470  // Integer operations.
2471  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2472  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2473  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2474  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2475  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2476  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2477  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2478  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2479  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2480  .Default([](Operation *op) { return std::nullopt; });
2481  if (!maybeKind) {
2482  op->emitError() << "Unknown neutral element for: " << *op;
2483  return std::nullopt;
2484  }
2485 
2486  bool useOnlyFiniteValue = false;
2487  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2488  if (fmfOpInterface) {
2489  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2490  useOnlyFiniteValue =
2491  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2492  }
2493 
2494  // Builder only used as helper for attribute creation.
2495  OpBuilder b(op->getContext());
2496  Type resultType = op->getResult(0).getType();
2497 
2498  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2499  useOnlyFiniteValue);
2500 }
2501 
2502 /// Returns the identity value associated with an AtomicRMWKind op.
2503 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2504  OpBuilder &builder, Location loc,
2505  bool useOnlyFiniteValue) {
2506  auto attr =
2507  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2508  return builder.create<arith::ConstantOp>(loc, attr);
2509 }
2510 
2511 /// Return the value obtained by applying the reduction operation kind
2512 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2513 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2514  Location loc, Value lhs, Value rhs) {
2515  switch (op) {
2516  case AtomicRMWKind::addf:
2517  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2518  case AtomicRMWKind::addi:
2519  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2520  case AtomicRMWKind::mulf:
2521  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2522  case AtomicRMWKind::muli:
2523  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2524  case AtomicRMWKind::maximumf:
2525  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2526  case AtomicRMWKind::minimumf:
2527  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2528  case AtomicRMWKind::maxnumf:
2529  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2530  case AtomicRMWKind::minnumf:
2531  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2532  case AtomicRMWKind::maxs:
2533  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2534  case AtomicRMWKind::mins:
2535  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2536  case AtomicRMWKind::maxu:
2537  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2538  case AtomicRMWKind::minu:
2539  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2540  case AtomicRMWKind::ori:
2541  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2542  case AtomicRMWKind::andi:
2543  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2544  // TODO: Add remaining reduction operations.
2545  default:
2546  (void)emitOptionalError(loc, "Reduction operation type not supported");
2547  break;
2548  }
2549  return nullptr;
2550 }
2551 
2552 //===----------------------------------------------------------------------===//
2553 // TableGen'd op method definitions
2554 //===----------------------------------------------------------------------===//
2555 
2556 #define GET_OP_CLASSES
2557 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2558 
2559 //===----------------------------------------------------------------------===//
2560 // TableGen'd enum attribute definitions
2561 //===----------------------------------------------------------------------===//
2562 
2563 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1246
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:59
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1182
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1689
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1196
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1168
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:578
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:108
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:49
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:116
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1545
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1444
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1218
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:54
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:96
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:779
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1189
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:137
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1707
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1204
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1162
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:40
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:303
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1232
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1886
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1859
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:65
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:236
static bool classof(Operation *op)
Definition: ArithOps.cpp:242
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:248
static bool classof(Operation *op)
Definition: ArithOps.cpp:254
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:215
static bool classof(Operation *op)
Definition: ArithOps.cpp:230
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2462
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1661
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2407
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2513
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2503
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:65
auto m_Val(Value v)
Definition: Matchers.h:450
MPInt ceil(const Fraction &f)
Definition: Fraction.h:76
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:371
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:350
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:363
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:355
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2160
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2183
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes