MLIR  17.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <utility>
12 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 
21 #include "llvm/ADT/APSInt.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::arith;
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern helpers
30 //===----------------------------------------------------------------------===//
31 
32 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
33  Attribute lhs, Attribute rhs) {
34  return builder.getIntegerAttr(res.getType(),
35  lhs.cast<IntegerAttr>().getInt() +
36  rhs.cast<IntegerAttr>().getInt());
37 }
38 
39 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
40  Attribute lhs, Attribute rhs) {
41  return builder.getIntegerAttr(res.getType(),
42  lhs.cast<IntegerAttr>().getInt() -
43  rhs.cast<IntegerAttr>().getInt());
44 }
45 
46 /// Invert an integer comparison predicate.
47 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
48  switch (pred) {
49  case arith::CmpIPredicate::eq:
50  return arith::CmpIPredicate::ne;
51  case arith::CmpIPredicate::ne:
52  return arith::CmpIPredicate::eq;
53  case arith::CmpIPredicate::slt:
54  return arith::CmpIPredicate::sge;
55  case arith::CmpIPredicate::sle:
56  return arith::CmpIPredicate::sgt;
57  case arith::CmpIPredicate::sgt:
58  return arith::CmpIPredicate::sle;
59  case arith::CmpIPredicate::sge:
60  return arith::CmpIPredicate::slt;
61  case arith::CmpIPredicate::ult:
62  return arith::CmpIPredicate::uge;
63  case arith::CmpIPredicate::ule:
64  return arith::CmpIPredicate::ugt;
65  case arith::CmpIPredicate::ugt:
66  return arith::CmpIPredicate::ule;
67  case arith::CmpIPredicate::uge:
68  return arith::CmpIPredicate::ult;
69  }
70  llvm_unreachable("unknown cmpi predicate kind");
71 }
72 
73 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
74  return arith::CmpIPredicateAttr::get(pred.getContext(),
75  invertPredicate(pred.getValue()));
76 }
77 
78 static int64_t getScalarOrElementWidth(Type type) {
79  Type elemTy = getElementTypeOrSelf(type);
80  if (elemTy.isIntOrFloat())
81  return elemTy.getIntOrFloatBitWidth();
82 
83  return -1;
84 }
85 
86 static int64_t getScalarOrElementWidth(Value value) {
87  return getScalarOrElementWidth(value.getType());
88 }
89 
91  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
92  return intAttr.getValue();
93 
94  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
95  if (splatAttr.getElementType().isa<IntegerType>())
96  return splatAttr.getSplatValue<APInt>();
97 
98  return failure();
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // TableGen'd canonicalization patterns
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 #include "ArithCanonicalization.inc"
107 } // namespace
108 
109 //===----------------------------------------------------------------------===//
110 // ConstantOp
111 //===----------------------------------------------------------------------===//
112 
113 void arith::ConstantOp::getAsmResultNames(
114  function_ref<void(Value, StringRef)> setNameFn) {
115  auto type = getType();
116  if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
117  auto intType = type.dyn_cast<IntegerType>();
118 
119  // Sugar i1 constants with 'true' and 'false'.
120  if (intType && intType.getWidth() == 1)
121  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
122 
123  // Otherwise, build a complex name with the value and type.
124  SmallString<32> specialNameBuffer;
125  llvm::raw_svector_ostream specialName(specialNameBuffer);
126  specialName << 'c' << intCst.getValue();
127  if (intType)
128  specialName << '_' << type;
129  setNameFn(getResult(), specialName.str());
130  } else {
131  setNameFn(getResult(), "cst");
132  }
133 }
134 
135 /// TODO: disallow arith.constant to return anything other than signless integer
136 /// or float like.
138  auto type = getType();
139  // The value's type must match the return type.
140  if (getValue().getType() != type) {
141  return emitOpError() << "value type " << getValue().getType()
142  << " must match return type: " << type;
143  }
144  // Integer values must be signless.
145  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
146  return emitOpError("integer return type must be signless");
147  // Any float or elements attribute are acceptable.
148  if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
149  return emitOpError(
150  "value must be an integer, float, or elements attribute");
151  }
152  return success();
153 }
154 
155 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
156  // The value's type must be the same as the provided type.
157  auto typedAttr = value.dyn_cast<TypedAttr>();
158  if (!typedAttr || typedAttr.getType() != type)
159  return false;
160  // Integer values must be signless.
161  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
162  return false;
163  // Integer, float, and element attributes are buildable.
164  return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
165 }
166 
167 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
168 
170  int64_t value, unsigned width) {
171  auto type = builder.getIntegerType(width);
172  arith::ConstantOp::build(builder, result, type,
173  builder.getIntegerAttr(type, value));
174 }
175 
177  int64_t value, Type type) {
178  assert(type.isSignlessInteger() &&
179  "ConstantIntOp can only have signless integer type values");
180  arith::ConstantOp::build(builder, result, type,
181  builder.getIntegerAttr(type, value));
182 }
183 
185  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
186  return constOp.getType().isSignlessInteger();
187  return false;
188 }
189 
191  const APFloat &value, FloatType type) {
192  arith::ConstantOp::build(builder, result, type,
193  builder.getFloatAttr(type, value));
194 }
195 
197  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
198  return constOp.getType().isa<FloatType>();
199  return false;
200 }
201 
203  int64_t value) {
204  arith::ConstantOp::build(builder, result, builder.getIndexType(),
205  builder.getIndexAttr(value));
206 }
207 
209  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
210  return constOp.getType().isIndex();
211  return false;
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // AddIOp
216 //===----------------------------------------------------------------------===//
217 
218 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
219  // addi(x, 0) -> x
220  if (matchPattern(getRhs(), m_Zero()))
221  return getLhs();
222 
223  // addi(subi(a, b), b) -> a
224  if (auto sub = getLhs().getDefiningOp<SubIOp>())
225  if (getRhs() == sub.getRhs())
226  return sub.getLhs();
227 
228  // addi(b, subi(a, b)) -> a
229  if (auto sub = getRhs().getDefiningOp<SubIOp>())
230  if (getLhs() == sub.getRhs())
231  return sub.getLhs();
232 
233  return constFoldBinaryOp<IntegerAttr>(
234  adaptor.getOperands(),
235  [](APInt a, const APInt &b) { return std::move(a) + b; });
236 }
237 
238 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
239  MLIRContext *context) {
240  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
241  context);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // AddUIExtendedOp
246 //===----------------------------------------------------------------------===//
247 
248 std::optional<SmallVector<int64_t, 4>>
249 arith::AddUIExtendedOp::getShapeForUnroll() {
250  if (auto vt = getType(0).dyn_cast<VectorType>())
251  return llvm::to_vector<4>(vt.getShape());
252  return std::nullopt;
253 }
254 
255 // Returns the overflow bit, assuming that `sum` is the result of unsigned
256 // addition of `operand` and another number.
257 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
258  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
259 }
260 
262 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
264  Type overflowTy = getOverflow().getType();
265  // addui_extended(x, 0) -> x, false
266  if (matchPattern(getRhs(), m_Zero())) {
267  Builder builder(getContext());
268  auto falseValue = builder.getZeroAttr(overflowTy);
269 
270  results.push_back(getLhs());
271  results.push_back(falseValue);
272  return success();
273  }
274 
275  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
276  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
277  // operands. If that succeeds, calculate the overflow bit based on the sum
278  // and the first (constant) operand, `lhs`. Note that we cannot simply call
279  // `constFoldBinaryOp` again to calculate the overflow bit because the
280  // constructed attribute is of the same element type as both operands.
281  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
282  adaptor.getOperands(),
283  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
284  Attribute overflowAttr;
285  if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
286  // Both arguments are scalars, calculate the scalar overflow value.
287  auto sum = sumAttr.cast<IntegerAttr>();
288  overflowAttr = IntegerAttr::get(
289  overflowTy,
290  calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
291  } else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
292  // Both arguments are splats, calculate the splat overflow value.
293  auto sum = sumAttr.cast<SplatElementsAttr>();
294  APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
295  lhs.getSplatValue<APInt>());
296  overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
297  } else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
298  // Othwerwise calculate element-wise overflow values.
299  auto sum = sumAttr.cast<ElementsAttr>();
300  const auto numElems = static_cast<size_t>(sum.getNumElements());
301  SmallVector<APInt> overflowValues;
302  overflowValues.reserve(numElems);
303 
304  auto sumIt = sum.value_begin<APInt>();
305  auto lhsIt = lhs.value_begin<APInt>();
306  for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
307  overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt));
308 
309  overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues);
310  } else {
311  return failure();
312  }
313 
314  results.push_back(sumAttr);
315  results.push_back(overflowAttr);
316  return success();
317  }
318 
319  return failure();
320 }
321 
322 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
323  RewritePatternSet &patterns, MLIRContext *context) {
324  patterns.add<AddUIExtendedToAddI>(context);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // SubIOp
329 //===----------------------------------------------------------------------===//
330 
331 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
332  // subi(x,x) -> 0
333  if (getOperand(0) == getOperand(1))
334  return Builder(getContext()).getZeroAttr(getType());
335  // subi(x,0) -> x
336  if (matchPattern(getRhs(), m_Zero()))
337  return getLhs();
338 
339  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
340  // subi(addi(a, b), b) -> a
341  if (getRhs() == add.getRhs())
342  return add.getLhs();
343  // subi(addi(a, b), a) -> b
344  if (getRhs() == add.getLhs())
345  return add.getRhs();
346  }
347 
348  return constFoldBinaryOp<IntegerAttr>(
349  adaptor.getOperands(),
350  [](APInt a, const APInt &b) { return std::move(a) - b; });
351 }
352 
353 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
354  MLIRContext *context) {
355  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
356  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
357  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // MulIOp
362 //===----------------------------------------------------------------------===//
363 
364 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
365  // muli(x, 0) -> 0
366  if (matchPattern(getRhs(), m_Zero()))
367  return getRhs();
368  // muli(x, 1) -> x
369  if (matchPattern(getRhs(), m_One()))
370  return getOperand(0);
371  // TODO: Handle the overflow case.
372 
373  // default folder
374  return constFoldBinaryOp<IntegerAttr>(
375  adaptor.getOperands(),
376  [](const APInt &a, const APInt &b) { return a * b; });
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // MulSIExtendedOp
381 //===----------------------------------------------------------------------===//
382 
383 std::optional<SmallVector<int64_t, 4>>
384 arith::MulSIExtendedOp::getShapeForUnroll() {
385  if (auto vt = getType(0).dyn_cast<VectorType>())
386  return llvm::to_vector<4>(vt.getShape());
387  return std::nullopt;
388 }
389 
391 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
393  // mulsi_extended(x, 0) -> 0, 0
394  if (matchPattern(getRhs(), m_Zero())) {
395  Attribute zero = adaptor.getRhs();
396  results.push_back(zero);
397  results.push_back(zero);
398  return success();
399  }
400 
401  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
402  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
403  adaptor.getOperands(),
404  [](const APInt &a, const APInt &b) { return a * b; })) {
405  // Invoke the constant fold helper again to calculate the 'high' result.
406  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
407  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
408  unsigned bitWidth = a.getBitWidth();
409  APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
410  return fullProduct.extractBits(bitWidth, bitWidth);
411  });
412  assert(highAttr && "Unexpected constant-folding failure");
413 
414  results.push_back(lowAttr);
415  results.push_back(highAttr);
416  return success();
417  }
418 
419  return failure();
420 }
421 
422 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
423  RewritePatternSet &patterns, MLIRContext *context) {
424  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // MulUIExtendedOp
429 //===----------------------------------------------------------------------===//
430 
431 std::optional<SmallVector<int64_t, 4>>
432 arith::MulUIExtendedOp::getShapeForUnroll() {
433  if (auto vt = getType(0).dyn_cast<VectorType>())
434  return llvm::to_vector<4>(vt.getShape());
435  return std::nullopt;
436 }
437 
439 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
441  // mului_extended(x, 0) -> 0, 0
442  if (matchPattern(getRhs(), m_Zero())) {
443  Attribute zero = adaptor.getRhs();
444  results.push_back(zero);
445  results.push_back(zero);
446  return success();
447  }
448 
449  // mului_extended(x, 1) -> x, 0
450  if (matchPattern(getRhs(), m_One())) {
451  Builder builder(getContext());
452  Attribute zero = builder.getZeroAttr(getLhs().getType());
453  results.push_back(getLhs());
454  results.push_back(zero);
455  return success();
456  }
457 
458  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
459  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
460  adaptor.getOperands(),
461  [](const APInt &a, const APInt &b) { return a * b; })) {
462  // Invoke the constant fold helper again to calculate the 'high' result.
463  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
464  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
465  unsigned bitWidth = a.getBitWidth();
466  APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
467  return fullProduct.extractBits(bitWidth, bitWidth);
468  });
469  assert(highAttr && "Unexpected constant-folding failure");
470 
471  results.push_back(lowAttr);
472  results.push_back(highAttr);
473  return success();
474  }
475 
476  return failure();
477 }
478 
479 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
480  RewritePatternSet &patterns, MLIRContext *context) {
481  patterns.add<MulUIExtendedToMulI>(context);
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // DivUIOp
486 //===----------------------------------------------------------------------===//
487 
488 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
489  // divui (x, 1) -> x.
490  if (matchPattern(getRhs(), m_One()))
491  return getLhs();
492 
493  // Don't fold if it would require a division by zero.
494  bool div0 = false;
495  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
496  [&](APInt a, const APInt &b) {
497  if (div0 || !b) {
498  div0 = true;
499  return a;
500  }
501  return a.udiv(b);
502  });
503 
504  return div0 ? Attribute() : result;
505 }
506 
507 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
508  // X / 0 => UB
509  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // DivSIOp
515 //===----------------------------------------------------------------------===//
516 
517 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
518  // divsi (x, 1) -> x.
519  if (matchPattern(getRhs(), m_One()))
520  return getLhs();
521 
522  // Don't fold if it would overflow or if it requires a division by zero.
523  bool overflowOrDiv0 = false;
524  auto result = constFoldBinaryOp<IntegerAttr>(
525  adaptor.getOperands(), [&](APInt a, const APInt &b) {
526  if (overflowOrDiv0 || !b) {
527  overflowOrDiv0 = true;
528  return a;
529  }
530  return a.sdiv_ov(b, overflowOrDiv0);
531  });
532 
533  return overflowOrDiv0 ? Attribute() : result;
534 }
535 
536 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
537  bool mayHaveUB = true;
538 
539  APInt constRHS;
540  // X / 0 => UB
541  // INT_MIN / -1 => UB
542  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
543  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
544 
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // Ceil and floor division folding helpers
550 //===----------------------------------------------------------------------===//
551 
552 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
553  bool &overflow) {
554  // Returns (a-1)/b + 1
555  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
556  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
557  return val.sadd_ov(one, overflow);
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // CeilDivUIOp
562 //===----------------------------------------------------------------------===//
563 
564 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
565  // ceildivui (x, 1) -> x.
566  if (matchPattern(getRhs(), m_One()))
567  return getLhs();
568 
569  bool overflowOrDiv0 = false;
570  auto result = constFoldBinaryOp<IntegerAttr>(
571  adaptor.getOperands(), [&](APInt a, const APInt &b) {
572  if (overflowOrDiv0 || !b) {
573  overflowOrDiv0 = true;
574  return a;
575  }
576  APInt quotient = a.udiv(b);
577  if (!a.urem(b))
578  return quotient;
579  APInt one(a.getBitWidth(), 1, true);
580  return quotient.uadd_ov(one, overflowOrDiv0);
581  });
582 
583  return overflowOrDiv0 ? Attribute() : result;
584 }
585 
586 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
587  // X / 0 => UB
588  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // CeilDivSIOp
594 //===----------------------------------------------------------------------===//
595 
596 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
597  // ceildivsi (x, 1) -> x.
598  if (matchPattern(getRhs(), m_One()))
599  return getLhs();
600 
601  // Don't fold if it would overflow or if it requires a division by zero.
602  bool overflowOrDiv0 = false;
603  auto result = constFoldBinaryOp<IntegerAttr>(
604  adaptor.getOperands(), [&](APInt a, const APInt &b) {
605  if (overflowOrDiv0 || !b) {
606  overflowOrDiv0 = true;
607  return a;
608  }
609  if (!a)
610  return a;
611  // After this point we know that neither a or b are zero.
612  unsigned bits = a.getBitWidth();
613  APInt zero = APInt::getZero(bits);
614  bool aGtZero = a.sgt(zero);
615  bool bGtZero = b.sgt(zero);
616  if (aGtZero && bGtZero) {
617  // Both positive, return ceil(a, b).
618  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
619  }
620  if (!aGtZero && !bGtZero) {
621  // Both negative, return ceil(-a, -b).
622  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
623  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
624  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
625  }
626  if (!aGtZero && bGtZero) {
627  // A is negative, b is positive, return - ( -a / b).
628  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
629  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
630  return zero.ssub_ov(div, overflowOrDiv0);
631  }
632  // A is positive, b is negative, return - (a / -b).
633  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
634  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
635  return zero.ssub_ov(div, overflowOrDiv0);
636  });
637 
638  return overflowOrDiv0 ? Attribute() : result;
639 }
640 
641 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
642  bool mayHaveUB = true;
643 
644  APInt constRHS;
645  // X / 0 => UB
646  // INT_MIN / -1 => UB
647  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
648  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
649 
651 }
652 
653 //===----------------------------------------------------------------------===//
654 // FloorDivSIOp
655 //===----------------------------------------------------------------------===//
656 
657 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
658  // floordivsi (x, 1) -> x.
659  if (matchPattern(getRhs(), m_One()))
660  return getLhs();
661 
662  // Don't fold if it would overflow or if it requires a division by zero.
663  bool overflowOrDiv0 = false;
664  auto result = constFoldBinaryOp<IntegerAttr>(
665  adaptor.getOperands(), [&](APInt a, const APInt &b) {
666  if (overflowOrDiv0 || !b) {
667  overflowOrDiv0 = true;
668  return a;
669  }
670  if (!a)
671  return a;
672  // After this point we know that neither a or b are zero.
673  unsigned bits = a.getBitWidth();
674  APInt zero = APInt::getZero(bits);
675  bool aGtZero = a.sgt(zero);
676  bool bGtZero = b.sgt(zero);
677  if (aGtZero && bGtZero) {
678  // Both positive, return a / b.
679  return a.sdiv_ov(b, overflowOrDiv0);
680  }
681  if (!aGtZero && !bGtZero) {
682  // Both negative, return -a / -b.
683  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
684  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
685  return posA.sdiv_ov(posB, overflowOrDiv0);
686  }
687  if (!aGtZero && bGtZero) {
688  // A is negative, b is positive, return - ceil(-a, b).
689  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
690  APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
691  return zero.ssub_ov(ceil, overflowOrDiv0);
692  }
693  // A is positive, b is negative, return - ceil(a, -b).
694  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
695  APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
696  return zero.ssub_ov(ceil, overflowOrDiv0);
697  });
698 
699  return overflowOrDiv0 ? Attribute() : result;
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // RemUIOp
704 //===----------------------------------------------------------------------===//
705 
706 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
707  // remui (x, 1) -> 0.
708  if (matchPattern(getRhs(), m_One()))
709  return Builder(getContext()).getZeroAttr(getType());
710 
711  // Don't fold if it would require a division by zero.
712  bool div0 = false;
713  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
714  [&](APInt a, const APInt &b) {
715  if (div0 || b.isNullValue()) {
716  div0 = true;
717  return a;
718  }
719  return a.urem(b);
720  });
721 
722  return div0 ? Attribute() : result;
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // RemSIOp
727 //===----------------------------------------------------------------------===//
728 
729 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
730  // remsi (x, 1) -> 0.
731  if (matchPattern(getRhs(), m_One()))
732  return Builder(getContext()).getZeroAttr(getType());
733 
734  // Don't fold if it would require a division by zero.
735  bool div0 = false;
736  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
737  [&](APInt a, const APInt &b) {
738  if (div0 || b.isNullValue()) {
739  div0 = true;
740  return a;
741  }
742  return a.srem(b);
743  });
744 
745  return div0 ? Attribute() : result;
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // AndIOp
750 //===----------------------------------------------------------------------===//
751 
752 /// Fold `and(a, and(a, b))` to `and(a, b)`
753 static Value foldAndIofAndI(arith::AndIOp op) {
754  for (bool reversePrev : {false, true}) {
755  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
756  .getDefiningOp<arith::AndIOp>();
757  if (!prev)
758  continue;
759 
760  Value other = (reversePrev ? op.getLhs() : op.getRhs());
761  if (other != prev.getLhs() && other != prev.getRhs())
762  continue;
763 
764  return prev.getResult();
765  }
766  return {};
767 }
768 
769 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
770  /// and(x, 0) -> 0
771  if (matchPattern(getRhs(), m_Zero()))
772  return getRhs();
773  /// and(x, allOnes) -> x
774  APInt intValue;
775  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
776  return getLhs();
777  /// and(x, not(x)) -> 0
778  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
779  m_ConstantInt(&intValue))) &&
780  intValue.isAllOnes())
781  return Builder(getContext()).getZeroAttr(getType());
782  /// and(not(x), x) -> 0
783  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
784  m_ConstantInt(&intValue))) &&
785  intValue.isAllOnes())
786  return Builder(getContext()).getZeroAttr(getType());
787 
788  /// and(a, and(a, b)) -> and(a, b)
789  if (Value result = foldAndIofAndI(*this))
790  return result;
791 
792  return constFoldBinaryOp<IntegerAttr>(
793  adaptor.getOperands(),
794  [](APInt a, const APInt &b) { return std::move(a) & b; });
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // OrIOp
799 //===----------------------------------------------------------------------===//
800 
801 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
802  /// or(x, 0) -> x
803  if (matchPattern(getRhs(), m_Zero()))
804  return getLhs();
805  /// or(x, <all ones>) -> <all ones>
806  if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
807  if (rhsAttr.getValue().isAllOnes())
808  return rhsAttr;
809 
810  return constFoldBinaryOp<IntegerAttr>(
811  adaptor.getOperands(),
812  [](APInt a, const APInt &b) { return std::move(a) | b; });
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // XOrIOp
817 //===----------------------------------------------------------------------===//
818 
819 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
820  /// xor(x, 0) -> x
821  if (matchPattern(getRhs(), m_Zero()))
822  return getLhs();
823  /// xor(x, x) -> 0
824  if (getLhs() == getRhs())
825  return Builder(getContext()).getZeroAttr(getType());
826  /// xor(xor(x, a), a) -> x
827  /// xor(xor(a, x), a) -> x
828  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
829  if (prev.getRhs() == getRhs())
830  return prev.getLhs();
831  if (prev.getLhs() == getRhs())
832  return prev.getRhs();
833  }
834  /// xor(a, xor(x, a)) -> x
835  /// xor(a, xor(a, x)) -> x
836  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
837  if (prev.getRhs() == getLhs())
838  return prev.getLhs();
839  if (prev.getLhs() == getLhs())
840  return prev.getRhs();
841  }
842 
843  return constFoldBinaryOp<IntegerAttr>(
844  adaptor.getOperands(),
845  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
846 }
847 
848 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
849  MLIRContext *context) {
850  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // NegFOp
855 //===----------------------------------------------------------------------===//
856 
857 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
858  /// negf(negf(x)) -> x
859  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
860  return op.getOperand();
861  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
862  [](const APFloat &a) { return -a; });
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // AddFOp
867 //===----------------------------------------------------------------------===//
868 
869 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
870  // addf(x, -0) -> x
871  if (matchPattern(getRhs(), m_NegZeroFloat()))
872  return getLhs();
873 
874  return constFoldBinaryOp<FloatAttr>(
875  adaptor.getOperands(),
876  [](const APFloat &a, const APFloat &b) { return a + b; });
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // SubFOp
881 //===----------------------------------------------------------------------===//
882 
883 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
884  // subf(x, +0) -> x
885  if (matchPattern(getRhs(), m_PosZeroFloat()))
886  return getLhs();
887 
888  return constFoldBinaryOp<FloatAttr>(
889  adaptor.getOperands(),
890  [](const APFloat &a, const APFloat &b) { return a - b; });
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // MaxFOp
895 //===----------------------------------------------------------------------===//
896 
897 OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) {
898  // maxf(x,x) -> x
899  if (getLhs() == getRhs())
900  return getRhs();
901 
902  // maxf(x, -inf) -> x
903  if (matchPattern(getRhs(), m_NegInfFloat()))
904  return getLhs();
905 
906  return constFoldBinaryOp<FloatAttr>(
907  adaptor.getOperands(),
908  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // MaxSIOp
913 //===----------------------------------------------------------------------===//
914 
915 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
916  // maxsi(x,x) -> x
917  if (getLhs() == getRhs())
918  return getRhs();
919 
920  APInt intValue;
921  // maxsi(x,MAX_INT) -> MAX_INT
922  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
923  intValue.isMaxSignedValue())
924  return getRhs();
925 
926  // maxsi(x, MIN_INT) -> x
927  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
928  intValue.isMinSignedValue())
929  return getLhs();
930 
931  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
932  [](const APInt &a, const APInt &b) {
933  return llvm::APIntOps::smax(a, b);
934  });
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // MaxUIOp
939 //===----------------------------------------------------------------------===//
940 
941 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
942  // maxui(x,x) -> x
943  if (getLhs() == getRhs())
944  return getRhs();
945 
946  APInt intValue;
947  // maxui(x,MAX_INT) -> MAX_INT
948  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
949  return getRhs();
950 
951  // maxui(x, MIN_INT) -> x
952  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
953  return getLhs();
954 
955  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
956  [](const APInt &a, const APInt &b) {
957  return llvm::APIntOps::umax(a, b);
958  });
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // MinFOp
963 //===----------------------------------------------------------------------===//
964 
965 OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) {
966  // minf(x,x) -> x
967  if (getLhs() == getRhs())
968  return getRhs();
969 
970  // minf(x, +inf) -> x
971  if (matchPattern(getRhs(), m_PosInfFloat()))
972  return getLhs();
973 
974  return constFoldBinaryOp<FloatAttr>(
975  adaptor.getOperands(),
976  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
977 }
978 
979 //===----------------------------------------------------------------------===//
980 // MinSIOp
981 //===----------------------------------------------------------------------===//
982 
983 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
984  // minsi(x,x) -> x
985  if (getLhs() == getRhs())
986  return getRhs();
987 
988  APInt intValue;
989  // minsi(x,MIN_INT) -> MIN_INT
990  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
991  intValue.isMinSignedValue())
992  return getRhs();
993 
994  // minsi(x, MAX_INT) -> x
995  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
996  intValue.isMaxSignedValue())
997  return getLhs();
998 
999  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1000  [](const APInt &a, const APInt &b) {
1001  return llvm::APIntOps::smin(a, b);
1002  });
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // MinUIOp
1007 //===----------------------------------------------------------------------===//
1008 
1009 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1010  // minui(x,x) -> x
1011  if (getLhs() == getRhs())
1012  return getRhs();
1013 
1014  APInt intValue;
1015  // minui(x,MIN_INT) -> MIN_INT
1016  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
1017  return getRhs();
1018 
1019  // minui(x, MAX_INT) -> x
1020  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
1021  return getLhs();
1022 
1023  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1024  [](const APInt &a, const APInt &b) {
1025  return llvm::APIntOps::umin(a, b);
1026  });
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // MulFOp
1031 //===----------------------------------------------------------------------===//
1032 
1033 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1034  // mulf(x, 1) -> x
1035  if (matchPattern(getRhs(), m_OneFloat()))
1036  return getLhs();
1037 
1038  return constFoldBinaryOp<FloatAttr>(
1039  adaptor.getOperands(),
1040  [](const APFloat &a, const APFloat &b) { return a * b; });
1041 }
1042 
1043 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1044  MLIRContext *context) {
1045  patterns.add<MulFOfNegF>(context);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // DivFOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1053  // divf(x, 1) -> x
1054  if (matchPattern(getRhs(), m_OneFloat()))
1055  return getLhs();
1056 
1057  return constFoldBinaryOp<FloatAttr>(
1058  adaptor.getOperands(),
1059  [](const APFloat &a, const APFloat &b) { return a / b; });
1060 }
1061 
1062 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1063  MLIRContext *context) {
1064  patterns.add<DivFOfNegF>(context);
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // RemFOp
1069 //===----------------------------------------------------------------------===//
1070 
1071 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1072  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1073  [](const APFloat &a, const APFloat &b) {
1074  APFloat result(a);
1075  (void)result.remainder(b);
1076  return result;
1077  });
1078 }
1079 
1080 //===----------------------------------------------------------------------===//
1081 // Utility functions for verifying cast ops
1082 //===----------------------------------------------------------------------===//
1083 
1084 template <typename... Types>
1085 using type_list = std::tuple<Types...> *;
1086 
1087 /// Returns a non-null type only if the provided type is one of the allowed
1088 /// types or one of the allowed shaped types of the allowed types. Returns the
1089 /// element type if a valid shaped type is provided.
1090 template <typename... ShapedTypes, typename... ElementTypes>
1093  if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
1094  return {};
1095 
1096  auto underlyingType = getElementTypeOrSelf(type);
1097  if (!underlyingType.isa<ElementTypes...>())
1098  return {};
1099 
1100  return underlyingType;
1101 }
1102 
1103 /// Get allowed underlying types for vectors and tensors.
1104 template <typename... ElementTypes>
1105 static Type getTypeIfLike(Type type) {
1108 }
1109 
1110 /// Get allowed underlying types for vectors, tensors, and memrefs.
1111 template <typename... ElementTypes>
1113  return getUnderlyingType(type,
1116 }
1117 
1118 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1119  return inputs.size() == 1 && outputs.size() == 1 &&
1120  succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // Verifiers for integer and floating point extension/truncation ops
1125 //===----------------------------------------------------------------------===//
1126 
1127 // Extend ops can only extend to a wider type.
1128 template <typename ValType, typename Op>
1130  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1131  Type dstType = getElementTypeOrSelf(op.getType());
1132 
1133  if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
1134  return op.emitError("result type ")
1135  << dstType << " must be wider than operand type " << srcType;
1136 
1137  return success();
1138 }
1139 
1140 // Truncate ops can only truncate to a shorter type.
1141 template <typename ValType, typename Op>
1143  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1144  Type dstType = getElementTypeOrSelf(op.getType());
1145 
1146  if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
1147  return op.emitError("result type ")
1148  << dstType << " must be shorter than operand type " << srcType;
1149 
1150  return success();
1151 }
1152 
1153 /// Validate a cast that changes the width of a type.
1154 template <template <typename> class WidthComparator, typename... ElementTypes>
1155 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1156  if (!areValidCastInputsAndOutputs(inputs, outputs))
1157  return false;
1158 
1159  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1160  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1161  if (!srcType || !dstType)
1162  return false;
1163 
1164  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1165  srcType.getIntOrFloatBitWidth());
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // ExtUIOp
1170 //===----------------------------------------------------------------------===//
1171 
1172 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1173  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1174  getInMutable().assign(lhs.getIn());
1175  return getResult();
1176  }
1177 
1178  Type resType = getElementTypeOrSelf(getType());
1179  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
1180  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1181  adaptor.getOperands(), getType(),
1182  [bitWidth](const APInt &a, bool &castStatus) {
1183  return a.zext(bitWidth);
1184  });
1185 }
1186 
1187 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1188  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1189 }
1190 
1192  return verifyExtOp<IntegerType>(*this);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // ExtSIOp
1197 //===----------------------------------------------------------------------===//
1198 
1199 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1200  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1201  getInMutable().assign(lhs.getIn());
1202  return getResult();
1203  }
1204 
1205  Type resType = getElementTypeOrSelf(getType());
1206  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
1207  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1208  adaptor.getOperands(), getType(),
1209  [bitWidth](const APInt &a, bool &castStatus) {
1210  return a.sext(bitWidth);
1211  });
1212 }
1213 
1214 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1215  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1216 }
1217 
1218 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1219  MLIRContext *context) {
1220  patterns.add<ExtSIOfExtUI>(context);
1221 }
1222 
1224  return verifyExtOp<IntegerType>(*this);
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // ExtFOp
1229 //===----------------------------------------------------------------------===//
1230 
1231 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1232  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1233 }
1234 
1235 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // TruncIOp
1239 //===----------------------------------------------------------------------===//
1240 
1241 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1242  // trunci(zexti(a)) -> a
1243  // trunci(sexti(a)) -> a
1244  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1245  matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
1246  return getOperand().getDefiningOp()->getOperand(0);
1247 
1248  // trunci(trunci(a)) -> trunci(a))
1249  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1250  setOperand(getOperand().getDefiningOp()->getOperand(0));
1251  return getResult();
1252  }
1253 
1254  Type resType = getElementTypeOrSelf(getType());
1255  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
1256  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1257  adaptor.getOperands(), getType(),
1258  [bitWidth](const APInt &a, bool &castStatus) {
1259  return a.trunc(bitWidth);
1260  });
1261 }
1262 
1263 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1264  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1265 }
1266 
1267 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1268  MLIRContext *context) {
1269  patterns.add<TruncIShrSIToTrunciShrUI, TruncIShrUIMulIToMulSIExtended,
1270  TruncIShrUIMulIToMulUIExtended>(context);
1271 }
1272 
1274  return verifyTruncateOp<IntegerType>(*this);
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // TruncFOp
1279 //===----------------------------------------------------------------------===//
1280 
1281 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
1282 /// can be represented without precision loss or rounding.
1283 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1284  auto constOperand = adaptor.getIn();
1285  if (!constOperand || !constOperand.isa<FloatAttr>())
1286  return {};
1287 
1288  // Convert to target type via 'double'.
1289  double sourceValue =
1290  constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1291  auto targetAttr = FloatAttr::get(getType(), sourceValue);
1292 
1293  // Propagate if constant's value does not change after truncation.
1294  if (sourceValue == targetAttr.getValue().convertToDouble())
1295  return targetAttr;
1296 
1297  return {};
1298 }
1299 
1300 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1301  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1302 }
1303 
1305  return verifyTruncateOp<FloatType>(*this);
1306 }
1307 
1308 //===----------------------------------------------------------------------===//
1309 // AndIOp
1310 //===----------------------------------------------------------------------===//
1311 
1312 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1313  MLIRContext *context) {
1314  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // OrIOp
1319 //===----------------------------------------------------------------------===//
1320 
1321 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1322  MLIRContext *context) {
1323  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // Verifiers for casts between integers and floats.
1328 //===----------------------------------------------------------------------===//
1329 
1330 template <typename From, typename To>
1331 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1332  if (!areValidCastInputsAndOutputs(inputs, outputs))
1333  return false;
1334 
1335  auto srcType = getTypeIfLike<From>(inputs.front());
1336  auto dstType = getTypeIfLike<To>(outputs.back());
1337 
1338  return srcType && dstType;
1339 }
1340 
1341 //===----------------------------------------------------------------------===//
1342 // UIToFPOp
1343 //===----------------------------------------------------------------------===//
1344 
1345 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1346  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1347 }
1348 
1349 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1350  Type resEleType = getElementTypeOrSelf(getType());
1351  return constFoldCastOp<IntegerAttr, FloatAttr>(
1352  adaptor.getOperands(), getType(),
1353  [&resEleType](const APInt &a, bool &castStatus) {
1354  FloatType floatTy = resEleType.cast<FloatType>();
1355  APFloat apf(floatTy.getFloatSemantics(),
1356  APInt::getZero(floatTy.getWidth()));
1357  apf.convertFromAPInt(a, /*IsSigned=*/false,
1358  APFloat::rmNearestTiesToEven);
1359  return apf;
1360  });
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // SIToFPOp
1365 //===----------------------------------------------------------------------===//
1366 
1367 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1368  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1369 }
1370 
1371 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1372  Type resEleType = getElementTypeOrSelf(getType());
1373  return constFoldCastOp<IntegerAttr, FloatAttr>(
1374  adaptor.getOperands(), getType(),
1375  [&resEleType](const APInt &a, bool &castStatus) {
1376  FloatType floatTy = resEleType.cast<FloatType>();
1377  APFloat apf(floatTy.getFloatSemantics(),
1378  APInt::getZero(floatTy.getWidth()));
1379  apf.convertFromAPInt(a, /*IsSigned=*/true,
1380  APFloat::rmNearestTiesToEven);
1381  return apf;
1382  });
1383 }
1384 //===----------------------------------------------------------------------===//
1385 // FPToUIOp
1386 //===----------------------------------------------------------------------===//
1387 
1388 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1389  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1390 }
1391 
1392 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1393  Type resType = getElementTypeOrSelf(getType());
1394  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
1395  return constFoldCastOp<FloatAttr, IntegerAttr>(
1396  adaptor.getOperands(), getType(),
1397  [&bitWidth](const APFloat &a, bool &castStatus) {
1398  bool ignored;
1399  APSInt api(bitWidth, /*isUnsigned=*/true);
1400  castStatus = APFloat::opInvalidOp !=
1401  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1402  return api;
1403  });
1404 }
1405 
1406 //===----------------------------------------------------------------------===//
1407 // FPToSIOp
1408 //===----------------------------------------------------------------------===//
1409 
1410 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1411  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1412 }
1413 
1414 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1415  Type resType = getElementTypeOrSelf(getType());
1416  unsigned bitWidth = resType.cast<IntegerType>().getWidth();
1417  return constFoldCastOp<FloatAttr, IntegerAttr>(
1418  adaptor.getOperands(), getType(),
1419  [&bitWidth](const APFloat &a, bool &castStatus) {
1420  bool ignored;
1421  APSInt api(bitWidth, /*isUnsigned=*/false);
1422  castStatus = APFloat::opInvalidOp !=
1423  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1424  return api;
1425  });
1426 }
1427 
1428 //===----------------------------------------------------------------------===//
1429 // IndexCastOp
1430 //===----------------------------------------------------------------------===//
1431 
1432 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1433  if (!areValidCastInputsAndOutputs(inputs, outputs))
1434  return false;
1435 
1436  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1437  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1438  if (!srcType || !dstType)
1439  return false;
1440 
1441  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1442  (srcType.isSignlessInteger() && dstType.isIndex());
1443 }
1444 
1445 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1446  TypeRange outputs) {
1447  return areIndexCastCompatible(inputs, outputs);
1448 }
1449 
1450 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1451  // index_cast(constant) -> constant
1452  // A little hack because we go through int. Otherwise, the size of the
1453  // constant might need to change.
1454  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
1455  return IntegerAttr::get(getType(), value.getInt());
1456 
1457  return {};
1458 }
1459 
1460 void arith::IndexCastOp::getCanonicalizationPatterns(
1461  RewritePatternSet &patterns, MLIRContext *context) {
1462  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // IndexCastUIOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1470  TypeRange outputs) {
1471  return areIndexCastCompatible(inputs, outputs);
1472 }
1473 
1474 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1475  // index_castui(constant) -> constant
1476  // A little hack because we go through int. Otherwise, the size of the
1477  // constant might need to change.
1478  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
1479  return IntegerAttr::get(getType(), value.getValue().getZExtValue());
1480 
1481  return {};
1482 }
1483 
1484 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1485  RewritePatternSet &patterns, MLIRContext *context) {
1486  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1487 }
1488 
1489 //===----------------------------------------------------------------------===//
1490 // BitcastOp
1491 //===----------------------------------------------------------------------===//
1492 
1493 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1494  if (!areValidCastInputsAndOutputs(inputs, outputs))
1495  return false;
1496 
1497  auto srcType =
1498  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1499  auto dstType =
1500  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1501  if (!srcType || !dstType)
1502  return false;
1503 
1504  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1505 }
1506 
1507 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1508  auto resType = getType();
1509  auto operand = adaptor.getIn();
1510  if (!operand)
1511  return {};
1512 
1513  /// Bitcast dense elements.
1514  if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1515  return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1516  /// Other shaped types unhandled.
1517  if (resType.isa<ShapedType>())
1518  return {};
1519 
1520  /// Bitcast integer or float to integer or float.
1521  APInt bits = operand.isa<FloatAttr>()
1522  ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1523  : operand.cast<IntegerAttr>().getValue();
1524 
1525  if (auto resFloatType = resType.dyn_cast<FloatType>())
1526  return FloatAttr::get(resType,
1527  APFloat(resFloatType.getFloatSemantics(), bits));
1528  return IntegerAttr::get(resType, bits);
1529 }
1530 
1531 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1532  MLIRContext *context) {
1533  patterns.add<BitcastOfBitcast>(context);
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // Helpers for compare ops
1538 //===----------------------------------------------------------------------===//
1539 
1540 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1541 static Type getI1SameShape(Type type) {
1542  auto i1Type = IntegerType::get(type.getContext(), 1);
1543  if (auto tensorType = type.dyn_cast<RankedTensorType>())
1544  return RankedTensorType::get(tensorType.getShape(), i1Type);
1545  if (type.isa<UnrankedTensorType>())
1546  return UnrankedTensorType::get(i1Type);
1547  if (auto vectorType = type.dyn_cast<VectorType>())
1548  return VectorType::get(vectorType.getShape(), i1Type,
1549  vectorType.getNumScalableDims());
1550  return i1Type;
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // CmpIOp
1555 //===----------------------------------------------------------------------===//
1556 
1557 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1558 /// comparison predicates.
1559 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1560  const APInt &lhs, const APInt &rhs) {
1561  switch (predicate) {
1562  case arith::CmpIPredicate::eq:
1563  return lhs.eq(rhs);
1564  case arith::CmpIPredicate::ne:
1565  return lhs.ne(rhs);
1566  case arith::CmpIPredicate::slt:
1567  return lhs.slt(rhs);
1568  case arith::CmpIPredicate::sle:
1569  return lhs.sle(rhs);
1570  case arith::CmpIPredicate::sgt:
1571  return lhs.sgt(rhs);
1572  case arith::CmpIPredicate::sge:
1573  return lhs.sge(rhs);
1574  case arith::CmpIPredicate::ult:
1575  return lhs.ult(rhs);
1576  case arith::CmpIPredicate::ule:
1577  return lhs.ule(rhs);
1578  case arith::CmpIPredicate::ugt:
1579  return lhs.ugt(rhs);
1580  case arith::CmpIPredicate::uge:
1581  return lhs.uge(rhs);
1582  }
1583  llvm_unreachable("unknown cmpi predicate kind");
1584 }
1585 
1586 /// Returns true if the predicate is true for two equal operands.
1587 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1588  switch (predicate) {
1589  case arith::CmpIPredicate::eq:
1590  case arith::CmpIPredicate::sle:
1591  case arith::CmpIPredicate::sge:
1592  case arith::CmpIPredicate::ule:
1593  case arith::CmpIPredicate::uge:
1594  return true;
1595  case arith::CmpIPredicate::ne:
1596  case arith::CmpIPredicate::slt:
1597  case arith::CmpIPredicate::sgt:
1598  case arith::CmpIPredicate::ult:
1599  case arith::CmpIPredicate::ugt:
1600  return false;
1601  }
1602  llvm_unreachable("unknown cmpi predicate kind");
1603 }
1604 
1605 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1606  auto boolAttr = BoolAttr::get(ctx, value);
1607  ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1608  if (!shapedType)
1609  return boolAttr;
1610  return DenseElementsAttr::get(shapedType, boolAttr);
1611 }
1612 
1613 static std::optional<int64_t> getIntegerWidth(Type t) {
1614  if (auto intType = t.dyn_cast<IntegerType>()) {
1615  return intType.getWidth();
1616  }
1617  if (auto vectorIntType = t.dyn_cast<VectorType>()) {
1618  return vectorIntType.getElementType().cast<IntegerType>().getWidth();
1619  }
1620  return std::nullopt;
1621 }
1622 
1623 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1624  // cmpi(pred, x, x)
1625  if (getLhs() == getRhs()) {
1626  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1627  return getBoolAttribute(getType(), getContext(), val);
1628  }
1629 
1630  if (matchPattern(getRhs(), m_Zero())) {
1631  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1632  // extsi(%x : i1 -> iN) != 0 -> %x
1633  std::optional<int64_t> integerWidth =
1634  getIntegerWidth(extOp.getOperand().getType());
1635  if (integerWidth && integerWidth.value() == 1 &&
1636  getPredicate() == arith::CmpIPredicate::ne)
1637  return extOp.getOperand();
1638  }
1639  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1640  // extui(%x : i1 -> iN) != 0 -> %x
1641  std::optional<int64_t> integerWidth =
1642  getIntegerWidth(extOp.getOperand().getType());
1643  if (integerWidth && integerWidth.value() == 1 &&
1644  getPredicate() == arith::CmpIPredicate::ne)
1645  return extOp.getOperand();
1646  }
1647  }
1648 
1649  // Move constant to the right side.
1650  if (adaptor.getLhs() && !adaptor.getRhs()) {
1651  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1652  using Pred = CmpIPredicate;
1653  const std::pair<Pred, Pred> invPreds[] = {
1654  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1655  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1656  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1657  {Pred::ne, Pred::ne},
1658  };
1659  Pred origPred = getPredicate();
1660  for (auto pred : invPreds) {
1661  if (origPred == pred.first) {
1662  setPredicate(pred.second);
1663  Value lhs = getLhs();
1664  Value rhs = getRhs();
1665  getLhsMutable().assign(rhs);
1666  getRhsMutable().assign(lhs);
1667  return getResult();
1668  }
1669  }
1670  llvm_unreachable("unknown cmpi predicate kind");
1671  }
1672 
1673  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
1674  if (!lhs)
1675  return {};
1676 
1677  // We are moving constants to the right side; So if lhs is constant rhs is
1678  // guaranteed to be a constant.
1679  auto rhs = adaptor.getRhs().cast<IntegerAttr>();
1680 
1681  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1682  return BoolAttr::get(getContext(), val);
1683 }
1684 
1685 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1686  MLIRContext *context) {
1687  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1688 }
1689 
1690 //===----------------------------------------------------------------------===//
1691 // CmpFOp
1692 //===----------------------------------------------------------------------===//
1693 
1694 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1695 /// comparison predicates.
1696 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1697  const APFloat &lhs, const APFloat &rhs) {
1698  auto cmpResult = lhs.compare(rhs);
1699  switch (predicate) {
1700  case arith::CmpFPredicate::AlwaysFalse:
1701  return false;
1702  case arith::CmpFPredicate::OEQ:
1703  return cmpResult == APFloat::cmpEqual;
1704  case arith::CmpFPredicate::OGT:
1705  return cmpResult == APFloat::cmpGreaterThan;
1706  case arith::CmpFPredicate::OGE:
1707  return cmpResult == APFloat::cmpGreaterThan ||
1708  cmpResult == APFloat::cmpEqual;
1709  case arith::CmpFPredicate::OLT:
1710  return cmpResult == APFloat::cmpLessThan;
1711  case arith::CmpFPredicate::OLE:
1712  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1713  case arith::CmpFPredicate::ONE:
1714  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1715  case arith::CmpFPredicate::ORD:
1716  return cmpResult != APFloat::cmpUnordered;
1717  case arith::CmpFPredicate::UEQ:
1718  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1719  case arith::CmpFPredicate::UGT:
1720  return cmpResult == APFloat::cmpUnordered ||
1721  cmpResult == APFloat::cmpGreaterThan;
1722  case arith::CmpFPredicate::UGE:
1723  return cmpResult == APFloat::cmpUnordered ||
1724  cmpResult == APFloat::cmpGreaterThan ||
1725  cmpResult == APFloat::cmpEqual;
1726  case arith::CmpFPredicate::ULT:
1727  return cmpResult == APFloat::cmpUnordered ||
1728  cmpResult == APFloat::cmpLessThan;
1729  case arith::CmpFPredicate::ULE:
1730  return cmpResult == APFloat::cmpUnordered ||
1731  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1732  case arith::CmpFPredicate::UNE:
1733  return cmpResult != APFloat::cmpEqual;
1734  case arith::CmpFPredicate::UNO:
1735  return cmpResult == APFloat::cmpUnordered;
1736  case arith::CmpFPredicate::AlwaysTrue:
1737  return true;
1738  }
1739  llvm_unreachable("unknown cmpf predicate kind");
1740 }
1741 
1742 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1743  auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
1744  auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
1745 
1746  // If one operand is NaN, making them both NaN does not change the result.
1747  if (lhs && lhs.getValue().isNaN())
1748  rhs = lhs;
1749  if (rhs && rhs.getValue().isNaN())
1750  lhs = rhs;
1751 
1752  if (!lhs || !rhs)
1753  return {};
1754 
1755  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1756  return BoolAttr::get(getContext(), val);
1757 }
1758 
1759 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1760 public:
1762 
1763  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1764  bool isUnsigned) {
1765  using namespace arith;
1766  switch (pred) {
1767  case CmpFPredicate::UEQ:
1768  case CmpFPredicate::OEQ:
1769  return CmpIPredicate::eq;
1770  case CmpFPredicate::UGT:
1771  case CmpFPredicate::OGT:
1772  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1773  case CmpFPredicate::UGE:
1774  case CmpFPredicate::OGE:
1775  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1776  case CmpFPredicate::ULT:
1777  case CmpFPredicate::OLT:
1778  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1779  case CmpFPredicate::ULE:
1780  case CmpFPredicate::OLE:
1781  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1782  case CmpFPredicate::UNE:
1783  case CmpFPredicate::ONE:
1784  return CmpIPredicate::ne;
1785  default:
1786  llvm_unreachable("Unexpected predicate!");
1787  }
1788  }
1789 
1791  PatternRewriter &rewriter) const override {
1792  FloatAttr flt;
1793  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1794  return failure();
1795 
1796  const APFloat &rhs = flt.getValue();
1797 
1798  // Don't attempt to fold a nan.
1799  if (rhs.isNaN())
1800  return failure();
1801 
1802  // Get the width of the mantissa. We don't want to hack on conversions that
1803  // might lose information from the integer, e.g. "i64 -> float"
1804  FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1805  int mantissaWidth = floatTy.getFPMantissaWidth();
1806  if (mantissaWidth <= 0)
1807  return failure();
1808 
1809  bool isUnsigned;
1810  Value intVal;
1811 
1812  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1813  isUnsigned = false;
1814  intVal = si.getIn();
1815  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1816  isUnsigned = true;
1817  intVal = ui.getIn();
1818  } else {
1819  return failure();
1820  }
1821 
1822  // Check to see that the input is converted from an integer type that is
1823  // small enough that preserves all bits.
1824  auto intTy = intVal.getType().cast<IntegerType>();
1825  auto intWidth = intTy.getWidth();
1826 
1827  // Number of bits representing values, as opposed to the sign
1828  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1829 
1830  // Following test does NOT adjust intWidth downwards for signed inputs,
1831  // because the most negative value still requires all the mantissa bits
1832  // to distinguish it from one less than that value.
1833  if ((int)intWidth > mantissaWidth) {
1834  // Conversion would lose accuracy. Check if loss can impact comparison.
1835  int exponent = ilogb(rhs);
1836  if (exponent == APFloat::IEK_Inf) {
1837  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1838  if (maxExponent < (int)valueBits) {
1839  // Conversion could create infinity.
1840  return failure();
1841  }
1842  } else {
1843  // Note that if rhs is zero or NaN, then Exp is negative
1844  // and first condition is trivially false.
1845  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1846  // Conversion could affect comparison.
1847  return failure();
1848  }
1849  }
1850  }
1851 
1852  // Convert to equivalent cmpi predicate
1853  CmpIPredicate pred;
1854  switch (op.getPredicate()) {
1855  case CmpFPredicate::ORD:
1856  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1857  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1858  /*width=*/1);
1859  return success();
1860  case CmpFPredicate::UNO:
1861  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1862  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1863  /*width=*/1);
1864  return success();
1865  default:
1866  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1867  break;
1868  }
1869 
1870  if (!isUnsigned) {
1871  // If the rhs value is > SignedMax, fold the comparison. This handles
1872  // +INF and large values.
1873  APFloat signedMax(rhs.getSemantics());
1874  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1875  APFloat::rmNearestTiesToEven);
1876  if (signedMax < rhs) { // smax < 13123.0
1877  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1878  pred == CmpIPredicate::sle)
1879  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1880  /*width=*/1);
1881  else
1882  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1883  /*width=*/1);
1884  return success();
1885  }
1886  } else {
1887  // If the rhs value is > UnsignedMax, fold the comparison. This handles
1888  // +INF and large values.
1889  APFloat unsignedMax(rhs.getSemantics());
1890  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1891  APFloat::rmNearestTiesToEven);
1892  if (unsignedMax < rhs) { // umax < 13123.0
1893  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1894  pred == CmpIPredicate::ule)
1895  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1896  /*width=*/1);
1897  else
1898  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1899  /*width=*/1);
1900  return success();
1901  }
1902  }
1903 
1904  if (!isUnsigned) {
1905  // See if the rhs value is < SignedMin.
1906  APFloat signedMin(rhs.getSemantics());
1907  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1908  APFloat::rmNearestTiesToEven);
1909  if (signedMin > rhs) { // smin > 12312.0
1910  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1911  pred == CmpIPredicate::sge)
1912  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1913  /*width=*/1);
1914  else
1915  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1916  /*width=*/1);
1917  return success();
1918  }
1919  } else {
1920  // See if the rhs value is < UnsignedMin.
1921  APFloat unsignedMin(rhs.getSemantics());
1922  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1923  APFloat::rmNearestTiesToEven);
1924  if (unsignedMin > rhs) { // umin > 12312.0
1925  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1926  pred == CmpIPredicate::uge)
1927  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1928  /*width=*/1);
1929  else
1930  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1931  /*width=*/1);
1932  return success();
1933  }
1934  }
1935 
1936  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1937  // [0, UMAX], but it may still be fractional. See if it is fractional by
1938  // casting the FP value to the integer value and back, checking for
1939  // equality. Don't do this for zero, because -0.0 is not fractional.
1940  bool ignored;
1941  APSInt rhsInt(intWidth, isUnsigned);
1942  if (APFloat::opInvalidOp ==
1943  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1944  // Undefined behavior invoked - the destination type can't represent
1945  // the input constant.
1946  return failure();
1947  }
1948 
1949  if (!rhs.isZero()) {
1950  APFloat apf(floatTy.getFloatSemantics(),
1951  APInt::getZero(floatTy.getWidth()));
1952  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1953 
1954  bool equal = apf == rhs;
1955  if (!equal) {
1956  // If we had a comparison against a fractional value, we have to adjust
1957  // the compare predicate and sometimes the value. rhsInt is rounded
1958  // towards zero at this point.
1959  switch (pred) {
1960  case CmpIPredicate::ne: // (float)int != 4.4 --> true
1961  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1962  /*width=*/1);
1963  return success();
1964  case CmpIPredicate::eq: // (float)int == 4.4 --> false
1965  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1966  /*width=*/1);
1967  return success();
1968  case CmpIPredicate::ule:
1969  // (float)int <= 4.4 --> int <= 4
1970  // (float)int <= -4.4 --> false
1971  if (rhs.isNegative()) {
1972  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1973  /*width=*/1);
1974  return success();
1975  }
1976  break;
1977  case CmpIPredicate::sle:
1978  // (float)int <= 4.4 --> int <= 4
1979  // (float)int <= -4.4 --> int < -4
1980  if (rhs.isNegative())
1981  pred = CmpIPredicate::slt;
1982  break;
1983  case CmpIPredicate::ult:
1984  // (float)int < -4.4 --> false
1985  // (float)int < 4.4 --> int <= 4
1986  if (rhs.isNegative()) {
1987  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1988  /*width=*/1);
1989  return success();
1990  }
1991  pred = CmpIPredicate::ule;
1992  break;
1993  case CmpIPredicate::slt:
1994  // (float)int < -4.4 --> int < -4
1995  // (float)int < 4.4 --> int <= 4
1996  if (!rhs.isNegative())
1997  pred = CmpIPredicate::sle;
1998  break;
1999  case CmpIPredicate::ugt:
2000  // (float)int > 4.4 --> int > 4
2001  // (float)int > -4.4 --> true
2002  if (rhs.isNegative()) {
2003  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2004  /*width=*/1);
2005  return success();
2006  }
2007  break;
2008  case CmpIPredicate::sgt:
2009  // (float)int > 4.4 --> int > 4
2010  // (float)int > -4.4 --> int >= -4
2011  if (rhs.isNegative())
2012  pred = CmpIPredicate::sge;
2013  break;
2014  case CmpIPredicate::uge:
2015  // (float)int >= -4.4 --> true
2016  // (float)int >= 4.4 --> int > 4
2017  if (rhs.isNegative()) {
2018  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2019  /*width=*/1);
2020  return success();
2021  }
2022  pred = CmpIPredicate::ugt;
2023  break;
2024  case CmpIPredicate::sge:
2025  // (float)int >= -4.4 --> int >= -4
2026  // (float)int >= 4.4 --> int > 4
2027  if (!rhs.isNegative())
2028  pred = CmpIPredicate::sgt;
2029  break;
2030  }
2031  }
2032  }
2033 
2034  // Lower this FP comparison into an appropriate integer version of the
2035  // comparison.
2036  rewriter.replaceOpWithNewOp<CmpIOp>(
2037  op, pred, intVal,
2038  rewriter.create<ConstantOp>(
2039  op.getLoc(), intVal.getType(),
2040  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2041  return success();
2042  }
2043 };
2044 
2045 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2046  MLIRContext *context) {
2047  patterns.insert<CmpFIntToFPConst>(context);
2048 }
2049 
2050 //===----------------------------------------------------------------------===//
2051 // SelectOp
2052 //===----------------------------------------------------------------------===//
2053 
2054 // Transforms a select of a boolean to arithmetic operations
2055 //
2056 // arith.select %arg, %x, %y : i1
2057 //
2058 // becomes
2059 //
2060 // and(%arg, %x) or and(!%arg, %y)
2061 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
2063 
2064  LogicalResult matchAndRewrite(arith::SelectOp op,
2065  PatternRewriter &rewriter) const override {
2066  if (!op.getType().isInteger(1))
2067  return failure();
2068 
2069  Value falseConstant =
2070  rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
2071  Value notCondition = rewriter.create<arith::XOrIOp>(
2072  op.getLoc(), op.getCondition(), falseConstant);
2073 
2074  Value trueVal = rewriter.create<arith::AndIOp>(
2075  op.getLoc(), op.getCondition(), op.getTrueValue());
2076  Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
2077  op.getFalseValue());
2078  rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
2079  return success();
2080  }
2081 };
2082 
2083 // select %arg, %c1, %c0 => extui %arg
2084 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2086 
2087  LogicalResult matchAndRewrite(arith::SelectOp op,
2088  PatternRewriter &rewriter) const override {
2089  // Cannot extui i1 to i1, or i1 to f32
2090  if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
2091  return failure();
2092 
2093  // select %x, c1, %c0 => extui %arg
2094  if (matchPattern(op.getTrueValue(), m_One()) &&
2095  matchPattern(op.getFalseValue(), m_Zero())) {
2096  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2097  op.getCondition());
2098  return success();
2099  }
2100 
2101  // select %x, c0, %c1 => extui (xor %arg, true)
2102  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2103  matchPattern(op.getFalseValue(), m_One())) {
2104  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2105  op, op.getType(),
2106  rewriter.create<arith::XOrIOp>(
2107  op.getLoc(), op.getCondition(),
2108  rewriter.create<arith::ConstantIntOp>(
2109  op.getLoc(), 1, op.getCondition().getType())));
2110  return success();
2111  }
2112 
2113  return failure();
2114  }
2115 };
2116 
2117 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2118  MLIRContext *context) {
2119  results.add<SelectI1Simplify, SelectToExtUI>(context);
2120 }
2121 
2122 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2123  Value trueVal = getTrueValue();
2124  Value falseVal = getFalseValue();
2125  if (trueVal == falseVal)
2126  return trueVal;
2127 
2128  Value condition = getCondition();
2129 
2130  // select true, %0, %1 => %0
2131  if (matchPattern(condition, m_One()))
2132  return trueVal;
2133 
2134  // select false, %0, %1 => %1
2135  if (matchPattern(condition, m_Zero()))
2136  return falseVal;
2137 
2138  // select %x, true, false => %x
2139  if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
2140  matchPattern(getFalseValue(), m_Zero()))
2141  return condition;
2142 
2143  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2144  auto pred = cmp.getPredicate();
2145  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2146  auto cmpLhs = cmp.getLhs();
2147  auto cmpRhs = cmp.getRhs();
2148 
2149  // %0 = arith.cmpi eq, %arg0, %arg1
2150  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2151 
2152  // %0 = arith.cmpi ne, %arg0, %arg1
2153  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2154 
2155  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2156  (cmpRhs == trueVal && cmpLhs == falseVal))
2157  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2158  }
2159  }
2160  return nullptr;
2161 }
2162 
2163 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2164  Type conditionType, resultType;
2166  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2167  parser.parseOptionalAttrDict(result.attributes) ||
2168  parser.parseColonType(resultType))
2169  return failure();
2170 
2171  // Check for the explicit condition type if this is a masked tensor or vector.
2172  if (succeeded(parser.parseOptionalComma())) {
2173  conditionType = resultType;
2174  if (parser.parseType(resultType))
2175  return failure();
2176  } else {
2177  conditionType = parser.getBuilder().getI1Type();
2178  }
2179 
2180  result.addTypes(resultType);
2181  return parser.resolveOperands(operands,
2182  {conditionType, resultType, resultType},
2183  parser.getNameLoc(), result.operands);
2184 }
2185 
2187  p << " " << getOperands();
2188  p.printOptionalAttrDict((*this)->getAttrs());
2189  p << " : ";
2190  if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
2191  p << condType << ", ";
2192  p << getType();
2193 }
2194 
2196  Type conditionType = getCondition().getType();
2197  if (conditionType.isSignlessInteger(1))
2198  return success();
2199 
2200  // If the result type is a vector or tensor, the type can be a mask with the
2201  // same elements.
2202  Type resultType = getType();
2203  if (!resultType.isa<TensorType, VectorType>())
2204  return emitOpError() << "expected condition to be a signless i1, but got "
2205  << conditionType;
2206  Type shapedConditionType = getI1SameShape(resultType);
2207  if (conditionType != shapedConditionType) {
2208  return emitOpError() << "expected condition type to have the same shape "
2209  "as the result type, expected "
2210  << shapedConditionType << ", but got "
2211  << conditionType;
2212  }
2213  return success();
2214 }
2215 //===----------------------------------------------------------------------===//
2216 // ShLIOp
2217 //===----------------------------------------------------------------------===//
2218 
2219 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2220  // shli(x, 0) -> x
2221  if (matchPattern(getRhs(), m_Zero()))
2222  return getLhs();
2223  // Don't fold if shifting more than the bit width.
2224  bool bounded = false;
2225  auto result = constFoldBinaryOp<IntegerAttr>(
2226  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2227  bounded = b.ule(b.getBitWidth());
2228  return a.shl(b);
2229  });
2230  return bounded ? result : Attribute();
2231 }
2232 
2233 //===----------------------------------------------------------------------===//
2234 // ShRUIOp
2235 //===----------------------------------------------------------------------===//
2236 
2237 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2238  // shrui(x, 0) -> x
2239  if (matchPattern(getRhs(), m_Zero()))
2240  return getLhs();
2241  // Don't fold if shifting more than the bit width.
2242  bool bounded = false;
2243  auto result = constFoldBinaryOp<IntegerAttr>(
2244  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2245  bounded = b.ule(b.getBitWidth());
2246  return a.lshr(b);
2247  });
2248  return bounded ? result : Attribute();
2249 }
2250 
2251 //===----------------------------------------------------------------------===//
2252 // ShRSIOp
2253 //===----------------------------------------------------------------------===//
2254 
2255 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2256  // shrsi(x, 0) -> x
2257  if (matchPattern(getRhs(), m_Zero()))
2258  return getLhs();
2259  // Don't fold if shifting more than the bit width.
2260  bool bounded = false;
2261  auto result = constFoldBinaryOp<IntegerAttr>(
2262  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2263  bounded = b.ule(b.getBitWidth());
2264  return a.ashr(b);
2265  });
2266  return bounded ? result : Attribute();
2267 }
2268 
2269 //===----------------------------------------------------------------------===//
2270 // Atomic Enum
2271 //===----------------------------------------------------------------------===//
2272 
2273 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2274 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2275  OpBuilder &builder, Location loc) {
2276  switch (kind) {
2277  case AtomicRMWKind::maxf:
2278  return builder.getFloatAttr(
2279  resultType,
2280  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
2281  /*Negative=*/true));
2282  case AtomicRMWKind::addf:
2283  case AtomicRMWKind::addi:
2284  case AtomicRMWKind::maxu:
2285  case AtomicRMWKind::ori:
2286  return builder.getZeroAttr(resultType);
2287  case AtomicRMWKind::andi:
2288  return builder.getIntegerAttr(
2289  resultType,
2290  APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
2291  case AtomicRMWKind::maxs:
2292  return builder.getIntegerAttr(
2293  resultType,
2294  APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
2295  case AtomicRMWKind::minf:
2296  return builder.getFloatAttr(
2297  resultType,
2298  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
2299  /*Negative=*/false));
2300  case AtomicRMWKind::mins:
2301  return builder.getIntegerAttr(
2302  resultType,
2303  APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
2304  case AtomicRMWKind::minu:
2305  return builder.getIntegerAttr(
2306  resultType,
2307  APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
2308  case AtomicRMWKind::muli:
2309  return builder.getIntegerAttr(resultType, 1);
2310  case AtomicRMWKind::mulf:
2311  return builder.getFloatAttr(resultType, 1);
2312  // TODO: Add remaining reduction operations.
2313  default:
2314  (void)emitOptionalError(loc, "Reduction operation type not supported");
2315  break;
2316  }
2317  return nullptr;
2318 }
2319 
2320 /// Returns the identity value associated with an AtomicRMWKind op.
2321 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2322  OpBuilder &builder, Location loc) {
2323  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2324  return builder.create<arith::ConstantOp>(loc, attr);
2325 }
2326 
2327 /// Return the value obtained by applying the reduction operation kind
2328 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2329 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2330  Location loc, Value lhs, Value rhs) {
2331  switch (op) {
2332  case AtomicRMWKind::addf:
2333  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2334  case AtomicRMWKind::addi:
2335  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2336  case AtomicRMWKind::mulf:
2337  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2338  case AtomicRMWKind::muli:
2339  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2340  case AtomicRMWKind::maxf:
2341  return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2342  case AtomicRMWKind::minf:
2343  return builder.create<arith::MinFOp>(loc, lhs, rhs);
2344  case AtomicRMWKind::maxs:
2345  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2346  case AtomicRMWKind::mins:
2347  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2348  case AtomicRMWKind::maxu:
2349  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2350  case AtomicRMWKind::minu:
2351  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2352  case AtomicRMWKind::ori:
2353  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2354  case AtomicRMWKind::andi:
2355  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2356  // TODO: Add remaining reduction operations.
2357  default:
2358  (void)emitOptionalError(loc, "Reduction operation type not supported");
2359  break;
2360  }
2361  return nullptr;
2362 }
2363 
2364 //===----------------------------------------------------------------------===//
2365 // TableGen'd op method definitions
2366 //===----------------------------------------------------------------------===//
2367 
2368 #define GET_OP_CLASSES
2369 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2370 
2371 //===----------------------------------------------------------------------===//
2372 // TableGen'd enum attribute definitions
2373 //===----------------------------------------------------------------------===//
2374 
2375 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1155
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1105
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1587
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1091
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:552
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:90
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:32
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1432
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1331
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1129
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:39
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:78
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:753
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1112
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:1541
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1613
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1118
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Definition: ArithOps.cpp:1605
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1085
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:257
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1142
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1790
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1763
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:169
U cast() const
Definition: Attributes.h:179
bool isa() const
Casting utility functions.
Definition: Attributes.h:159
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
IntegerType getI1Type()
Definition: Builders.cpp:58
IndexType getIndexType()
Definition: Builders.cpp:56
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:199
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:631
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast_or_null() const
Definition: Types.h:316
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:55
U dyn_cast() const
Definition: Types.h:311
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:105
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:109
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:190
static bool classof(Operation *op)
Definition: ArithOps.cpp:196
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:202
static bool classof(Operation *op)
Definition: ArithOps.cpp:208
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:52
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:169
static bool classof(Operation *op)
Definition: ArithOps.cpp:184
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2321
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1559
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2274
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2329
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:47
auto m_Val(Value v)
Definition: Matchers.h:357
MPInt ceil(const Fraction &f)
Definition: Fraction.h:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:266
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:299
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:310
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:305
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:292
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:271
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:248
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:284
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:276
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2064
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2087
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes