MLIR  16.0.0git
ArithmeticOps.cpp
Go to the documentation of this file.
1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 #include "llvm/ADT/APSInt.h"
21 
22 using namespace mlir;
23 using namespace mlir::arith;
24 
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28 
29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30  Attribute lhs, Attribute rhs) {
31  return builder.getIntegerAttr(res.getType(),
32  lhs.cast<IntegerAttr>().getInt() +
33  rhs.cast<IntegerAttr>().getInt());
34 }
35 
36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37  Attribute lhs, Attribute rhs) {
38  return builder.getIntegerAttr(res.getType(),
39  lhs.cast<IntegerAttr>().getInt() -
40  rhs.cast<IntegerAttr>().getInt());
41 }
42 
43 /// Invert an integer comparison predicate.
44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45  switch (pred) {
46  case arith::CmpIPredicate::eq:
47  return arith::CmpIPredicate::ne;
48  case arith::CmpIPredicate::ne:
49  return arith::CmpIPredicate::eq;
50  case arith::CmpIPredicate::slt:
51  return arith::CmpIPredicate::sge;
52  case arith::CmpIPredicate::sle:
53  return arith::CmpIPredicate::sgt;
54  case arith::CmpIPredicate::sgt:
55  return arith::CmpIPredicate::sle;
56  case arith::CmpIPredicate::sge:
57  return arith::CmpIPredicate::slt;
58  case arith::CmpIPredicate::ult:
59  return arith::CmpIPredicate::uge;
60  case arith::CmpIPredicate::ule:
61  return arith::CmpIPredicate::ugt;
62  case arith::CmpIPredicate::ugt:
63  return arith::CmpIPredicate::ule;
64  case arith::CmpIPredicate::uge:
65  return arith::CmpIPredicate::ult;
66  }
67  llvm_unreachable("unknown cmpi predicate kind");
68 }
69 
70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71  return arith::CmpIPredicateAttr::get(pred.getContext(),
72  invertPredicate(pred.getValue()));
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78 
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::ConstantOp::getAsmResultNames(
88  function_ref<void(Value, StringRef)> setNameFn) {
89  auto type = getType();
90  if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91  auto intType = type.dyn_cast<IntegerType>();
92 
93  // Sugar i1 constants with 'true' and 'false'.
94  if (intType && intType.getWidth() == 1)
95  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96 
97  // Otherwise, build a complex name with the value and type.
98  SmallString<32> specialNameBuffer;
99  llvm::raw_svector_ostream specialName(specialNameBuffer);
100  specialName << 'c' << intCst.getValue();
101  if (intType)
102  specialName << '_' << type;
103  setNameFn(getResult(), specialName.str());
104  } else {
105  setNameFn(getResult(), "cst");
106  }
107 }
108 
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
112  auto type = getType();
113  // The value's type must match the return type.
114  if (getValue().getType() != type) {
115  return emitOpError() << "value type " << getValue().getType()
116  << " must match return type: " << type;
117  }
118  // Integer values must be signless.
119  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120  return emitOpError("integer return type must be signless");
121  // Any float or elements attribute are acceptable.
122  if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123  return emitOpError(
124  "value must be an integer, float, or elements attribute");
125  }
126  return success();
127 }
128 
129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130  // The value's type must be the same as the provided type.
131  auto typedAttr = value.dyn_cast<TypedAttr>();
132  if (!typedAttr || typedAttr.getType() != type)
133  return false;
134  // Integer values must be signless.
135  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
136  return false;
137  // Integer, float, and element attributes are buildable.
138  return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
139 }
140 
141 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
142  return getValue();
143 }
144 
146  int64_t value, unsigned width) {
147  auto type = builder.getIntegerType(width);
148  arith::ConstantOp::build(builder, result, type,
149  builder.getIntegerAttr(type, value));
150 }
151 
153  int64_t value, Type type) {
154  assert(type.isSignlessInteger() &&
155  "ConstantIntOp can only have signless integer type values");
156  arith::ConstantOp::build(builder, result, type,
157  builder.getIntegerAttr(type, value));
158 }
159 
161  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
162  return constOp.getType().isSignlessInteger();
163  return false;
164 }
165 
167  const APFloat &value, FloatType type) {
168  arith::ConstantOp::build(builder, result, type,
169  builder.getFloatAttr(type, value));
170 }
171 
173  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
174  return constOp.getType().isa<FloatType>();
175  return false;
176 }
177 
179  int64_t value) {
180  arith::ConstantOp::build(builder, result, builder.getIndexType(),
181  builder.getIndexAttr(value));
182 }
183 
185  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
186  return constOp.getType().isIndex();
187  return false;
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // AddIOp
192 //===----------------------------------------------------------------------===//
193 
194 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
195  // addi(x, 0) -> x
196  if (matchPattern(getRhs(), m_Zero()))
197  return getLhs();
198 
199  // addi(subi(a, b), b) -> a
200  if (auto sub = getLhs().getDefiningOp<SubIOp>())
201  if (getRhs() == sub.getRhs())
202  return sub.getLhs();
203 
204  // addi(b, subi(a, b)) -> a
205  if (auto sub = getRhs().getDefiningOp<SubIOp>())
206  if (getLhs() == sub.getRhs())
207  return sub.getLhs();
208 
209  return constFoldBinaryOp<IntegerAttr>(
210  operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
211 }
212 
213 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
214  MLIRContext *context) {
215  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
216  context);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // SubIOp
221 //===----------------------------------------------------------------------===//
222 
223 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
224  // subi(x,x) -> 0
225  if (getOperand(0) == getOperand(1))
226  return Builder(getContext()).getZeroAttr(getType());
227  // subi(x,0) -> x
228  if (matchPattern(getRhs(), m_Zero()))
229  return getLhs();
230 
231  return constFoldBinaryOp<IntegerAttr>(
232  operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
233 }
234 
235 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
236  MLIRContext *context) {
237  patterns
238  .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
239  SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
240  context);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // MulIOp
245 //===----------------------------------------------------------------------===//
246 
247 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
248  // muli(x, 0) -> 0
249  if (matchPattern(getRhs(), m_Zero()))
250  return getRhs();
251  // muli(x, 1) -> x
252  if (matchPattern(getRhs(), m_One()))
253  return getOperand(0);
254  // TODO: Handle the overflow case.
255 
256  // default folder
257  return constFoldBinaryOp<IntegerAttr>(
258  operands, [](const APInt &a, const APInt &b) { return a * b; });
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // DivUIOp
263 //===----------------------------------------------------------------------===//
264 
265 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
266  // divui (x, 1) -> x.
267  if (matchPattern(getRhs(), m_One()))
268  return getLhs();
269 
270  // Don't fold if it would require a division by zero.
271  bool div0 = false;
272  auto result =
273  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
274  if (div0 || !b) {
275  div0 = true;
276  return a;
277  }
278  return a.udiv(b);
279  });
280 
281  return div0 ? Attribute() : result;
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // DivSIOp
286 //===----------------------------------------------------------------------===//
287 
288 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
289  // divsi (x, 1) -> x.
290  if (matchPattern(getRhs(), m_One()))
291  return getLhs();
292 
293  // Don't fold if it would overflow or if it requires a division by zero.
294  bool overflowOrDiv0 = false;
295  auto result =
296  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
297  if (overflowOrDiv0 || !b) {
298  overflowOrDiv0 = true;
299  return a;
300  }
301  return a.sdiv_ov(b, overflowOrDiv0);
302  });
303 
304  return overflowOrDiv0 ? Attribute() : result;
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Ceil and floor division folding helpers
309 //===----------------------------------------------------------------------===//
310 
311 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
312  bool &overflow) {
313  // Returns (a-1)/b + 1
314  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
315  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
316  return val.sadd_ov(one, overflow);
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // CeilDivUIOp
321 //===----------------------------------------------------------------------===//
322 
323 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
324  // ceildivui (x, 1) -> x.
325  if (matchPattern(getRhs(), m_One()))
326  return getLhs();
327 
328  bool overflowOrDiv0 = false;
329  auto result =
330  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
331  if (overflowOrDiv0 || !b) {
332  overflowOrDiv0 = true;
333  return a;
334  }
335  APInt quotient = a.udiv(b);
336  if (!a.urem(b))
337  return quotient;
338  APInt one(a.getBitWidth(), 1, true);
339  return quotient.uadd_ov(one, overflowOrDiv0);
340  });
341 
342  return overflowOrDiv0 ? Attribute() : result;
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // CeilDivSIOp
347 //===----------------------------------------------------------------------===//
348 
349 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
350  // ceildivsi (x, 1) -> x.
351  if (matchPattern(getRhs(), m_One()))
352  return getLhs();
353 
354  // Don't fold if it would overflow or if it requires a division by zero.
355  bool overflowOrDiv0 = false;
356  auto result =
357  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
358  if (overflowOrDiv0 || !b) {
359  overflowOrDiv0 = true;
360  return a;
361  }
362  if (!a)
363  return a;
364  // After this point we know that neither a or b are zero.
365  unsigned bits = a.getBitWidth();
366  APInt zero = APInt::getZero(bits);
367  bool aGtZero = a.sgt(zero);
368  bool bGtZero = b.sgt(zero);
369  if (aGtZero && bGtZero) {
370  // Both positive, return ceil(a, b).
371  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
372  }
373  if (!aGtZero && !bGtZero) {
374  // Both negative, return ceil(-a, -b).
375  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
376  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
377  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
378  }
379  if (!aGtZero && bGtZero) {
380  // A is negative, b is positive, return - ( -a / b).
381  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
382  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
383  return zero.ssub_ov(div, overflowOrDiv0);
384  }
385  // A is positive, b is negative, return - (a / -b).
386  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
387  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
388  return zero.ssub_ov(div, overflowOrDiv0);
389  });
390 
391  return overflowOrDiv0 ? Attribute() : result;
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // FloorDivSIOp
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
399  // floordivsi (x, 1) -> x.
400  if (matchPattern(getRhs(), m_One()))
401  return getLhs();
402 
403  // Don't fold if it would overflow or if it requires a division by zero.
404  bool overflowOrDiv0 = false;
405  auto result =
406  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
407  if (overflowOrDiv0 || !b) {
408  overflowOrDiv0 = true;
409  return a;
410  }
411  if (!a)
412  return a;
413  // After this point we know that neither a or b are zero.
414  unsigned bits = a.getBitWidth();
415  APInt zero = APInt::getZero(bits);
416  bool aGtZero = a.sgt(zero);
417  bool bGtZero = b.sgt(zero);
418  if (aGtZero && bGtZero) {
419  // Both positive, return a / b.
420  return a.sdiv_ov(b, overflowOrDiv0);
421  }
422  if (!aGtZero && !bGtZero) {
423  // Both negative, return -a / -b.
424  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
425  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
426  return posA.sdiv_ov(posB, overflowOrDiv0);
427  }
428  if (!aGtZero && bGtZero) {
429  // A is negative, b is positive, return - ceil(-a, b).
430  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
431  APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
432  return zero.ssub_ov(ceil, overflowOrDiv0);
433  }
434  // A is positive, b is negative, return - ceil(a, -b).
435  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
436  APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
437  return zero.ssub_ov(ceil, overflowOrDiv0);
438  });
439 
440  return overflowOrDiv0 ? Attribute() : result;
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // RemUIOp
445 //===----------------------------------------------------------------------===//
446 
447 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
448  // remui (x, 1) -> 0.
449  if (matchPattern(getRhs(), m_One()))
450  return Builder(getContext()).getZeroAttr(getType());
451 
452  // Don't fold if it would require a division by zero.
453  bool div0 = false;
454  auto result =
455  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
456  if (div0 || b.isNullValue()) {
457  div0 = true;
458  return a;
459  }
460  return a.urem(b);
461  });
462 
463  return div0 ? Attribute() : result;
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // RemSIOp
468 //===----------------------------------------------------------------------===//
469 
470 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
471  // remsi (x, 1) -> 0.
472  if (matchPattern(getRhs(), m_One()))
473  return Builder(getContext()).getZeroAttr(getType());
474 
475  // Don't fold if it would require a division by zero.
476  bool div0 = false;
477  auto result =
478  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
479  if (div0 || b.isNullValue()) {
480  div0 = true;
481  return a;
482  }
483  return a.srem(b);
484  });
485 
486  return div0 ? Attribute() : result;
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // AndIOp
491 //===----------------------------------------------------------------------===//
492 
493 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
494  /// and(x, 0) -> 0
495  if (matchPattern(getRhs(), m_Zero()))
496  return getRhs();
497  /// and(x, allOnes) -> x
498  APInt intValue;
499  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
500  return getLhs();
501 
502  return constFoldBinaryOp<IntegerAttr>(
503  operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // OrIOp
508 //===----------------------------------------------------------------------===//
509 
510 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
511  /// or(x, 0) -> x
512  if (matchPattern(getRhs(), m_Zero()))
513  return getLhs();
514  /// or(x, <all ones>) -> <all ones>
515  if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
516  if (rhsAttr.getValue().isAllOnes())
517  return rhsAttr;
518 
519  return constFoldBinaryOp<IntegerAttr>(
520  operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // XOrIOp
525 //===----------------------------------------------------------------------===//
526 
527 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
528  /// xor(x, 0) -> x
529  if (matchPattern(getRhs(), m_Zero()))
530  return getLhs();
531  /// xor(x, x) -> 0
532  if (getLhs() == getRhs())
533  return Builder(getContext()).getZeroAttr(getType());
534  /// xor(xor(x, a), a) -> x
535  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
536  if (prev.getRhs() == getRhs())
537  return prev.getLhs();
538 
539  return constFoldBinaryOp<IntegerAttr>(
540  operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
541 }
542 
543 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
544  MLIRContext *context) {
545  patterns.add<XOrINotCmpI>(context);
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // NegFOp
550 //===----------------------------------------------------------------------===//
551 
552 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
553  /// negf(negf(x)) -> x
554  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
555  return op.getOperand();
556  return constFoldUnaryOp<FloatAttr>(operands,
557  [](const APFloat &a) { return -a; });
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // AddFOp
562 //===----------------------------------------------------------------------===//
563 
564 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
565  // addf(x, -0) -> x
566  if (matchPattern(getRhs(), m_NegZeroFloat()))
567  return getLhs();
568 
569  return constFoldBinaryOp<FloatAttr>(
570  operands, [](const APFloat &a, const APFloat &b) { return a + b; });
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // SubFOp
575 //===----------------------------------------------------------------------===//
576 
577 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
578  // subf(x, +0) -> x
579  if (matchPattern(getRhs(), m_PosZeroFloat()))
580  return getLhs();
581 
582  return constFoldBinaryOp<FloatAttr>(
583  operands, [](const APFloat &a, const APFloat &b) { return a - b; });
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // MaxFOp
588 //===----------------------------------------------------------------------===//
589 
590 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
591  assert(operands.size() == 2 && "maxf takes two operands");
592 
593  // maxf(x,x) -> x
594  if (getLhs() == getRhs())
595  return getRhs();
596 
597  // maxf(x, -inf) -> x
598  if (matchPattern(getRhs(), m_NegInfFloat()))
599  return getLhs();
600 
601  return constFoldBinaryOp<FloatAttr>(
602  operands,
603  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // MaxSIOp
608 //===----------------------------------------------------------------------===//
609 
610 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
611  assert(operands.size() == 2 && "binary operation takes two operands");
612 
613  // maxsi(x,x) -> x
614  if (getLhs() == getRhs())
615  return getRhs();
616 
617  APInt intValue;
618  // maxsi(x,MAX_INT) -> MAX_INT
619  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
620  intValue.isMaxSignedValue())
621  return getRhs();
622 
623  // maxsi(x, MIN_INT) -> x
624  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
625  intValue.isMinSignedValue())
626  return getLhs();
627 
628  return constFoldBinaryOp<IntegerAttr>(operands,
629  [](const APInt &a, const APInt &b) {
630  return llvm::APIntOps::smax(a, b);
631  });
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // MaxUIOp
636 //===----------------------------------------------------------------------===//
637 
638 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
639  assert(operands.size() == 2 && "binary operation takes two operands");
640 
641  // maxui(x,x) -> x
642  if (getLhs() == getRhs())
643  return getRhs();
644 
645  APInt intValue;
646  // maxui(x,MAX_INT) -> MAX_INT
647  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
648  return getRhs();
649 
650  // maxui(x, MIN_INT) -> x
651  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
652  return getLhs();
653 
654  return constFoldBinaryOp<IntegerAttr>(operands,
655  [](const APInt &a, const APInt &b) {
656  return llvm::APIntOps::umax(a, b);
657  });
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // MinFOp
662 //===----------------------------------------------------------------------===//
663 
664 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
665  assert(operands.size() == 2 && "minf takes two operands");
666 
667  // minf(x,x) -> x
668  if (getLhs() == getRhs())
669  return getRhs();
670 
671  // minf(x, +inf) -> x
672  if (matchPattern(getRhs(), m_PosInfFloat()))
673  return getLhs();
674 
675  return constFoldBinaryOp<FloatAttr>(
676  operands,
677  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // MinSIOp
682 //===----------------------------------------------------------------------===//
683 
684 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
685  assert(operands.size() == 2 && "binary operation takes two operands");
686 
687  // minsi(x,x) -> x
688  if (getLhs() == getRhs())
689  return getRhs();
690 
691  APInt intValue;
692  // minsi(x,MIN_INT) -> MIN_INT
693  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
694  intValue.isMinSignedValue())
695  return getRhs();
696 
697  // minsi(x, MAX_INT) -> x
698  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
699  intValue.isMaxSignedValue())
700  return getLhs();
701 
702  return constFoldBinaryOp<IntegerAttr>(operands,
703  [](const APInt &a, const APInt &b) {
704  return llvm::APIntOps::smin(a, b);
705  });
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // MinUIOp
710 //===----------------------------------------------------------------------===//
711 
712 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
713  assert(operands.size() == 2 && "binary operation takes two operands");
714 
715  // minui(x,x) -> x
716  if (getLhs() == getRhs())
717  return getRhs();
718 
719  APInt intValue;
720  // minui(x,MIN_INT) -> MIN_INT
721  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
722  return getRhs();
723 
724  // minui(x, MAX_INT) -> x
725  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
726  return getLhs();
727 
728  return constFoldBinaryOp<IntegerAttr>(operands,
729  [](const APInt &a, const APInt &b) {
730  return llvm::APIntOps::umin(a, b);
731  });
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // MulFOp
736 //===----------------------------------------------------------------------===//
737 
738 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
739  // mulf(x, 1) -> x
740  if (matchPattern(getRhs(), m_OneFloat()))
741  return getLhs();
742 
743  return constFoldBinaryOp<FloatAttr>(
744  operands, [](const APFloat &a, const APFloat &b) { return a * b; });
745 }
746 
747 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
748  MLIRContext *context) {
749  patterns.add<MulFOfNegF>(context);
750 }
751 
752 //===----------------------------------------------------------------------===//
753 // DivFOp
754 //===----------------------------------------------------------------------===//
755 
756 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
757  // divf(x, 1) -> x
758  if (matchPattern(getRhs(), m_OneFloat()))
759  return getLhs();
760 
761  return constFoldBinaryOp<FloatAttr>(
762  operands, [](const APFloat &a, const APFloat &b) { return a / b; });
763 }
764 
765 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
766  MLIRContext *context) {
767  patterns.add<DivFOfNegF>(context);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // RemFOp
772 //===----------------------------------------------------------------------===//
773 
774 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
775  return constFoldBinaryOp<FloatAttr>(operands,
776  [](const APFloat &a, const APFloat &b) {
777  APFloat result(a);
778  (void)result.remainder(b);
779  return result;
780  });
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Utility functions for verifying cast ops
785 //===----------------------------------------------------------------------===//
786 
787 template <typename... Types>
788 using type_list = std::tuple<Types...> *;
789 
790 /// Returns a non-null type only if the provided type is one of the allowed
791 /// types or one of the allowed shaped types of the allowed types. Returns the
792 /// element type if a valid shaped type is provided.
793 template <typename... ShapedTypes, typename... ElementTypes>
796  if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
797  return {};
798 
799  auto underlyingType = getElementTypeOrSelf(type);
800  if (!underlyingType.isa<ElementTypes...>())
801  return {};
802 
803  return underlyingType;
804 }
805 
806 /// Get allowed underlying types for vectors and tensors.
807 template <typename... ElementTypes>
808 static Type getTypeIfLike(Type type) {
811 }
812 
813 /// Get allowed underlying types for vectors, tensors, and memrefs.
814 template <typename... ElementTypes>
816  return getUnderlyingType(type,
819 }
820 
821 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
822  return inputs.size() == 1 && outputs.size() == 1 &&
823  succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // Verifiers for integer and floating point extension/truncation ops
828 //===----------------------------------------------------------------------===//
829 
830 // Extend ops can only extend to a wider type.
831 template <typename ValType, typename Op>
833  Type srcType = getElementTypeOrSelf(op.getIn().getType());
834  Type dstType = getElementTypeOrSelf(op.getType());
835 
836  if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
837  return op.emitError("result type ")
838  << dstType << " must be wider than operand type " << srcType;
839 
840  return success();
841 }
842 
843 // Truncate ops can only truncate to a shorter type.
844 template <typename ValType, typename Op>
846  Type srcType = getElementTypeOrSelf(op.getIn().getType());
847  Type dstType = getElementTypeOrSelf(op.getType());
848 
849  if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
850  return op.emitError("result type ")
851  << dstType << " must be shorter than operand type " << srcType;
852 
853  return success();
854 }
855 
856 /// Validate a cast that changes the width of a type.
857 template <template <typename> class WidthComparator, typename... ElementTypes>
858 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
859  if (!areValidCastInputsAndOutputs(inputs, outputs))
860  return false;
861 
862  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
863  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
864  if (!srcType || !dstType)
865  return false;
866 
867  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
868  srcType.getIntOrFloatBitWidth());
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // ExtUIOp
873 //===----------------------------------------------------------------------===//
874 
875 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
876  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
877  getInMutable().assign(lhs.getIn());
878  return getResult();
879  }
880  Type resType = getType();
881  unsigned bitWidth;
882  if (auto shapedType = resType.dyn_cast<ShapedType>())
883  bitWidth = shapedType.getElementTypeBitWidth();
884  else
885  bitWidth = resType.getIntOrFloatBitWidth();
886  return constFoldCastOp<IntegerAttr, IntegerAttr>(
887  operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
888  return a.zext(bitWidth);
889  });
890 }
891 
892 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
893  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
894 }
895 
897  return verifyExtOp<IntegerType>(*this);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // ExtSIOp
902 //===----------------------------------------------------------------------===//
903 
904 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
905  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
906  getInMutable().assign(lhs.getIn());
907  return getResult();
908  }
909  Type resType = getType();
910  unsigned bitWidth;
911  if (auto shapedType = resType.dyn_cast<ShapedType>())
912  bitWidth = shapedType.getElementTypeBitWidth();
913  else
914  bitWidth = resType.getIntOrFloatBitWidth();
915  return constFoldCastOp<IntegerAttr, IntegerAttr>(
916  operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
917  return a.sext(bitWidth);
918  });
919 }
920 
921 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
922  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
923 }
924 
925 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
926  MLIRContext *context) {
927  patterns.add<ExtSIOfExtUI>(context);
928 }
929 
931  return verifyExtOp<IntegerType>(*this);
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // ExtFOp
936 //===----------------------------------------------------------------------===//
937 
938 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
939  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
940 }
941 
942 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
943 
944 //===----------------------------------------------------------------------===//
945 // TruncIOp
946 //===----------------------------------------------------------------------===//
947 
948 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
949  assert(operands.size() == 1 && "unary operation takes one operand");
950 
951  // trunci(zexti(a)) -> a
952  // trunci(sexti(a)) -> a
953  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
954  matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
955  return getOperand().getDefiningOp()->getOperand(0);
956 
957  // trunci(trunci(a)) -> trunci(a))
958  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
959  setOperand(getOperand().getDefiningOp()->getOperand(0));
960  return getResult();
961  }
962 
963  Type resType = getType();
964  unsigned bitWidth;
965  if (auto shapedType = resType.dyn_cast<ShapedType>())
966  bitWidth = shapedType.getElementTypeBitWidth();
967  else
968  bitWidth = resType.getIntOrFloatBitWidth();
969 
970  return constFoldCastOp<IntegerAttr, IntegerAttr>(
971  operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
972  return a.trunc(bitWidth);
973  });
974 }
975 
976 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
977  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
978 }
979 
981  return verifyTruncateOp<IntegerType>(*this);
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // TruncFOp
986 //===----------------------------------------------------------------------===//
987 
988 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
989 /// can be represented without precision loss or rounding.
990 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
991  assert(operands.size() == 1 && "unary operation takes one operand");
992 
993  auto constOperand = operands.front();
994  if (!constOperand || !constOperand.isa<FloatAttr>())
995  return {};
996 
997  // Convert to target type via 'double'.
998  double sourceValue =
999  constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1000  auto targetAttr = FloatAttr::get(getType(), sourceValue);
1001 
1002  // Propagate if constant's value does not change after truncation.
1003  if (sourceValue == targetAttr.getValue().convertToDouble())
1004  return targetAttr;
1005 
1006  return {};
1007 }
1008 
1009 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1010  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1011 }
1012 
1014  return verifyTruncateOp<FloatType>(*this);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // AndIOp
1019 //===----------------------------------------------------------------------===//
1020 
1021 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1022  MLIRContext *context) {
1023  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // OrIOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1031  MLIRContext *context) {
1032  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // Verifiers for casts between integers and floats.
1037 //===----------------------------------------------------------------------===//
1038 
1039 template <typename From, typename To>
1040 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1041  if (!areValidCastInputsAndOutputs(inputs, outputs))
1042  return false;
1043 
1044  auto srcType = getTypeIfLike<From>(inputs.front());
1045  auto dstType = getTypeIfLike<To>(outputs.back());
1046 
1047  return srcType && dstType;
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // UIToFPOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1055  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1056 }
1057 
1058 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1059  Type resType = getType();
1060  Type resEleType;
1061  if (auto shapedType = resType.dyn_cast<ShapedType>())
1062  resEleType = shapedType.getElementType();
1063  else
1064  resEleType = resType;
1065  return constFoldCastOp<IntegerAttr, FloatAttr>(
1066  operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1067  FloatType floatTy = resEleType.cast<FloatType>();
1068  APFloat apf(floatTy.getFloatSemantics(),
1069  APInt::getZero(floatTy.getWidth()));
1070  apf.convertFromAPInt(a, /*IsSigned=*/false,
1071  APFloat::rmNearestTiesToEven);
1072  return apf;
1073  });
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // SIToFPOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1081  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1082 }
1083 
1084 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1085  Type resType = getType();
1086  Type resEleType;
1087  if (auto shapedType = resType.dyn_cast<ShapedType>())
1088  resEleType = shapedType.getElementType();
1089  else
1090  resEleType = resType;
1091  return constFoldCastOp<IntegerAttr, FloatAttr>(
1092  operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1093  FloatType floatTy = resEleType.cast<FloatType>();
1094  APFloat apf(floatTy.getFloatSemantics(),
1095  APInt::getZero(floatTy.getWidth()));
1096  apf.convertFromAPInt(a, /*IsSigned=*/true,
1097  APFloat::rmNearestTiesToEven);
1098  return apf;
1099  });
1100 }
1101 //===----------------------------------------------------------------------===//
1102 // FPToUIOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1106  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1107 }
1108 
1109 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1110  Type resType = getType();
1111  Type resEleType;
1112  if (auto shapedType = resType.dyn_cast<ShapedType>())
1113  resEleType = shapedType.getElementType();
1114  else
1115  resEleType = resType;
1116  return constFoldCastOp<FloatAttr, IntegerAttr>(
1117  operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1118  IntegerType intTy = resEleType.cast<IntegerType>();
1119  bool ignored;
1120  APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1121  castStatus = APFloat::opInvalidOp !=
1122  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1123  return api;
1124  });
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // FPToSIOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1132  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1133 }
1134 
1135 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1136  Type resType = getType();
1137  Type resEleType;
1138  if (auto shapedType = resType.dyn_cast<ShapedType>())
1139  resEleType = shapedType.getElementType();
1140  else
1141  resEleType = resType;
1142  return constFoldCastOp<FloatAttr, IntegerAttr>(
1143  operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1144  IntegerType intTy = resEleType.cast<IntegerType>();
1145  bool ignored;
1146  APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1147  castStatus = APFloat::opInvalidOp !=
1148  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1149  return api;
1150  });
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // IndexCastOp
1155 //===----------------------------------------------------------------------===//
1156 
1157 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1158  TypeRange outputs) {
1159  if (!areValidCastInputsAndOutputs(inputs, outputs))
1160  return false;
1161 
1162  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1163  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1164  if (!srcType || !dstType)
1165  return false;
1166 
1167  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1168  (srcType.isSignlessInteger() && dstType.isIndex());
1169 }
1170 
1171 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1172  // index_cast(constant) -> constant
1173  // A little hack because we go through int. Otherwise, the size of the
1174  // constant might need to change.
1175  if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1176  return IntegerAttr::get(getType(), value.getInt());
1177 
1178  return {};
1179 }
1180 
1181 void arith::IndexCastOp::getCanonicalizationPatterns(
1182  RewritePatternSet &patterns, MLIRContext *context) {
1183  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // BitcastOp
1188 //===----------------------------------------------------------------------===//
1189 
1190 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1191  if (!areValidCastInputsAndOutputs(inputs, outputs))
1192  return false;
1193 
1194  auto srcType =
1195  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1196  auto dstType =
1197  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1198  if (!srcType || !dstType)
1199  return false;
1200 
1201  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1202 }
1203 
1204 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1205  assert(operands.size() == 1 && "bitcast op expects 1 operand");
1206 
1207  auto resType = getType();
1208  auto operand = operands[0];
1209  if (!operand)
1210  return {};
1211 
1212  /// Bitcast dense elements.
1213  if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1214  return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1215  /// Other shaped types unhandled.
1216  if (resType.isa<ShapedType>())
1217  return {};
1218 
1219  /// Bitcast integer or float to integer or float.
1220  APInt bits = operand.isa<FloatAttr>()
1221  ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1222  : operand.cast<IntegerAttr>().getValue();
1223 
1224  if (auto resFloatType = resType.dyn_cast<FloatType>())
1225  return FloatAttr::get(resType,
1226  APFloat(resFloatType.getFloatSemantics(), bits));
1227  return IntegerAttr::get(resType, bits);
1228 }
1229 
1230 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1231  MLIRContext *context) {
1232  patterns.add<BitcastOfBitcast>(context);
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // Helpers for compare ops
1237 //===----------------------------------------------------------------------===//
1238 
1239 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1240 static Type getI1SameShape(Type type) {
1241  auto i1Type = IntegerType::get(type.getContext(), 1);
1242  if (auto tensorType = type.dyn_cast<RankedTensorType>())
1243  return RankedTensorType::get(tensorType.getShape(), i1Type);
1244  if (type.isa<UnrankedTensorType>())
1245  return UnrankedTensorType::get(i1Type);
1246  if (auto vectorType = type.dyn_cast<VectorType>())
1247  return VectorType::get(vectorType.getShape(), i1Type,
1248  vectorType.getNumScalableDims());
1249  return i1Type;
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // CmpIOp
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1257 /// comparison predicates.
1258 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1259  const APInt &lhs, const APInt &rhs) {
1260  switch (predicate) {
1261  case arith::CmpIPredicate::eq:
1262  return lhs.eq(rhs);
1263  case arith::CmpIPredicate::ne:
1264  return lhs.ne(rhs);
1265  case arith::CmpIPredicate::slt:
1266  return lhs.slt(rhs);
1267  case arith::CmpIPredicate::sle:
1268  return lhs.sle(rhs);
1269  case arith::CmpIPredicate::sgt:
1270  return lhs.sgt(rhs);
1271  case arith::CmpIPredicate::sge:
1272  return lhs.sge(rhs);
1273  case arith::CmpIPredicate::ult:
1274  return lhs.ult(rhs);
1275  case arith::CmpIPredicate::ule:
1276  return lhs.ule(rhs);
1277  case arith::CmpIPredicate::ugt:
1278  return lhs.ugt(rhs);
1279  case arith::CmpIPredicate::uge:
1280  return lhs.uge(rhs);
1281  }
1282  llvm_unreachable("unknown cmpi predicate kind");
1283 }
1284 
1285 /// Returns true if the predicate is true for two equal operands.
1286 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1287  switch (predicate) {
1288  case arith::CmpIPredicate::eq:
1289  case arith::CmpIPredicate::sle:
1290  case arith::CmpIPredicate::sge:
1291  case arith::CmpIPredicate::ule:
1292  case arith::CmpIPredicate::uge:
1293  return true;
1294  case arith::CmpIPredicate::ne:
1295  case arith::CmpIPredicate::slt:
1296  case arith::CmpIPredicate::sgt:
1297  case arith::CmpIPredicate::ult:
1298  case arith::CmpIPredicate::ugt:
1299  return false;
1300  }
1301  llvm_unreachable("unknown cmpi predicate kind");
1302 }
1303 
1304 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1305  auto boolAttr = BoolAttr::get(ctx, value);
1306  ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1307  if (!shapedType)
1308  return boolAttr;
1309  return DenseElementsAttr::get(shapedType, boolAttr);
1310 }
1311 
1312 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1313  assert(operands.size() == 2 && "cmpi takes two operands");
1314 
1315  // cmpi(pred, x, x)
1316  if (getLhs() == getRhs()) {
1317  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1318  return getBoolAttribute(getType(), getContext(), val);
1319  }
1320 
1321  if (matchPattern(getRhs(), m_Zero())) {
1322  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1323  // extsi(%x : i1 -> iN) != 0 -> %x
1324  if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1325  getPredicate() == arith::CmpIPredicate::ne)
1326  return extOp.getOperand();
1327  }
1328  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1329  // extui(%x : i1 -> iN) != 0 -> %x
1330  if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1331  getPredicate() == arith::CmpIPredicate::ne)
1332  return extOp.getOperand();
1333  }
1334  }
1335 
1336  // Move constant to the right side.
1337  if (operands[0] && !operands[1]) {
1338  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1339  using Pred = CmpIPredicate;
1340  const std::pair<Pred, Pred> invPreds[] = {
1341  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1342  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1343  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1344  {Pred::ne, Pred::ne},
1345  };
1346  Pred origPred = getPredicate();
1347  for (auto pred : invPreds) {
1348  if (origPred == pred.first) {
1349  setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
1350  Value lhs = getLhs();
1351  Value rhs = getRhs();
1352  getLhsMutable().assign(rhs);
1353  getRhsMutable().assign(lhs);
1354  return getResult();
1355  }
1356  }
1357  llvm_unreachable("unknown cmpi predicate kind");
1358  }
1359 
1360  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1361  if (!lhs)
1362  return {};
1363 
1364  // We are moving constants to the right side; So if lhs is constant rhs is
1365  // guaranteed to be a constant.
1366  auto rhs = operands.back().cast<IntegerAttr>();
1367 
1368  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1369  return BoolAttr::get(getContext(), val);
1370 }
1371 
1372 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1373  MLIRContext *context) {
1374  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // CmpFOp
1379 //===----------------------------------------------------------------------===//
1380 
1381 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1382 /// comparison predicates.
1383 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1384  const APFloat &lhs, const APFloat &rhs) {
1385  auto cmpResult = lhs.compare(rhs);
1386  switch (predicate) {
1387  case arith::CmpFPredicate::AlwaysFalse:
1388  return false;
1389  case arith::CmpFPredicate::OEQ:
1390  return cmpResult == APFloat::cmpEqual;
1391  case arith::CmpFPredicate::OGT:
1392  return cmpResult == APFloat::cmpGreaterThan;
1393  case arith::CmpFPredicate::OGE:
1394  return cmpResult == APFloat::cmpGreaterThan ||
1395  cmpResult == APFloat::cmpEqual;
1396  case arith::CmpFPredicate::OLT:
1397  return cmpResult == APFloat::cmpLessThan;
1398  case arith::CmpFPredicate::OLE:
1399  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1400  case arith::CmpFPredicate::ONE:
1401  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1402  case arith::CmpFPredicate::ORD:
1403  return cmpResult != APFloat::cmpUnordered;
1404  case arith::CmpFPredicate::UEQ:
1405  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1406  case arith::CmpFPredicate::UGT:
1407  return cmpResult == APFloat::cmpUnordered ||
1408  cmpResult == APFloat::cmpGreaterThan;
1409  case arith::CmpFPredicate::UGE:
1410  return cmpResult == APFloat::cmpUnordered ||
1411  cmpResult == APFloat::cmpGreaterThan ||
1412  cmpResult == APFloat::cmpEqual;
1413  case arith::CmpFPredicate::ULT:
1414  return cmpResult == APFloat::cmpUnordered ||
1415  cmpResult == APFloat::cmpLessThan;
1416  case arith::CmpFPredicate::ULE:
1417  return cmpResult == APFloat::cmpUnordered ||
1418  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1419  case arith::CmpFPredicate::UNE:
1420  return cmpResult != APFloat::cmpEqual;
1421  case arith::CmpFPredicate::UNO:
1422  return cmpResult == APFloat::cmpUnordered;
1423  case arith::CmpFPredicate::AlwaysTrue:
1424  return true;
1425  }
1426  llvm_unreachable("unknown cmpf predicate kind");
1427 }
1428 
1429 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1430  assert(operands.size() == 2 && "cmpf takes two operands");
1431 
1432  auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1433  auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1434 
1435  // If one operand is NaN, making them both NaN does not change the result.
1436  if (lhs && lhs.getValue().isNaN())
1437  rhs = lhs;
1438  if (rhs && rhs.getValue().isNaN())
1439  lhs = rhs;
1440 
1441  if (!lhs || !rhs)
1442  return {};
1443 
1444  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1445  return BoolAttr::get(getContext(), val);
1446 }
1447 
1448 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1449 public:
1451 
1452  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1453  bool isUnsigned) {
1454  using namespace arith;
1455  switch (pred) {
1456  case CmpFPredicate::UEQ:
1457  case CmpFPredicate::OEQ:
1458  return CmpIPredicate::eq;
1459  case CmpFPredicate::UGT:
1460  case CmpFPredicate::OGT:
1461  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1462  case CmpFPredicate::UGE:
1463  case CmpFPredicate::OGE:
1464  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1465  case CmpFPredicate::ULT:
1466  case CmpFPredicate::OLT:
1467  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1468  case CmpFPredicate::ULE:
1469  case CmpFPredicate::OLE:
1470  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1471  case CmpFPredicate::UNE:
1472  case CmpFPredicate::ONE:
1473  return CmpIPredicate::ne;
1474  default:
1475  llvm_unreachable("Unexpected predicate!");
1476  }
1477  }
1478 
1480  PatternRewriter &rewriter) const override {
1481  FloatAttr flt;
1482  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1483  return failure();
1484 
1485  const APFloat &rhs = flt.getValue();
1486 
1487  // Don't attempt to fold a nan.
1488  if (rhs.isNaN())
1489  return failure();
1490 
1491  // Get the width of the mantissa. We don't want to hack on conversions that
1492  // might lose information from the integer, e.g. "i64 -> float"
1493  FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1494  int mantissaWidth = floatTy.getFPMantissaWidth();
1495  if (mantissaWidth <= 0)
1496  return failure();
1497 
1498  bool isUnsigned;
1499  Value intVal;
1500 
1501  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1502  isUnsigned = false;
1503  intVal = si.getIn();
1504  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1505  isUnsigned = true;
1506  intVal = ui.getIn();
1507  } else {
1508  return failure();
1509  }
1510 
1511  // Check to see that the input is converted from an integer type that is
1512  // small enough that preserves all bits.
1513  auto intTy = intVal.getType().cast<IntegerType>();
1514  auto intWidth = intTy.getWidth();
1515 
1516  // Number of bits representing values, as opposed to the sign
1517  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1518 
1519  // Following test does NOT adjust intWidth downwards for signed inputs,
1520  // because the most negative value still requires all the mantissa bits
1521  // to distinguish it from one less than that value.
1522  if ((int)intWidth > mantissaWidth) {
1523  // Conversion would lose accuracy. Check if loss can impact comparison.
1524  int exponent = ilogb(rhs);
1525  if (exponent == APFloat::IEK_Inf) {
1526  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1527  if (maxExponent < (int)valueBits) {
1528  // Conversion could create infinity.
1529  return failure();
1530  }
1531  } else {
1532  // Note that if rhs is zero or NaN, then Exp is negative
1533  // and first condition is trivially false.
1534  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1535  // Conversion could affect comparison.
1536  return failure();
1537  }
1538  }
1539  }
1540 
1541  // Convert to equivalent cmpi predicate
1542  CmpIPredicate pred;
1543  switch (op.getPredicate()) {
1544  case CmpFPredicate::ORD:
1545  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1546  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1547  /*width=*/1);
1548  return success();
1549  case CmpFPredicate::UNO:
1550  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1551  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1552  /*width=*/1);
1553  return success();
1554  default:
1555  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1556  break;
1557  }
1558 
1559  if (!isUnsigned) {
1560  // If the rhs value is > SignedMax, fold the comparison. This handles
1561  // +INF and large values.
1562  APFloat signedMax(rhs.getSemantics());
1563  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1564  APFloat::rmNearestTiesToEven);
1565  if (signedMax < rhs) { // smax < 13123.0
1566  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1567  pred == CmpIPredicate::sle)
1568  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1569  /*width=*/1);
1570  else
1571  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1572  /*width=*/1);
1573  return success();
1574  }
1575  } else {
1576  // If the rhs value is > UnsignedMax, fold the comparison. This handles
1577  // +INF and large values.
1578  APFloat unsignedMax(rhs.getSemantics());
1579  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1580  APFloat::rmNearestTiesToEven);
1581  if (unsignedMax < rhs) { // umax < 13123.0
1582  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1583  pred == CmpIPredicate::ule)
1584  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1585  /*width=*/1);
1586  else
1587  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1588  /*width=*/1);
1589  return success();
1590  }
1591  }
1592 
1593  if (!isUnsigned) {
1594  // See if the rhs value is < SignedMin.
1595  APFloat signedMin(rhs.getSemantics());
1596  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1597  APFloat::rmNearestTiesToEven);
1598  if (signedMin > rhs) { // smin > 12312.0
1599  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1600  pred == CmpIPredicate::sge)
1601  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1602  /*width=*/1);
1603  else
1604  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1605  /*width=*/1);
1606  return success();
1607  }
1608  } else {
1609  // See if the rhs value is < UnsignedMin.
1610  APFloat unsignedMin(rhs.getSemantics());
1611  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1612  APFloat::rmNearestTiesToEven);
1613  if (unsignedMin > rhs) { // umin > 12312.0
1614  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1615  pred == CmpIPredicate::uge)
1616  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1617  /*width=*/1);
1618  else
1619  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1620  /*width=*/1);
1621  return success();
1622  }
1623  }
1624 
1625  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1626  // [0, UMAX], but it may still be fractional. See if it is fractional by
1627  // casting the FP value to the integer value and back, checking for
1628  // equality. Don't do this for zero, because -0.0 is not fractional.
1629  bool ignored;
1630  APSInt rhsInt(intWidth, isUnsigned);
1631  if (APFloat::opInvalidOp ==
1632  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1633  // Undefined behavior invoked - the destination type can't represent
1634  // the input constant.
1635  return failure();
1636  }
1637 
1638  if (!rhs.isZero()) {
1639  APFloat apf(floatTy.getFloatSemantics(),
1640  APInt::getZero(floatTy.getWidth()));
1641  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1642 
1643  bool equal = apf == rhs;
1644  if (!equal) {
1645  // If we had a comparison against a fractional value, we have to adjust
1646  // the compare predicate and sometimes the value. rhsInt is rounded
1647  // towards zero at this point.
1648  switch (pred) {
1649  case CmpIPredicate::ne: // (float)int != 4.4 --> true
1650  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1651  /*width=*/1);
1652  return success();
1653  case CmpIPredicate::eq: // (float)int == 4.4 --> false
1654  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1655  /*width=*/1);
1656  return success();
1657  case CmpIPredicate::ule:
1658  // (float)int <= 4.4 --> int <= 4
1659  // (float)int <= -4.4 --> false
1660  if (rhs.isNegative()) {
1661  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1662  /*width=*/1);
1663  return success();
1664  }
1665  break;
1666  case CmpIPredicate::sle:
1667  // (float)int <= 4.4 --> int <= 4
1668  // (float)int <= -4.4 --> int < -4
1669  if (rhs.isNegative())
1670  pred = CmpIPredicate::slt;
1671  break;
1672  case CmpIPredicate::ult:
1673  // (float)int < -4.4 --> false
1674  // (float)int < 4.4 --> int <= 4
1675  if (rhs.isNegative()) {
1676  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1677  /*width=*/1);
1678  return success();
1679  }
1680  pred = CmpIPredicate::ule;
1681  break;
1682  case CmpIPredicate::slt:
1683  // (float)int < -4.4 --> int < -4
1684  // (float)int < 4.4 --> int <= 4
1685  if (!rhs.isNegative())
1686  pred = CmpIPredicate::sle;
1687  break;
1688  case CmpIPredicate::ugt:
1689  // (float)int > 4.4 --> int > 4
1690  // (float)int > -4.4 --> true
1691  if (rhs.isNegative()) {
1692  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1693  /*width=*/1);
1694  return success();
1695  }
1696  break;
1697  case CmpIPredicate::sgt:
1698  // (float)int > 4.4 --> int > 4
1699  // (float)int > -4.4 --> int >= -4
1700  if (rhs.isNegative())
1701  pred = CmpIPredicate::sge;
1702  break;
1703  case CmpIPredicate::uge:
1704  // (float)int >= -4.4 --> true
1705  // (float)int >= 4.4 --> int > 4
1706  if (rhs.isNegative()) {
1707  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1708  /*width=*/1);
1709  return success();
1710  }
1711  pred = CmpIPredicate::ugt;
1712  break;
1713  case CmpIPredicate::sge:
1714  // (float)int >= -4.4 --> int >= -4
1715  // (float)int >= 4.4 --> int > 4
1716  if (!rhs.isNegative())
1717  pred = CmpIPredicate::sgt;
1718  break;
1719  }
1720  }
1721  }
1722 
1723  // Lower this FP comparison into an appropriate integer version of the
1724  // comparison.
1725  rewriter.replaceOpWithNewOp<CmpIOp>(
1726  op, pred, intVal,
1727  rewriter.create<ConstantOp>(
1728  op.getLoc(), intVal.getType(),
1729  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1730  return success();
1731  }
1732 };
1733 
1734 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1735  MLIRContext *context) {
1736  patterns.insert<CmpFIntToFPConst>(context);
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // SelectOp
1741 //===----------------------------------------------------------------------===//
1742 
1743 // Transforms a select of a boolean to arithmetic operations
1744 //
1745 // arith.select %arg, %x, %y : i1
1746 //
1747 // becomes
1748 //
1749 // and(%arg, %x) or and(!%arg, %y)
1750 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1752 
1753  LogicalResult matchAndRewrite(arith::SelectOp op,
1754  PatternRewriter &rewriter) const override {
1755  if (!op.getType().isInteger(1))
1756  return failure();
1757 
1758  Value falseConstant =
1759  rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1760  Value notCondition = rewriter.create<arith::XOrIOp>(
1761  op.getLoc(), op.getCondition(), falseConstant);
1762 
1763  Value trueVal = rewriter.create<arith::AndIOp>(
1764  op.getLoc(), op.getCondition(), op.getTrueValue());
1765  Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1766  op.getFalseValue());
1767  rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1768  return success();
1769  }
1770 };
1771 
1772 // select %arg, %c1, %c0 => extui %arg
1773 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1775 
1776  LogicalResult matchAndRewrite(arith::SelectOp op,
1777  PatternRewriter &rewriter) const override {
1778  // Cannot extui i1 to i1, or i1 to f32
1779  if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1780  return failure();
1781 
1782  // select %x, c1, %c0 => extui %arg
1783  if (matchPattern(op.getTrueValue(), m_One()) &&
1784  matchPattern(op.getFalseValue(), m_Zero())) {
1785  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1786  op.getCondition());
1787  return success();
1788  }
1789 
1790  // select %x, c0, %c1 => extui (xor %arg, true)
1791  if (matchPattern(op.getTrueValue(), m_Zero()) &&
1792  matchPattern(op.getFalseValue(), m_One())) {
1793  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1794  op, op.getType(),
1795  rewriter.create<arith::XOrIOp>(
1796  op.getLoc(), op.getCondition(),
1797  rewriter.create<arith::ConstantIntOp>(
1798  op.getLoc(), 1, op.getCondition().getType())));
1799  return success();
1800  }
1801 
1802  return failure();
1803  }
1804 };
1805 
1806 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1807  MLIRContext *context) {
1808  results.add<SelectI1Simplify, SelectToExtUI>(context);
1809 }
1810 
1811 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1812  Value trueVal = getTrueValue();
1813  Value falseVal = getFalseValue();
1814  if (trueVal == falseVal)
1815  return trueVal;
1816 
1817  Value condition = getCondition();
1818 
1819  // select true, %0, %1 => %0
1820  if (matchPattern(condition, m_One()))
1821  return trueVal;
1822 
1823  // select false, %0, %1 => %1
1824  if (matchPattern(condition, m_Zero()))
1825  return falseVal;
1826 
1827  // select %x, true, false => %x
1828  if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1829  matchPattern(getFalseValue(), m_Zero()))
1830  return condition;
1831 
1832  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1833  auto pred = cmp.getPredicate();
1834  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1835  auto cmpLhs = cmp.getLhs();
1836  auto cmpRhs = cmp.getRhs();
1837 
1838  // %0 = arith.cmpi eq, %arg0, %arg1
1839  // %1 = arith.select %0, %arg0, %arg1 => %arg1
1840 
1841  // %0 = arith.cmpi ne, %arg0, %arg1
1842  // %1 = arith.select %0, %arg0, %arg1 => %arg0
1843 
1844  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1845  (cmpRhs == trueVal && cmpLhs == falseVal))
1846  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1847  }
1848  }
1849  return nullptr;
1850 }
1851 
1852 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1853  Type conditionType, resultType;
1855  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1856  parser.parseOptionalAttrDict(result.attributes) ||
1857  parser.parseColonType(resultType))
1858  return failure();
1859 
1860  // Check for the explicit condition type if this is a masked tensor or vector.
1861  if (succeeded(parser.parseOptionalComma())) {
1862  conditionType = resultType;
1863  if (parser.parseType(resultType))
1864  return failure();
1865  } else {
1866  conditionType = parser.getBuilder().getI1Type();
1867  }
1868 
1869  result.addTypes(resultType);
1870  return parser.resolveOperands(operands,
1871  {conditionType, resultType, resultType},
1872  parser.getNameLoc(), result.operands);
1873 }
1874 
1876  p << " " << getOperands();
1877  p.printOptionalAttrDict((*this)->getAttrs());
1878  p << " : ";
1879  if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1880  p << condType << ", ";
1881  p << getType();
1882 }
1883 
1885  Type conditionType = getCondition().getType();
1886  if (conditionType.isSignlessInteger(1))
1887  return success();
1888 
1889  // If the result type is a vector or tensor, the type can be a mask with the
1890  // same elements.
1891  Type resultType = getType();
1892  if (!resultType.isa<TensorType, VectorType>())
1893  return emitOpError() << "expected condition to be a signless i1, but got "
1894  << conditionType;
1895  Type shapedConditionType = getI1SameShape(resultType);
1896  if (conditionType != shapedConditionType) {
1897  return emitOpError() << "expected condition type to have the same shape "
1898  "as the result type, expected "
1899  << shapedConditionType << ", but got "
1900  << conditionType;
1901  }
1902  return success();
1903 }
1904 //===----------------------------------------------------------------------===//
1905 // ShLIOp
1906 //===----------------------------------------------------------------------===//
1907 
1908 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1909  // Don't fold if shifting more than the bit width.
1910  bool bounded = false;
1911  auto result = constFoldBinaryOp<IntegerAttr>(
1912  operands, [&](const APInt &a, const APInt &b) {
1913  bounded = b.ule(b.getBitWidth());
1914  return a.shl(b);
1915  });
1916  return bounded ? result : Attribute();
1917 }
1918 
1919 //===----------------------------------------------------------------------===//
1920 // ShRUIOp
1921 //===----------------------------------------------------------------------===//
1922 
1923 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1924  // Don't fold if shifting more than the bit width.
1925  bool bounded = false;
1926  auto result = constFoldBinaryOp<IntegerAttr>(
1927  operands, [&](const APInt &a, const APInt &b) {
1928  bounded = b.ule(b.getBitWidth());
1929  return a.lshr(b);
1930  });
1931  return bounded ? result : Attribute();
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // ShRSIOp
1936 //===----------------------------------------------------------------------===//
1937 
1938 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1939  // Don't fold if shifting more than the bit width.
1940  bool bounded = false;
1941  auto result = constFoldBinaryOp<IntegerAttr>(
1942  operands, [&](const APInt &a, const APInt &b) {
1943  bounded = b.ule(b.getBitWidth());
1944  return a.ashr(b);
1945  });
1946  return bounded ? result : Attribute();
1947 }
1948 
1949 //===----------------------------------------------------------------------===//
1950 // Atomic Enum
1951 //===----------------------------------------------------------------------===//
1952 
1953 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1954 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1955  OpBuilder &builder, Location loc) {
1956  switch (kind) {
1957  case AtomicRMWKind::maxf:
1958  return builder.getFloatAttr(
1959  resultType,
1960  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1961  /*Negative=*/true));
1962  case AtomicRMWKind::addf:
1963  case AtomicRMWKind::addi:
1964  case AtomicRMWKind::maxu:
1965  case AtomicRMWKind::ori:
1966  return builder.getZeroAttr(resultType);
1967  case AtomicRMWKind::andi:
1968  return builder.getIntegerAttr(
1969  resultType,
1970  APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1971  case AtomicRMWKind::maxs:
1972  return builder.getIntegerAttr(
1973  resultType,
1974  APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1975  case AtomicRMWKind::minf:
1976  return builder.getFloatAttr(
1977  resultType,
1978  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1979  /*Negative=*/false));
1980  case AtomicRMWKind::mins:
1981  return builder.getIntegerAttr(
1982  resultType,
1983  APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1984  case AtomicRMWKind::minu:
1985  return builder.getIntegerAttr(
1986  resultType,
1987  APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1988  case AtomicRMWKind::muli:
1989  return builder.getIntegerAttr(resultType, 1);
1990  case AtomicRMWKind::mulf:
1991  return builder.getFloatAttr(resultType, 1);
1992  // TODO: Add remaining reduction operations.
1993  default:
1994  (void)emitOptionalError(loc, "Reduction operation type not supported");
1995  break;
1996  }
1997  return nullptr;
1998 }
1999 
2000 /// Returns the identity value associated with an AtomicRMWKind op.
2001 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2002  OpBuilder &builder, Location loc) {
2003  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2004  return builder.create<arith::ConstantOp>(loc, attr);
2005 }
2006 
2007 /// Return the value obtained by applying the reduction operation kind
2008 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2009 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2010  Location loc, Value lhs, Value rhs) {
2011  switch (op) {
2012  case AtomicRMWKind::addf:
2013  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2014  case AtomicRMWKind::addi:
2015  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2016  case AtomicRMWKind::mulf:
2017  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2018  case AtomicRMWKind::muli:
2019  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2020  case AtomicRMWKind::maxf:
2021  return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2022  case AtomicRMWKind::minf:
2023  return builder.create<arith::MinFOp>(loc, lhs, rhs);
2024  case AtomicRMWKind::maxs:
2025  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2026  case AtomicRMWKind::mins:
2027  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2028  case AtomicRMWKind::maxu:
2029  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2030  case AtomicRMWKind::minu:
2031  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2032  case AtomicRMWKind::ori:
2033  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2034  case AtomicRMWKind::andi:
2035  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2036  // TODO: Add remaining reduction operations.
2037  default:
2038  (void)emitOptionalError(loc, "Reduction operation type not supported");
2039  break;
2040  }
2041  return nullptr;
2042 }
2043 
2044 //===----------------------------------------------------------------------===//
2045 // TableGen'd op method definitions
2046 //===----------------------------------------------------------------------===//
2047 
2048 #define GET_OP_CLASSES
2049 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2050 
2051 //===----------------------------------------------------------------------===//
2052 // TableGen'd enum attribute definitions
2053 //===----------------------------------------------------------------------===//
2054 
2055 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
Include the generated interface declarations.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:302
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
U cast() const
Definition: Attributes.h:135
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:355
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:43
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
bool isa() const
Definition: Attributes.h:111
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:276
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:281
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
static constexpr const bool value
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:217
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static bool classof(Operation *op)
unsigned getWidth()
Return the bitwidth of this float type.
An attribute that represents a reference to a dense vector or tensor object.
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:629
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
U dyn_cast_or_null() const
Definition: Types.h:274
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:294
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:194
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:286
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult verifyTruncateOp(Op op)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:309
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:37
static bool classof(Operation *op)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:258
std::tuple< Types... > * type_list
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
LogicalResult emitOptionalError(Optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:489
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:320
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
int64_t ceil(Fraction f)
Definition: Fraction.h:65
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
U dyn_cast() const
Definition: Attributes.h:127
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
This provides public APIs that all operations should have.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
static BoolAttr get(MLIRContext *context, bool value)
bool isa() const
Definition: Types.h:254
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class helps build Operations.
Definition: Builders.h:192
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
static LogicalResult verifyExtOp(Op op)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
U cast() const
Definition: Types.h:278
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
static bool classof(Operation *op)