MLIR  14.0.0git
ArithmeticOps.cpp
Go to the documentation of this file.
1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/ADT/APSInt.h"
20 
21 using namespace mlir;
22 using namespace mlir::arith;
23 
24 //===----------------------------------------------------------------------===//
25 // Pattern helpers
26 //===----------------------------------------------------------------------===//
27 
28 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
29  Attribute lhs, Attribute rhs) {
30  return builder.getIntegerAttr(res.getType(),
31  lhs.cast<IntegerAttr>().getInt() +
32  rhs.cast<IntegerAttr>().getInt());
33 }
34 
35 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
36  Attribute lhs, Attribute rhs) {
37  return builder.getIntegerAttr(res.getType(),
38  lhs.cast<IntegerAttr>().getInt() -
39  rhs.cast<IntegerAttr>().getInt());
40 }
41 
42 /// Invert an integer comparison predicate.
43 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
44  switch (pred) {
45  case arith::CmpIPredicate::eq:
46  return arith::CmpIPredicate::ne;
47  case arith::CmpIPredicate::ne:
48  return arith::CmpIPredicate::eq;
49  case arith::CmpIPredicate::slt:
50  return arith::CmpIPredicate::sge;
51  case arith::CmpIPredicate::sle:
52  return arith::CmpIPredicate::sgt;
53  case arith::CmpIPredicate::sgt:
54  return arith::CmpIPredicate::sle;
55  case arith::CmpIPredicate::sge:
56  return arith::CmpIPredicate::slt;
57  case arith::CmpIPredicate::ult:
58  return arith::CmpIPredicate::uge;
59  case arith::CmpIPredicate::ule:
60  return arith::CmpIPredicate::ugt;
61  case arith::CmpIPredicate::ugt:
62  return arith::CmpIPredicate::ule;
63  case arith::CmpIPredicate::uge:
64  return arith::CmpIPredicate::ult;
65  }
66  llvm_unreachable("unknown cmpi predicate kind");
67 }
68 
69 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
70  return arith::CmpIPredicateAttr::get(pred.getContext(),
71  invertPredicate(pred.getValue()));
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // TableGen'd canonicalization patterns
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 #include "ArithmeticCanonicalization.inc"
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // ConstantOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::ConstantOp::getAsmResultNames(
87  function_ref<void(Value, StringRef)> setNameFn) {
88  auto type = getType();
89  if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
90  auto intType = type.dyn_cast<IntegerType>();
91 
92  // Sugar i1 constants with 'true' and 'false'.
93  if (intType && intType.getWidth() == 1)
94  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
95 
96  // Otherwise, build a compex name with the value and type.
97  SmallString<32> specialNameBuffer;
98  llvm::raw_svector_ostream specialName(specialNameBuffer);
99  specialName << 'c' << intCst.getInt();
100  if (intType)
101  specialName << '_' << type;
102  setNameFn(getResult(), specialName.str());
103  } else {
104  setNameFn(getResult(), "cst");
105  }
106 }
107 
108 /// TODO: disallow arith.constant to return anything other than signless integer
109 /// or float like.
110 static LogicalResult verify(arith::ConstantOp op) {
111  auto type = op.getType();
112  // The value's type must match the return type.
113  if (op.getValue().getType() != type) {
114  return op.emitOpError() << "value type " << op.getValue().getType()
115  << " must match return type: " << type;
116  }
117  // Integer values must be signless.
118  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
119  return op.emitOpError("integer return type must be signless");
120  // Any float or elements attribute are acceptable.
121  if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
122  return op.emitOpError(
123  "value must be an integer, float, or elements attribute");
124  }
125  return success();
126 }
127 
128 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
129  // The value's type must be the same as the provided type.
130  if (value.getType() != type)
131  return false;
132  // Integer values must be signless.
133  if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
134  return false;
135  // Integer, float, and element attributes are buildable.
136  return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
137 }
138 
139 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
140  return getValue();
141 }
142 
144  int64_t value, unsigned width) {
145  auto type = builder.getIntegerType(width);
146  arith::ConstantOp::build(builder, result, type,
147  builder.getIntegerAttr(type, value));
148 }
149 
151  int64_t value, Type type) {
152  assert(type.isSignlessInteger() &&
153  "ConstantIntOp can only have signless integer type values");
154  arith::ConstantOp::build(builder, result, type,
155  builder.getIntegerAttr(type, value));
156 }
157 
159  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
160  return constOp.getType().isSignlessInteger();
161  return false;
162 }
163 
165  const APFloat &value, FloatType type) {
166  arith::ConstantOp::build(builder, result, type,
167  builder.getFloatAttr(type, value));
168 }
169 
171  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
172  return constOp.getType().isa<FloatType>();
173  return false;
174 }
175 
177  int64_t value) {
178  arith::ConstantOp::build(builder, result, builder.getIndexType(),
179  builder.getIndexAttr(value));
180 }
181 
183  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
184  return constOp.getType().isIndex();
185  return false;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // AddIOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
193  // addi(x, 0) -> x
194  if (matchPattern(getRhs(), m_Zero()))
195  return getLhs();
196 
197  // add(sub(a, b), b) -> a
198  if (auto sub = getLhs().getDefiningOp<SubIOp>())
199  if (getRhs() == sub.getRhs())
200  return sub.getLhs();
201 
202  // add(b, sub(a, b)) -> a
203  if (auto sub = getRhs().getDefiningOp<SubIOp>())
204  if (getLhs() == sub.getRhs())
205  return sub.getLhs();
206 
207  return constFoldBinaryOp<IntegerAttr>(
208  operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
209 }
210 
211 void arith::AddIOp::getCanonicalizationPatterns(
212  OwningRewritePatternList &patterns, MLIRContext *context) {
213  patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
214  context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // SubIOp
219 //===----------------------------------------------------------------------===//
220 
221 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
222  // subi(x,x) -> 0
223  if (getOperand(0) == getOperand(1))
224  return Builder(getContext()).getZeroAttr(getType());
225  // subi(x,0) -> x
226  if (matchPattern(getRhs(), m_Zero()))
227  return getLhs();
228 
229  return constFoldBinaryOp<IntegerAttr>(
230  operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
231 }
232 
233 void arith::SubIOp::getCanonicalizationPatterns(
234  OwningRewritePatternList &patterns, MLIRContext *context) {
235  patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
236  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
237  SubILHSSubConstantLHS>(context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MulIOp
242 //===----------------------------------------------------------------------===//
243 
244 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
245  // muli(x, 0) -> 0
246  if (matchPattern(getRhs(), m_Zero()))
247  return getRhs();
248  // muli(x, 1) -> x
249  if (matchPattern(getRhs(), m_One()))
250  return getOperand(0);
251  // TODO: Handle the overflow case.
252 
253  // default folder
254  return constFoldBinaryOp<IntegerAttr>(
255  operands, [](const APInt &a, const APInt &b) { return a * b; });
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // DivUIOp
260 //===----------------------------------------------------------------------===//
261 
262 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
263  // Don't fold if it would require a division by zero.
264  bool div0 = false;
265  auto result =
266  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
267  if (div0 || !b) {
268  div0 = true;
269  return a;
270  }
271  return a.udiv(b);
272  });
273 
274  // Fold out division by one. Assumes all tensors of all ones are splats.
275  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
276  if (rhs.getValue() == 1)
277  return getLhs();
278  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
279  if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
280  return getLhs();
281  }
282 
283  return div0 ? Attribute() : result;
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DivSIOp
288 //===----------------------------------------------------------------------===//
289 
290 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
291  // Don't fold if it would overflow or if it requires a division by zero.
292  bool overflowOrDiv0 = false;
293  auto result =
294  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
295  if (overflowOrDiv0 || !b) {
296  overflowOrDiv0 = true;
297  return a;
298  }
299  return a.sdiv_ov(b, overflowOrDiv0);
300  });
301 
302  // Fold out division by one. Assumes all tensors of all ones are splats.
303  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
304  if (rhs.getValue() == 1)
305  return getLhs();
306  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
307  if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
308  return getLhs();
309  }
310 
311  return overflowOrDiv0 ? Attribute() : result;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Ceil and floor division folding helpers
316 //===----------------------------------------------------------------------===//
317 
318 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
319  bool &overflow) {
320  // Returns (a-1)/b + 1
321  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
322  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
323  return val.sadd_ov(one, overflow);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // CeilDivUIOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
331  bool overflowOrDiv0 = false;
332  auto result =
333  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
334  if (overflowOrDiv0 || !b) {
335  overflowOrDiv0 = true;
336  return a;
337  }
338  APInt quotient = a.udiv(b);
339  if (!a.urem(b))
340  return quotient;
341  APInt one(a.getBitWidth(), 1, true);
342  return quotient.uadd_ov(one, overflowOrDiv0);
343  });
344  // Fold out ceil division by one. Assumes all tensors of all ones are
345  // splats.
346  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
347  if (rhs.getValue() == 1)
348  return getLhs();
349  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
350  if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
351  return getLhs();
352  }
353 
354  return overflowOrDiv0 ? Attribute() : result;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // CeilDivSIOp
359 //===----------------------------------------------------------------------===//
360 
361 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
362  // Don't fold if it would overflow or if it requires a division by zero.
363  bool overflowOrDiv0 = false;
364  auto result =
365  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
366  if (overflowOrDiv0 || !b) {
367  overflowOrDiv0 = true;
368  return a;
369  }
370  if (!a)
371  return a;
372  // After this point we know that neither a or b are zero.
373  unsigned bits = a.getBitWidth();
374  APInt zero = APInt::getZero(bits);
375  bool aGtZero = a.sgt(zero);
376  bool bGtZero = b.sgt(zero);
377  if (aGtZero && bGtZero) {
378  // Both positive, return ceil(a, b).
379  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
380  }
381  if (!aGtZero && !bGtZero) {
382  // Both negative, return ceil(-a, -b).
383  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
384  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
385  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
386  }
387  if (!aGtZero && bGtZero) {
388  // A is negative, b is positive, return - ( -a / b).
389  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
390  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
391  return zero.ssub_ov(div, overflowOrDiv0);
392  }
393  // A is positive, b is negative, return - (a / -b).
394  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
395  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
396  return zero.ssub_ov(div, overflowOrDiv0);
397  });
398 
399  // Fold out ceil division by one. Assumes all tensors of all ones are
400  // splats.
401  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
402  if (rhs.getValue() == 1)
403  return getLhs();
404  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
405  if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
406  return getLhs();
407  }
408 
409  return overflowOrDiv0 ? Attribute() : result;
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // FloorDivSIOp
414 //===----------------------------------------------------------------------===//
415 
416 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
417  // Don't fold if it would overflow or if it requires a division by zero.
418  bool overflowOrDiv0 = false;
419  auto result =
420  constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
421  if (overflowOrDiv0 || !b) {
422  overflowOrDiv0 = true;
423  return a;
424  }
425  if (!a)
426  return a;
427  // After this point we know that neither a or b are zero.
428  unsigned bits = a.getBitWidth();
429  APInt zero = APInt::getZero(bits);
430  bool aGtZero = a.sgt(zero);
431  bool bGtZero = b.sgt(zero);
432  if (aGtZero && bGtZero) {
433  // Both positive, return a / b.
434  return a.sdiv_ov(b, overflowOrDiv0);
435  }
436  if (!aGtZero && !bGtZero) {
437  // Both negative, return -a / -b.
438  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
439  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
440  return posA.sdiv_ov(posB, overflowOrDiv0);
441  }
442  if (!aGtZero && bGtZero) {
443  // A is negative, b is positive, return - ceil(-a, b).
444  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
445  APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
446  return zero.ssub_ov(ceil, overflowOrDiv0);
447  }
448  // A is positive, b is negative, return - ceil(a, -b).
449  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
450  APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
451  return zero.ssub_ov(ceil, overflowOrDiv0);
452  });
453 
454  // Fold out floor division by one. Assumes all tensors of all ones are
455  // splats.
456  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
457  if (rhs.getValue() == 1)
458  return getLhs();
459  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
460  if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
461  return getLhs();
462  }
463 
464  return overflowOrDiv0 ? Attribute() : result;
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // RemUIOp
469 //===----------------------------------------------------------------------===//
470 
471 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
472  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
473  if (!rhs)
474  return {};
475  auto rhsValue = rhs.getValue();
476 
477  // x % 1 = 0
478  if (rhsValue.isOneValue())
479  return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
480 
481  // Don't fold if it requires division by zero.
482  if (rhsValue.isNullValue())
483  return {};
484 
485  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
486  if (!lhs)
487  return {};
488  return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // RemSIOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
496  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
497  if (!rhs)
498  return {};
499  auto rhsValue = rhs.getValue();
500 
501  // x % 1 = 0
502  if (rhsValue.isOneValue())
503  return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
504 
505  // Don't fold if it requires division by zero.
506  if (rhsValue.isNullValue())
507  return {};
508 
509  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
510  if (!lhs)
511  return {};
512  return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AndIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
520  /// and(x, 0) -> 0
521  if (matchPattern(getRhs(), m_Zero()))
522  return getRhs();
523  /// and(x, allOnes) -> x
524  APInt intValue;
525  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
526  return getLhs();
527 
528  return constFoldBinaryOp<IntegerAttr>(
529  operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // OrIOp
534 //===----------------------------------------------------------------------===//
535 
536 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
537  /// or(x, 0) -> x
538  if (matchPattern(getRhs(), m_Zero()))
539  return getLhs();
540  /// or(x, <all ones>) -> <all ones>
541  if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
542  if (rhsAttr.getValue().isAllOnes())
543  return rhsAttr;
544 
545  return constFoldBinaryOp<IntegerAttr>(
546  operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOp
551 //===----------------------------------------------------------------------===//
552 
553 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
554  /// xor(x, 0) -> x
555  if (matchPattern(getRhs(), m_Zero()))
556  return getLhs();
557  /// xor(x, x) -> 0
558  if (getLhs() == getRhs())
559  return Builder(getContext()).getZeroAttr(getType());
560  /// xor(xor(x, a), a) -> x
561  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
562  if (prev.getRhs() == getRhs())
563  return prev.getLhs();
564 
565  return constFoldBinaryOp<IntegerAttr>(
566  operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
567 }
568 
569 void arith::XOrIOp::getCanonicalizationPatterns(
570  OwningRewritePatternList &patterns, MLIRContext *context) {
571  patterns.insert<XOrINotCmpI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // AddFOp
576 //===----------------------------------------------------------------------===//
577 
578 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
579  return constFoldBinaryOp<FloatAttr>(
580  operands, [](const APFloat &a, const APFloat &b) { return a + b; });
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // SubFOp
585 //===----------------------------------------------------------------------===//
586 
587 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
588  return constFoldBinaryOp<FloatAttr>(
589  operands, [](const APFloat &a, const APFloat &b) { return a - b; });
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // MaxSIOp
594 //===----------------------------------------------------------------------===//
595 
596 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
597  assert(operands.size() == 2 && "binary operation takes two operands");
598 
599  // maxsi(x,x) -> x
600  if (getLhs() == getRhs())
601  return getRhs();
602 
603  APInt intValue;
604  // maxsi(x,MAX_INT) -> MAX_INT
605  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
606  intValue.isMaxSignedValue())
607  return getRhs();
608 
609  // maxsi(x, MIN_INT) -> x
610  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
611  intValue.isMinSignedValue())
612  return getLhs();
613 
614  return constFoldBinaryOp<IntegerAttr>(operands,
615  [](const APInt &a, const APInt &b) {
616  return llvm::APIntOps::smax(a, b);
617  });
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MaxUIOp
622 //===----------------------------------------------------------------------===//
623 
624 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
625  assert(operands.size() == 2 && "binary operation takes two operands");
626 
627  // maxui(x,x) -> x
628  if (getLhs() == getRhs())
629  return getRhs();
630 
631  APInt intValue;
632  // maxui(x,MAX_INT) -> MAX_INT
633  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
634  return getRhs();
635 
636  // maxui(x, MIN_INT) -> x
637  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
638  return getLhs();
639 
640  return constFoldBinaryOp<IntegerAttr>(operands,
641  [](const APInt &a, const APInt &b) {
642  return llvm::APIntOps::umax(a, b);
643  });
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // MinSIOp
648 //===----------------------------------------------------------------------===//
649 
650 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
651  assert(operands.size() == 2 && "binary operation takes two operands");
652 
653  // minsi(x,x) -> x
654  if (getLhs() == getRhs())
655  return getRhs();
656 
657  APInt intValue;
658  // minsi(x,MIN_INT) -> MIN_INT
659  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
660  intValue.isMinSignedValue())
661  return getRhs();
662 
663  // minsi(x, MAX_INT) -> x
664  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
665  intValue.isMaxSignedValue())
666  return getLhs();
667 
668  return constFoldBinaryOp<IntegerAttr>(operands,
669  [](const APInt &a, const APInt &b) {
670  return llvm::APIntOps::smin(a, b);
671  });
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // MinUIOp
676 //===----------------------------------------------------------------------===//
677 
678 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
679  assert(operands.size() == 2 && "binary operation takes two operands");
680 
681  // minui(x,x) -> x
682  if (getLhs() == getRhs())
683  return getRhs();
684 
685  APInt intValue;
686  // minui(x,MIN_INT) -> MIN_INT
687  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
688  return getRhs();
689 
690  // minui(x, MAX_INT) -> x
691  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
692  return getLhs();
693 
694  return constFoldBinaryOp<IntegerAttr>(operands,
695  [](const APInt &a, const APInt &b) {
696  return llvm::APIntOps::umin(a, b);
697  });
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // MulFOp
702 //===----------------------------------------------------------------------===//
703 
704 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
705  return constFoldBinaryOp<FloatAttr>(
706  operands, [](const APFloat &a, const APFloat &b) { return a * b; });
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // DivFOp
711 //===----------------------------------------------------------------------===//
712 
713 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
714  return constFoldBinaryOp<FloatAttr>(
715  operands, [](const APFloat &a, const APFloat &b) { return a / b; });
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // Utility functions for verifying cast ops
720 //===----------------------------------------------------------------------===//
721 
722 template <typename... Types>
723 using type_list = std::tuple<Types...> *;
724 
725 /// Returns a non-null type only if the provided type is one of the allowed
726 /// types or one of the allowed shaped types of the allowed types. Returns the
727 /// element type if a valid shaped type is provided.
728 template <typename... ShapedTypes, typename... ElementTypes>
731  if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
732  return {};
733 
734  auto underlyingType = getElementTypeOrSelf(type);
735  if (!underlyingType.isa<ElementTypes...>())
736  return {};
737 
738  return underlyingType;
739 }
740 
741 /// Get allowed underlying types for vectors and tensors.
742 template <typename... ElementTypes>
743 static Type getTypeIfLike(Type type) {
746 }
747 
748 /// Get allowed underlying types for vectors, tensors, and memrefs.
749 template <typename... ElementTypes>
751  return getUnderlyingType(type,
754 }
755 
756 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
757  return inputs.size() == 1 && outputs.size() == 1 &&
758  succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // Verifiers for integer and floating point extension/truncation ops
763 //===----------------------------------------------------------------------===//
764 
765 // Extend ops can only extend to a wider type.
766 template <typename ValType, typename Op>
768  Type srcType = getElementTypeOrSelf(op.getIn().getType());
769  Type dstType = getElementTypeOrSelf(op.getType());
770 
771  if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
772  return op.emitError("result type ")
773  << dstType << " must be wider than operand type " << srcType;
774 
775  return success();
776 }
777 
778 // Truncate ops can only truncate to a shorter type.
779 template <typename ValType, typename Op>
781  Type srcType = getElementTypeOrSelf(op.getIn().getType());
782  Type dstType = getElementTypeOrSelf(op.getType());
783 
784  if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
785  return op.emitError("result type ")
786  << dstType << " must be shorter than operand type " << srcType;
787 
788  return success();
789 }
790 
791 /// Validate a cast that changes the width of a type.
792 template <template <typename> class WidthComparator, typename... ElementTypes>
793 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
794  if (!areValidCastInputsAndOutputs(inputs, outputs))
795  return false;
796 
797  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
798  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
799  if (!srcType || !dstType)
800  return false;
801 
802  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
803  srcType.getIntOrFloatBitWidth());
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // ExtUIOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
811  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
812  return IntegerAttr::get(
813  getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
814 
815  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
816  getInMutable().assign(lhs.getIn());
817  return getResult();
818  }
819 
820  return {};
821 }
822 
823 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
824  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
825 }
826 
827 //===----------------------------------------------------------------------===//
828 // ExtSIOp
829 //===----------------------------------------------------------------------===//
830 
831 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
832  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
833  return IntegerAttr::get(
834  getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
835 
836  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
837  getInMutable().assign(lhs.getIn());
838  return getResult();
839  }
840 
841  return {};
842 }
843 
844 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
845  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
846 }
847 
848 void arith::ExtSIOp::getCanonicalizationPatterns(
849  OwningRewritePatternList &patterns, MLIRContext *context) {
850  patterns.insert<ExtSIOfExtUI>(context);
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // ExtFOp
855 //===----------------------------------------------------------------------===//
856 
857 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
858  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
859 }
860 
861 //===----------------------------------------------------------------------===//
862 // TruncIOp
863 //===----------------------------------------------------------------------===//
864 
865 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
866  assert(operands.size() == 1 && "unary operation takes one operand");
867 
868  // trunci(zexti(a)) -> a
869  // trunci(sexti(a)) -> a
870  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
871  matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
872  return getOperand().getDefiningOp()->getOperand(0);
873 
874  // trunci(trunci(a)) -> trunci(a))
875  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
876  setOperand(getOperand().getDefiningOp()->getOperand(0));
877  return getResult();
878  }
879 
880  if (!operands[0])
881  return {};
882 
883  if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
884  return IntegerAttr::get(
885  getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
886  }
887 
888  return {};
889 }
890 
891 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // TruncFOp
897 //===----------------------------------------------------------------------===//
898 
899 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
900 /// can be represented without precision loss or rounding.
901 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
902  assert(operands.size() == 1 && "unary operation takes one operand");
903 
904  auto constOperand = operands.front();
905  if (!constOperand || !constOperand.isa<FloatAttr>())
906  return {};
907 
908  // Convert to target type via 'double'.
909  double sourceValue =
910  constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
911  auto targetAttr = FloatAttr::get(getType(), sourceValue);
912 
913  // Propagate if constant's value does not change after truncation.
914  if (sourceValue == targetAttr.getValue().convertToDouble())
915  return targetAttr;
916 
917  return {};
918 }
919 
920 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // AndIOp
926 //===----------------------------------------------------------------------===//
927 
928 void arith::AndIOp::getCanonicalizationPatterns(
929  OwningRewritePatternList &patterns, MLIRContext *context) {
930  patterns.insert<AndOfExtUI, AndOfExtSI>(context);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // OrIOp
935 //===----------------------------------------------------------------------===//
936 
937 void arith::OrIOp::getCanonicalizationPatterns(
938  OwningRewritePatternList &patterns, MLIRContext *context) {
939  patterns.insert<OrOfExtUI, OrOfExtSI>(context);
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // Verifiers for casts between integers and floats.
944 //===----------------------------------------------------------------------===//
945 
946 template <typename From, typename To>
947 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
948  if (!areValidCastInputsAndOutputs(inputs, outputs))
949  return false;
950 
951  auto srcType = getTypeIfLike<From>(inputs.front());
952  auto dstType = getTypeIfLike<To>(outputs.back());
953 
954  return srcType && dstType;
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // UIToFPOp
959 //===----------------------------------------------------------------------===//
960 
961 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
962  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
963 }
964 
965 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
966  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
967  const APInt &api = lhs.getValue();
968  FloatType floatTy = getType().cast<FloatType>();
969  APFloat apf(floatTy.getFloatSemantics(),
970  APInt::getZero(floatTy.getWidth()));
971  apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
972  return FloatAttr::get(floatTy, apf);
973  }
974  return {};
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // SIToFPOp
979 //===----------------------------------------------------------------------===//
980 
981 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
982  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
983 }
984 
985 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
986  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
987  const APInt &api = lhs.getValue();
988  FloatType floatTy = getType().cast<FloatType>();
989  APFloat apf(floatTy.getFloatSemantics(),
990  APInt::getZero(floatTy.getWidth()));
991  apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
992  return FloatAttr::get(floatTy, apf);
993  }
994  return {};
995 }
996 //===----------------------------------------------------------------------===//
997 // FPToUIOp
998 //===----------------------------------------------------------------------===//
999 
1000 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1001  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1002 }
1003 
1004 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1005  if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1006  const APFloat &apf = lhs.getValue();
1007  IntegerType intTy = getType().cast<IntegerType>();
1008  bool ignored;
1009  APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1010  if (APFloat::opInvalidOp ==
1011  apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1012  // Undefined behavior invoked - the destination type can't represent
1013  // the input constant.
1014  return {};
1015  }
1016  return IntegerAttr::get(getType(), api);
1017  }
1018 
1019  return {};
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // FPToSIOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1027  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1028 }
1029 
1030 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1031  if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1032  const APFloat &apf = lhs.getValue();
1033  IntegerType intTy = getType().cast<IntegerType>();
1034  bool ignored;
1035  APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1036  if (APFloat::opInvalidOp ==
1037  apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1038  // Undefined behavior invoked - the destination type can't represent
1039  // the input constant.
1040  return {};
1041  }
1042  return IntegerAttr::get(getType(), api);
1043  }
1044 
1045  return {};
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // IndexCastOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1053  TypeRange outputs) {
1054  if (!areValidCastInputsAndOutputs(inputs, outputs))
1055  return false;
1056 
1057  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1058  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1059  if (!srcType || !dstType)
1060  return false;
1061 
1062  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1063  (srcType.isSignlessInteger() && dstType.isIndex());
1064 }
1065 
1066 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1067  // index_cast(constant) -> constant
1068  // A little hack because we go through int. Otherwise, the size of the
1069  // constant might need to change.
1070  if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1071  return IntegerAttr::get(getType(), value.getInt());
1072 
1073  return {};
1074 }
1075 
1076 void arith::IndexCastOp::getCanonicalizationPatterns(
1077  OwningRewritePatternList &patterns, MLIRContext *context) {
1078  patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // BitcastOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1086  if (!areValidCastInputsAndOutputs(inputs, outputs))
1087  return false;
1088 
1089  auto srcType =
1090  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1091  auto dstType =
1092  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1093  if (!srcType || !dstType)
1094  return false;
1095 
1096  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1097 }
1098 
1099 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1100  assert(operands.size() == 1 && "bitcast op expects 1 operand");
1101 
1102  auto resType = getType();
1103  auto operand = operands[0];
1104  if (!operand)
1105  return {};
1106 
1107  /// Bitcast dense elements.
1108  if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1109  return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1110  /// Other shaped types unhandled.
1111  if (resType.isa<ShapedType>())
1112  return {};
1113 
1114  /// Bitcast integer or float to integer or float.
1115  APInt bits = operand.isa<FloatAttr>()
1116  ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1117  : operand.cast<IntegerAttr>().getValue();
1118 
1119  if (auto resFloatType = resType.dyn_cast<FloatType>())
1120  return FloatAttr::get(resType,
1121  APFloat(resFloatType.getFloatSemantics(), bits));
1122  return IntegerAttr::get(resType, bits);
1123 }
1124 
1125 void arith::BitcastOp::getCanonicalizationPatterns(
1126  OwningRewritePatternList &patterns, MLIRContext *context) {
1127  patterns.insert<BitcastOfBitcast>(context);
1128 }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // Helpers for compare ops
1132 //===----------------------------------------------------------------------===//
1133 
1134 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
1135 static Type getI1SameShape(Type type) {
1136  auto i1Type = IntegerType::get(type.getContext(), 1);
1137  if (auto tensorType = type.dyn_cast<RankedTensorType>())
1138  return RankedTensorType::get(tensorType.getShape(), i1Type);
1139  if (type.isa<UnrankedTensorType>())
1140  return UnrankedTensorType::get(i1Type);
1141  if (auto vectorType = type.dyn_cast<VectorType>())
1142  return VectorType::get(vectorType.getShape(), i1Type,
1143  vectorType.getNumScalableDims());
1144  return i1Type;
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // CmpIOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1152 /// comparison predicates.
1153 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1154  const APInt &lhs, const APInt &rhs) {
1155  switch (predicate) {
1156  case arith::CmpIPredicate::eq:
1157  return lhs.eq(rhs);
1158  case arith::CmpIPredicate::ne:
1159  return lhs.ne(rhs);
1160  case arith::CmpIPredicate::slt:
1161  return lhs.slt(rhs);
1162  case arith::CmpIPredicate::sle:
1163  return lhs.sle(rhs);
1164  case arith::CmpIPredicate::sgt:
1165  return lhs.sgt(rhs);
1166  case arith::CmpIPredicate::sge:
1167  return lhs.sge(rhs);
1168  case arith::CmpIPredicate::ult:
1169  return lhs.ult(rhs);
1170  case arith::CmpIPredicate::ule:
1171  return lhs.ule(rhs);
1172  case arith::CmpIPredicate::ugt:
1173  return lhs.ugt(rhs);
1174  case arith::CmpIPredicate::uge:
1175  return lhs.uge(rhs);
1176  }
1177  llvm_unreachable("unknown cmpi predicate kind");
1178 }
1179 
1180 /// Returns true if the predicate is true for two equal operands.
1181 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1182  switch (predicate) {
1183  case arith::CmpIPredicate::eq:
1184  case arith::CmpIPredicate::sle:
1185  case arith::CmpIPredicate::sge:
1186  case arith::CmpIPredicate::ule:
1187  case arith::CmpIPredicate::uge:
1188  return true;
1189  case arith::CmpIPredicate::ne:
1190  case arith::CmpIPredicate::slt:
1191  case arith::CmpIPredicate::sgt:
1192  case arith::CmpIPredicate::ult:
1193  case arith::CmpIPredicate::ugt:
1194  return false;
1195  }
1196  llvm_unreachable("unknown cmpi predicate kind");
1197 }
1198 
1199 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1200  auto boolAttr = BoolAttr::get(ctx, value);
1201  ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1202  if (!shapedType)
1203  return boolAttr;
1204  return DenseElementsAttr::get(shapedType, boolAttr);
1205 }
1206 
1207 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1208  assert(operands.size() == 2 && "cmpi takes two operands");
1209 
1210  // cmpi(pred, x, x)
1211  if (getLhs() == getRhs()) {
1212  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1213  return getBoolAttribute(getType(), getContext(), val);
1214  }
1215 
1216  if (matchPattern(getRhs(), m_Zero())) {
1217  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1218  if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1219  // extsi(%x : i1 -> iN) != 0 -> %x
1220  if (getPredicate() == arith::CmpIPredicate::ne) {
1221  return extOp.getOperand();
1222  }
1223  }
1224  }
1225  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1226  if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
1227  // extui(%x : i1 -> iN) != 0 -> %x
1228  if (getPredicate() == arith::CmpIPredicate::ne) {
1229  return extOp.getOperand();
1230  }
1231  }
1232  }
1233  }
1234 
1235  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1236  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1237  if (!lhs || !rhs)
1238  return {};
1239 
1240  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1241  return BoolAttr::get(getContext(), val);
1242 }
1243 
1244 //===----------------------------------------------------------------------===//
1245 // CmpFOp
1246 //===----------------------------------------------------------------------===//
1247 
1248 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1249 /// comparison predicates.
1250 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1251  const APFloat &lhs, const APFloat &rhs) {
1252  auto cmpResult = lhs.compare(rhs);
1253  switch (predicate) {
1254  case arith::CmpFPredicate::AlwaysFalse:
1255  return false;
1256  case arith::CmpFPredicate::OEQ:
1257  return cmpResult == APFloat::cmpEqual;
1258  case arith::CmpFPredicate::OGT:
1259  return cmpResult == APFloat::cmpGreaterThan;
1260  case arith::CmpFPredicate::OGE:
1261  return cmpResult == APFloat::cmpGreaterThan ||
1262  cmpResult == APFloat::cmpEqual;
1263  case arith::CmpFPredicate::OLT:
1264  return cmpResult == APFloat::cmpLessThan;
1265  case arith::CmpFPredicate::OLE:
1266  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1267  case arith::CmpFPredicate::ONE:
1268  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1269  case arith::CmpFPredicate::ORD:
1270  return cmpResult != APFloat::cmpUnordered;
1271  case arith::CmpFPredicate::UEQ:
1272  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1273  case arith::CmpFPredicate::UGT:
1274  return cmpResult == APFloat::cmpUnordered ||
1275  cmpResult == APFloat::cmpGreaterThan;
1276  case arith::CmpFPredicate::UGE:
1277  return cmpResult == APFloat::cmpUnordered ||
1278  cmpResult == APFloat::cmpGreaterThan ||
1279  cmpResult == APFloat::cmpEqual;
1280  case arith::CmpFPredicate::ULT:
1281  return cmpResult == APFloat::cmpUnordered ||
1282  cmpResult == APFloat::cmpLessThan;
1283  case arith::CmpFPredicate::ULE:
1284  return cmpResult == APFloat::cmpUnordered ||
1285  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1286  case arith::CmpFPredicate::UNE:
1287  return cmpResult != APFloat::cmpEqual;
1288  case arith::CmpFPredicate::UNO:
1289  return cmpResult == APFloat::cmpUnordered;
1290  case arith::CmpFPredicate::AlwaysTrue:
1291  return true;
1292  }
1293  llvm_unreachable("unknown cmpf predicate kind");
1294 }
1295 
1296 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1297  assert(operands.size() == 2 && "cmpf takes two operands");
1298 
1299  auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1300  auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1301 
1302  // If one operand is NaN, making them both NaN does not change the result.
1303  if (lhs && lhs.getValue().isNaN())
1304  rhs = lhs;
1305  if (rhs && rhs.getValue().isNaN())
1306  lhs = rhs;
1307 
1308  if (!lhs || !rhs)
1309  return {};
1310 
1311  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1312  return BoolAttr::get(getContext(), val);
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // Atomic Enum
1317 //===----------------------------------------------------------------------===//
1318 
1319 /// Returns the identity value attribute associated with an AtomicRMWKind op.
1320 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1321  OpBuilder &builder, Location loc) {
1322  switch (kind) {
1323  case AtomicRMWKind::maxf:
1324  return builder.getFloatAttr(
1325  resultType,
1326  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1327  /*Negative=*/true));
1328  case AtomicRMWKind::addf:
1329  case AtomicRMWKind::addi:
1330  case AtomicRMWKind::maxu:
1331  case AtomicRMWKind::ori:
1332  return builder.getZeroAttr(resultType);
1333  case AtomicRMWKind::andi:
1334  return builder.getIntegerAttr(
1335  resultType,
1336  APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1337  case AtomicRMWKind::maxs:
1338  return builder.getIntegerAttr(
1339  resultType,
1340  APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1341  case AtomicRMWKind::minf:
1342  return builder.getFloatAttr(
1343  resultType,
1344  APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1345  /*Negative=*/false));
1346  case AtomicRMWKind::mins:
1347  return builder.getIntegerAttr(
1348  resultType,
1349  APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1350  case AtomicRMWKind::minu:
1351  return builder.getIntegerAttr(
1352  resultType,
1353  APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1354  case AtomicRMWKind::muli:
1355  return builder.getIntegerAttr(resultType, 1);
1356  case AtomicRMWKind::mulf:
1357  return builder.getFloatAttr(resultType, 1);
1358  // TODO: Add remaining reduction operations.
1359  default:
1360  (void)emitOptionalError(loc, "Reduction operation type not supported");
1361  break;
1362  }
1363  return nullptr;
1364 }
1365 
1366 /// Returns the identity value associated with an AtomicRMWKind op.
1367 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
1368  OpBuilder &builder, Location loc) {
1369  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
1370  return builder.create<arith::ConstantOp>(loc, attr);
1371 }
1372 
1373 /// Return the value obtained by applying the reduction operation kind
1374 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
1375 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
1376  Location loc, Value lhs, Value rhs) {
1377  switch (op) {
1378  case AtomicRMWKind::addf:
1379  return builder.create<arith::AddFOp>(loc, lhs, rhs);
1380  case AtomicRMWKind::addi:
1381  return builder.create<arith::AddIOp>(loc, lhs, rhs);
1382  case AtomicRMWKind::mulf:
1383  return builder.create<arith::MulFOp>(loc, lhs, rhs);
1384  case AtomicRMWKind::muli:
1385  return builder.create<arith::MulIOp>(loc, lhs, rhs);
1386  case AtomicRMWKind::maxf:
1387  return builder.create<arith::MaxFOp>(loc, lhs, rhs);
1388  case AtomicRMWKind::minf:
1389  return builder.create<arith::MinFOp>(loc, lhs, rhs);
1390  case AtomicRMWKind::maxs:
1391  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
1392  case AtomicRMWKind::mins:
1393  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
1394  case AtomicRMWKind::maxu:
1395  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
1396  case AtomicRMWKind::minu:
1397  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
1398  case AtomicRMWKind::ori:
1399  return builder.create<arith::OrIOp>(loc, lhs, rhs);
1400  case AtomicRMWKind::andi:
1401  return builder.create<arith::AndIOp>(loc, lhs, rhs);
1402  // TODO: Add remaining reduction operations.
1403  default:
1404  (void)emitOptionalError(loc, "Reduction operation type not supported");
1405  break;
1406  }
1407  return nullptr;
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // TableGen'd op method definitions
1412 //===----------------------------------------------------------------------===//
1413 
1414 #define GET_OP_CLASSES
1415 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1416 
1417 //===----------------------------------------------------------------------===//
1418 // TableGen'd enum attribute definitions
1419 //===----------------------------------------------------------------------===//
1420 
1421 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:282
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
int64_t ceil(Fraction f)
Definition: Fraction.h:57
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool isa() const
Definition: Attributes.h:107
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:243
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static constexpr const bool value
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
LogicalResult emitOptionalError(Optional< Location > loc, Args &&... args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:464
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
static bool classof(Operation *op)
unsigned getWidth()
Return the bitwidth of this float type.
An attribute that represents a reference to a dense vector or tensor object.
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:618
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
U dyn_cast_or_null() const
Definition: Types.h:247
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static LogicalResult verifyTruncateOp(Op op)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
This represents an operation in an abstracted form, suitable for use with the builder APIs...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:37
static bool classof(Operation *op)
std::tuple< Types... > * type_list
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
IndexType getIndexType()
Definition: Builders.cpp:48
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
This provides public APIs that all operations should have.
U cast() const
Definition: Value.h:107
detail::constant_int_value_matcher< 0 > m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:254
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static BoolAttr get(MLIRContext *context, bool value)
bool isa() const
Definition: Types.h:234
This class helps build Operations.
Definition: Builders.h:177
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
static LogicalResult verifyExtOp(Op op)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
U cast() const
Definition: Types.h:250
static bool classof(Operation *op)