MLIR  17.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <utility>
12 
15 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::arith;
31 
32 //===----------------------------------------------------------------------===//
33 // Pattern helpers
34 //===----------------------------------------------------------------------===//
35 
36 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
37  Attribute lhs, Attribute rhs) {
38  return builder.getIntegerAttr(res.getType(),
39  llvm::cast<IntegerAttr>(lhs).getInt() +
40  llvm::cast<IntegerAttr>(rhs).getInt());
41 }
42 
43 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
44  Attribute lhs, Attribute rhs) {
45  return builder.getIntegerAttr(res.getType(),
46  llvm::cast<IntegerAttr>(lhs).getInt() -
47  llvm::cast<IntegerAttr>(rhs).getInt());
48 }
49 
50 /// Invert an integer comparison predicate.
51 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
52  switch (pred) {
53  case arith::CmpIPredicate::eq:
54  return arith::CmpIPredicate::ne;
55  case arith::CmpIPredicate::ne:
56  return arith::CmpIPredicate::eq;
57  case arith::CmpIPredicate::slt:
58  return arith::CmpIPredicate::sge;
59  case arith::CmpIPredicate::sle:
60  return arith::CmpIPredicate::sgt;
61  case arith::CmpIPredicate::sgt:
62  return arith::CmpIPredicate::sle;
63  case arith::CmpIPredicate::sge:
64  return arith::CmpIPredicate::slt;
65  case arith::CmpIPredicate::ult:
66  return arith::CmpIPredicate::uge;
67  case arith::CmpIPredicate::ule:
68  return arith::CmpIPredicate::ugt;
69  case arith::CmpIPredicate::ugt:
70  return arith::CmpIPredicate::ule;
71  case arith::CmpIPredicate::uge:
72  return arith::CmpIPredicate::ult;
73  }
74  llvm_unreachable("unknown cmpi predicate kind");
75 }
76 
77 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
78  return arith::CmpIPredicateAttr::get(pred.getContext(),
79  invertPredicate(pred.getValue()));
80 }
81 
82 static int64_t getScalarOrElementWidth(Type type) {
83  Type elemTy = getElementTypeOrSelf(type);
84  if (elemTy.isIntOrFloat())
85  return elemTy.getIntOrFloatBitWidth();
86 
87  return -1;
88 }
89 
90 static int64_t getScalarOrElementWidth(Value value) {
91  return getScalarOrElementWidth(value.getType());
92 }
93 
95  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
96  return intAttr.getValue();
97 
98  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
99  if (llvm::isa<IntegerType>(splatAttr.getElementType()))
100  return splatAttr.getSplatValue<APInt>();
101 
102  return failure();
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // TableGen'd canonicalization patterns
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 #include "ArithCanonicalization.inc"
111 } // namespace
112 
113 //===----------------------------------------------------------------------===//
114 // Common helpers
115 //===----------------------------------------------------------------------===//
116 
117 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
118 static Type getI1SameShape(Type type) {
119  auto i1Type = IntegerType::get(type.getContext(), 1);
120  if (auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
121  return RankedTensorType::get(tensorType.getShape(), i1Type);
122  if (llvm::isa<UnrankedTensorType>(type))
123  return UnrankedTensorType::get(i1Type);
124  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
125  return VectorType::get(vectorType.getShape(), i1Type,
126  vectorType.getNumScalableDims());
127  return i1Type;
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // ConstantOp
132 //===----------------------------------------------------------------------===//
133 
134 void arith::ConstantOp::getAsmResultNames(
135  function_ref<void(Value, StringRef)> setNameFn) {
136  auto type = getType();
137  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
138  auto intType = llvm::dyn_cast<IntegerType>(type);
139 
140  // Sugar i1 constants with 'true' and 'false'.
141  if (intType && intType.getWidth() == 1)
142  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
143 
144  // Otherwise, build a complex name with the value and type.
145  SmallString<32> specialNameBuffer;
146  llvm::raw_svector_ostream specialName(specialNameBuffer);
147  specialName << 'c' << intCst.getValue();
148  if (intType)
149  specialName << '_' << type;
150  setNameFn(getResult(), specialName.str());
151  } else {
152  setNameFn(getResult(), "cst");
153  }
154 }
155 
156 /// TODO: disallow arith.constant to return anything other than signless integer
157 /// or float like.
159  auto type = getType();
160  // The value's type must match the return type.
161  if (getValue().getType() != type) {
162  return emitOpError() << "value type " << getValue().getType()
163  << " must match return type: " << type;
164  }
165  // Integer values must be signless.
166  if (llvm::isa<IntegerType>(type) &&
167  !llvm::cast<IntegerType>(type).isSignless())
168  return emitOpError("integer return type must be signless");
169  // Any float or elements attribute are acceptable.
170  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
171  return emitOpError(
172  "value must be an integer, float, or elements attribute");
173  }
174  return success();
175 }
176 
177 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
178  // The value's type must be the same as the provided type.
179  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
180  if (!typedAttr || typedAttr.getType() != type)
181  return false;
182  // Integer values must be signless.
183  if (llvm::isa<IntegerType>(type) &&
184  !llvm::cast<IntegerType>(type).isSignless())
185  return false;
186  // Integer, float, and element attributes are buildable.
187  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
188 }
189 
190 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
191  Type type, Location loc) {
192  if (isBuildableWith(value, type))
193  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
194  return nullptr;
195 }
196 
197 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
198 
200  int64_t value, unsigned width) {
201  auto type = builder.getIntegerType(width);
202  arith::ConstantOp::build(builder, result, type,
203  builder.getIntegerAttr(type, value));
204 }
205 
207  int64_t value, Type type) {
208  assert(type.isSignlessInteger() &&
209  "ConstantIntOp can only have signless integer type values");
210  arith::ConstantOp::build(builder, result, type,
211  builder.getIntegerAttr(type, value));
212 }
213 
215  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
216  return constOp.getType().isSignlessInteger();
217  return false;
218 }
219 
221  const APFloat &value, FloatType type) {
222  arith::ConstantOp::build(builder, result, type,
223  builder.getFloatAttr(type, value));
224 }
225 
227  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
228  return llvm::isa<FloatType>(constOp.getType());
229  return false;
230 }
231 
233  int64_t value) {
234  arith::ConstantOp::build(builder, result, builder.getIndexType(),
235  builder.getIndexAttr(value));
236 }
237 
239  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
240  return constOp.getType().isIndex();
241  return false;
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // AddIOp
246 //===----------------------------------------------------------------------===//
247 
248 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
249  // addi(x, 0) -> x
250  if (matchPattern(getRhs(), m_Zero()))
251  return getLhs();
252 
253  // addi(subi(a, b), b) -> a
254  if (auto sub = getLhs().getDefiningOp<SubIOp>())
255  if (getRhs() == sub.getRhs())
256  return sub.getLhs();
257 
258  // addi(b, subi(a, b)) -> a
259  if (auto sub = getRhs().getDefiningOp<SubIOp>())
260  if (getLhs() == sub.getRhs())
261  return sub.getLhs();
262 
263  return constFoldBinaryOp<IntegerAttr>(
264  adaptor.getOperands(),
265  [](APInt a, const APInt &b) { return std::move(a) + b; });
266 }
267 
268 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
269  MLIRContext *context) {
270  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
271  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // AddUIExtendedOp
276 //===----------------------------------------------------------------------===//
277 
278 std::optional<SmallVector<int64_t, 4>>
279 arith::AddUIExtendedOp::getShapeForUnroll() {
280  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
281  return llvm::to_vector<4>(vt.getShape());
282  return std::nullopt;
283 }
284 
285 // Returns the overflow bit, assuming that `sum` is the result of unsigned
286 // addition of `operand` and another number.
287 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
288  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
289 }
290 
292 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
294  Type overflowTy = getOverflow().getType();
295  // addui_extended(x, 0) -> x, false
296  if (matchPattern(getRhs(), m_Zero())) {
297  Builder builder(getContext());
298  auto falseValue = builder.getZeroAttr(overflowTy);
299 
300  results.push_back(getLhs());
301  results.push_back(falseValue);
302  return success();
303  }
304 
305  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
306  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
307  // operands. If that succeeds, calculate the overflow bit based on the sum
308  // and the first (constant) operand, `lhs`.
309  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
310  adaptor.getOperands(),
311  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
312  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
313  ArrayRef({sumAttr, adaptor.getLhs()}),
314  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
316  if (!overflowAttr)
317  return failure();
318 
319  results.push_back(sumAttr);
320  results.push_back(overflowAttr);
321  return success();
322  }
323 
324  return failure();
325 }
326 
327 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
328  RewritePatternSet &patterns, MLIRContext *context) {
329  patterns.add<AddUIExtendedToAddI>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // SubIOp
334 //===----------------------------------------------------------------------===//
335 
336 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
337  // subi(x,x) -> 0
338  if (getOperand(0) == getOperand(1))
339  return Builder(getContext()).getZeroAttr(getType());
340  // subi(x,0) -> x
341  if (matchPattern(getRhs(), m_Zero()))
342  return getLhs();
343 
344  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
345  // subi(addi(a, b), b) -> a
346  if (getRhs() == add.getRhs())
347  return add.getLhs();
348  // subi(addi(a, b), a) -> b
349  if (getRhs() == add.getLhs())
350  return add.getRhs();
351  }
352 
353  return constFoldBinaryOp<IntegerAttr>(
354  adaptor.getOperands(),
355  [](APInt a, const APInt &b) { return std::move(a) - b; });
356 }
357 
358 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
359  MLIRContext *context) {
360  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
361  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
362  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // MulIOp
367 //===----------------------------------------------------------------------===//
368 
369 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
370  // muli(x, 0) -> 0
371  if (matchPattern(getRhs(), m_Zero()))
372  return getRhs();
373  // muli(x, 1) -> x
374  if (matchPattern(getRhs(), m_One()))
375  return getOperand(0);
376  // TODO: Handle the overflow case.
377 
378  // default folder
379  return constFoldBinaryOp<IntegerAttr>(
380  adaptor.getOperands(),
381  [](const APInt &a, const APInt &b) { return a * b; });
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // MulSIExtendedOp
386 //===----------------------------------------------------------------------===//
387 
388 std::optional<SmallVector<int64_t, 4>>
389 arith::MulSIExtendedOp::getShapeForUnroll() {
390  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
391  return llvm::to_vector<4>(vt.getShape());
392  return std::nullopt;
393 }
394 
396 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
398  // mulsi_extended(x, 0) -> 0, 0
399  if (matchPattern(getRhs(), m_Zero())) {
400  Attribute zero = adaptor.getRhs();
401  results.push_back(zero);
402  results.push_back(zero);
403  return success();
404  }
405 
406  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
407  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
408  adaptor.getOperands(),
409  [](const APInt &a, const APInt &b) { return a * b; })) {
410  // Invoke the constant fold helper again to calculate the 'high' result.
411  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
412  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
413  unsigned bitWidth = a.getBitWidth();
414  APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
415  return fullProduct.extractBits(bitWidth, bitWidth);
416  });
417  assert(highAttr && "Unexpected constant-folding failure");
418 
419  results.push_back(lowAttr);
420  results.push_back(highAttr);
421  return success();
422  }
423 
424  return failure();
425 }
426 
427 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
428  RewritePatternSet &patterns, MLIRContext *context) {
429  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // MulUIExtendedOp
434 //===----------------------------------------------------------------------===//
435 
436 std::optional<SmallVector<int64_t, 4>>
437 arith::MulUIExtendedOp::getShapeForUnroll() {
438  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
439  return llvm::to_vector<4>(vt.getShape());
440  return std::nullopt;
441 }
442 
444 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
446  // mului_extended(x, 0) -> 0, 0
447  if (matchPattern(getRhs(), m_Zero())) {
448  Attribute zero = adaptor.getRhs();
449  results.push_back(zero);
450  results.push_back(zero);
451  return success();
452  }
453 
454  // mului_extended(x, 1) -> x, 0
455  if (matchPattern(getRhs(), m_One())) {
456  Builder builder(getContext());
457  Attribute zero = builder.getZeroAttr(getLhs().getType());
458  results.push_back(getLhs());
459  results.push_back(zero);
460  return success();
461  }
462 
463  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
464  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
465  adaptor.getOperands(),
466  [](const APInt &a, const APInt &b) { return a * b; })) {
467  // Invoke the constant fold helper again to calculate the 'high' result.
468  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
469  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
470  unsigned bitWidth = a.getBitWidth();
471  APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
472  return fullProduct.extractBits(bitWidth, bitWidth);
473  });
474  assert(highAttr && "Unexpected constant-folding failure");
475 
476  results.push_back(lowAttr);
477  results.push_back(highAttr);
478  return success();
479  }
480 
481  return failure();
482 }
483 
484 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
485  RewritePatternSet &patterns, MLIRContext *context) {
486  patterns.add<MulUIExtendedToMulI>(context);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // DivUIOp
491 //===----------------------------------------------------------------------===//
492 
493 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
494  // divui (x, 1) -> x.
495  if (matchPattern(getRhs(), m_One()))
496  return getLhs();
497 
498  // Don't fold if it would require a division by zero.
499  bool div0 = false;
500  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
501  [&](APInt a, const APInt &b) {
502  if (div0 || !b) {
503  div0 = true;
504  return a;
505  }
506  return a.udiv(b);
507  });
508 
509  return div0 ? Attribute() : result;
510 }
511 
512 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
513  // X / 0 => UB
514  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // DivSIOp
520 //===----------------------------------------------------------------------===//
521 
522 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
523  // divsi (x, 1) -> x.
524  if (matchPattern(getRhs(), m_One()))
525  return getLhs();
526 
527  // Don't fold if it would overflow or if it requires a division by zero.
528  bool overflowOrDiv0 = false;
529  auto result = constFoldBinaryOp<IntegerAttr>(
530  adaptor.getOperands(), [&](APInt a, const APInt &b) {
531  if (overflowOrDiv0 || !b) {
532  overflowOrDiv0 = true;
533  return a;
534  }
535  return a.sdiv_ov(b, overflowOrDiv0);
536  });
537 
538  return overflowOrDiv0 ? Attribute() : result;
539 }
540 
541 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
542  bool mayHaveUB = true;
543 
544  APInt constRHS;
545  // X / 0 => UB
546  // INT_MIN / -1 => UB
547  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
548  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
549 
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Ceil and floor division folding helpers
555 //===----------------------------------------------------------------------===//
556 
557 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
558  bool &overflow) {
559  // Returns (a-1)/b + 1
560  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
561  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
562  return val.sadd_ov(one, overflow);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // CeilDivUIOp
567 //===----------------------------------------------------------------------===//
568 
569 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
570  // ceildivui (x, 1) -> x.
571  if (matchPattern(getRhs(), m_One()))
572  return getLhs();
573 
574  bool overflowOrDiv0 = false;
575  auto result = constFoldBinaryOp<IntegerAttr>(
576  adaptor.getOperands(), [&](APInt a, const APInt &b) {
577  if (overflowOrDiv0 || !b) {
578  overflowOrDiv0 = true;
579  return a;
580  }
581  APInt quotient = a.udiv(b);
582  if (!a.urem(b))
583  return quotient;
584  APInt one(a.getBitWidth(), 1, true);
585  return quotient.uadd_ov(one, overflowOrDiv0);
586  });
587 
588  return overflowOrDiv0 ? Attribute() : result;
589 }
590 
591 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
592  // X / 0 => UB
593  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // CeilDivSIOp
599 //===----------------------------------------------------------------------===//
600 
601 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
602  // ceildivsi (x, 1) -> x.
603  if (matchPattern(getRhs(), m_One()))
604  return getLhs();
605 
606  // Don't fold if it would overflow or if it requires a division by zero.
607  bool overflowOrDiv0 = false;
608  auto result = constFoldBinaryOp<IntegerAttr>(
609  adaptor.getOperands(), [&](APInt a, const APInt &b) {
610  if (overflowOrDiv0 || !b) {
611  overflowOrDiv0 = true;
612  return a;
613  }
614  if (!a)
615  return a;
616  // After this point we know that neither a or b are zero.
617  unsigned bits = a.getBitWidth();
618  APInt zero = APInt::getZero(bits);
619  bool aGtZero = a.sgt(zero);
620  bool bGtZero = b.sgt(zero);
621  if (aGtZero && bGtZero) {
622  // Both positive, return ceil(a, b).
623  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
624  }
625  if (!aGtZero && !bGtZero) {
626  // Both negative, return ceil(-a, -b).
627  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
628  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
629  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
630  }
631  if (!aGtZero && bGtZero) {
632  // A is negative, b is positive, return - ( -a / b).
633  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
634  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
635  return zero.ssub_ov(div, overflowOrDiv0);
636  }
637  // A is positive, b is negative, return - (a / -b).
638  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
639  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
640  return zero.ssub_ov(div, overflowOrDiv0);
641  });
642 
643  return overflowOrDiv0 ? Attribute() : result;
644 }
645 
646 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
647  bool mayHaveUB = true;
648 
649  APInt constRHS;
650  // X / 0 => UB
651  // INT_MIN / -1 => UB
652  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
653  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
654 
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // FloorDivSIOp
660 //===----------------------------------------------------------------------===//
661 
662 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
663  // floordivsi (x, 1) -> x.
664  if (matchPattern(getRhs(), m_One()))
665  return getLhs();
666 
667  // Don't fold if it would overflow or if it requires a division by zero.
668  bool overflowOrDiv0 = false;
669  auto result = constFoldBinaryOp<IntegerAttr>(
670  adaptor.getOperands(), [&](APInt a, const APInt &b) {
671  if (overflowOrDiv0 || !b) {
672  overflowOrDiv0 = true;
673  return a;
674  }
675  if (!a)
676  return a;
677  // After this point we know that neither a or b are zero.
678  unsigned bits = a.getBitWidth();
679  APInt zero = APInt::getZero(bits);
680  bool aGtZero = a.sgt(zero);
681  bool bGtZero = b.sgt(zero);
682  if (aGtZero && bGtZero) {
683  // Both positive, return a / b.
684  return a.sdiv_ov(b, overflowOrDiv0);
685  }
686  if (!aGtZero && !bGtZero) {
687  // Both negative, return -a / -b.
688  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
689  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
690  return posA.sdiv_ov(posB, overflowOrDiv0);
691  }
692  if (!aGtZero && bGtZero) {
693  // A is negative, b is positive, return - ceil(-a, b).
694  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
695  APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
696  return zero.ssub_ov(ceil, overflowOrDiv0);
697  }
698  // A is positive, b is negative, return - ceil(a, -b).
699  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
700  APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
701  return zero.ssub_ov(ceil, overflowOrDiv0);
702  });
703 
704  return overflowOrDiv0 ? Attribute() : result;
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // RemUIOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
712  // remui (x, 1) -> 0.
713  if (matchPattern(getRhs(), m_One()))
714  return Builder(getContext()).getZeroAttr(getType());
715 
716  // Don't fold if it would require a division by zero.
717  bool div0 = false;
718  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
719  [&](APInt a, const APInt &b) {
720  if (div0 || b.isZero()) {
721  div0 = true;
722  return a;
723  }
724  return a.urem(b);
725  });
726 
727  return div0 ? Attribute() : result;
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // RemSIOp
732 //===----------------------------------------------------------------------===//
733 
734 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
735  // remsi (x, 1) -> 0.
736  if (matchPattern(getRhs(), m_One()))
737  return Builder(getContext()).getZeroAttr(getType());
738 
739  // Don't fold if it would require a division by zero.
740  bool div0 = false;
741  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
742  [&](APInt a, const APInt &b) {
743  if (div0 || b.isZero()) {
744  div0 = true;
745  return a;
746  }
747  return a.srem(b);
748  });
749 
750  return div0 ? Attribute() : result;
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // AndIOp
755 //===----------------------------------------------------------------------===//
756 
757 /// Fold `and(a, and(a, b))` to `and(a, b)`
758 static Value foldAndIofAndI(arith::AndIOp op) {
759  for (bool reversePrev : {false, true}) {
760  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
761  .getDefiningOp<arith::AndIOp>();
762  if (!prev)
763  continue;
764 
765  Value other = (reversePrev ? op.getLhs() : op.getRhs());
766  if (other != prev.getLhs() && other != prev.getRhs())
767  continue;
768 
769  return prev.getResult();
770  }
771  return {};
772 }
773 
774 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
775  /// and(x, 0) -> 0
776  if (matchPattern(getRhs(), m_Zero()))
777  return getRhs();
778  /// and(x, allOnes) -> x
779  APInt intValue;
780  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
781  return getLhs();
782  /// and(x, not(x)) -> 0
783  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
784  m_ConstantInt(&intValue))) &&
785  intValue.isAllOnes())
786  return Builder(getContext()).getZeroAttr(getType());
787  /// and(not(x), x) -> 0
788  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
789  m_ConstantInt(&intValue))) &&
790  intValue.isAllOnes())
791  return Builder(getContext()).getZeroAttr(getType());
792 
793  /// and(a, and(a, b)) -> and(a, b)
794  if (Value result = foldAndIofAndI(*this))
795  return result;
796 
797  return constFoldBinaryOp<IntegerAttr>(
798  adaptor.getOperands(),
799  [](APInt a, const APInt &b) { return std::move(a) & b; });
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // OrIOp
804 //===----------------------------------------------------------------------===//
805 
806 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
807  /// or(x, 0) -> x
808  if (matchPattern(getRhs(), m_Zero()))
809  return getLhs();
810  /// or(x, <all ones>) -> <all ones>
811  if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
812  if (rhsAttr.getValue().isAllOnes())
813  return rhsAttr;
814 
815  APInt intValue;
816  /// or(x, xor(x, 1)) -> 1
817  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
818  m_ConstantInt(&intValue))) &&
819  intValue.isAllOnes())
820  return getRhs().getDefiningOp<XOrIOp>().getRhs();
821  /// or(xor(x, 1), x) -> 1
822  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
823  m_ConstantInt(&intValue))) &&
824  intValue.isAllOnes())
825  return getLhs().getDefiningOp<XOrIOp>().getRhs();
826 
827  return constFoldBinaryOp<IntegerAttr>(
828  adaptor.getOperands(),
829  [](APInt a, const APInt &b) { return std::move(a) | b; });
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // XOrIOp
834 //===----------------------------------------------------------------------===//
835 
836 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
837  /// xor(x, 0) -> x
838  if (matchPattern(getRhs(), m_Zero()))
839  return getLhs();
840  /// xor(x, x) -> 0
841  if (getLhs() == getRhs())
842  return Builder(getContext()).getZeroAttr(getType());
843  /// xor(xor(x, a), a) -> x
844  /// xor(xor(a, x), a) -> x
845  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
846  if (prev.getRhs() == getRhs())
847  return prev.getLhs();
848  if (prev.getLhs() == getRhs())
849  return prev.getRhs();
850  }
851  /// xor(a, xor(x, a)) -> x
852  /// xor(a, xor(a, x)) -> x
853  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
854  if (prev.getRhs() == getLhs())
855  return prev.getLhs();
856  if (prev.getLhs() == getLhs())
857  return prev.getRhs();
858  }
859 
860  return constFoldBinaryOp<IntegerAttr>(
861  adaptor.getOperands(),
862  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
863 }
864 
865 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
866  MLIRContext *context) {
867  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // NegFOp
872 //===----------------------------------------------------------------------===//
873 
874 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
875  /// negf(negf(x)) -> x
876  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
877  return op.getOperand();
878  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
879  [](const APFloat &a) { return -a; });
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // AddFOp
884 //===----------------------------------------------------------------------===//
885 
886 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
887  // addf(x, -0) -> x
888  if (matchPattern(getRhs(), m_NegZeroFloat()))
889  return getLhs();
890 
891  return constFoldBinaryOp<FloatAttr>(
892  adaptor.getOperands(),
893  [](const APFloat &a, const APFloat &b) { return a + b; });
894 }
895 
896 //===----------------------------------------------------------------------===//
897 // SubFOp
898 //===----------------------------------------------------------------------===//
899 
900 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
901  // subf(x, +0) -> x
902  if (matchPattern(getRhs(), m_PosZeroFloat()))
903  return getLhs();
904 
905  return constFoldBinaryOp<FloatAttr>(
906  adaptor.getOperands(),
907  [](const APFloat &a, const APFloat &b) { return a - b; });
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // MaxFOp
912 //===----------------------------------------------------------------------===//
913 
914 OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) {
915  // maxf(x,x) -> x
916  if (getLhs() == getRhs())
917  return getRhs();
918 
919  // maxf(x, -inf) -> x
920  if (matchPattern(getRhs(), m_NegInfFloat()))
921  return getLhs();
922 
923  return constFoldBinaryOp<FloatAttr>(
924  adaptor.getOperands(),
925  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // MaxSIOp
930 //===----------------------------------------------------------------------===//
931 
932 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
933  // maxsi(x,x) -> x
934  if (getLhs() == getRhs())
935  return getRhs();
936 
937  APInt intValue;
938  // maxsi(x,MAX_INT) -> MAX_INT
939  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
940  intValue.isMaxSignedValue())
941  return getRhs();
942 
943  // maxsi(x, MIN_INT) -> x
944  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
945  intValue.isMinSignedValue())
946  return getLhs();
947 
948  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
949  [](const APInt &a, const APInt &b) {
950  return llvm::APIntOps::smax(a, b);
951  });
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // MaxUIOp
956 //===----------------------------------------------------------------------===//
957 
958 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
959  // maxui(x,x) -> x
960  if (getLhs() == getRhs())
961  return getRhs();
962 
963  APInt intValue;
964  // maxui(x,MAX_INT) -> MAX_INT
965  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
966  return getRhs();
967 
968  // maxui(x, MIN_INT) -> x
969  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
970  return getLhs();
971 
972  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
973  [](const APInt &a, const APInt &b) {
974  return llvm::APIntOps::umax(a, b);
975  });
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // MinFOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) {
983  // minf(x,x) -> x
984  if (getLhs() == getRhs())
985  return getRhs();
986 
987  // minf(x, +inf) -> x
988  if (matchPattern(getRhs(), m_PosInfFloat()))
989  return getLhs();
990 
991  return constFoldBinaryOp<FloatAttr>(
992  adaptor.getOperands(),
993  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // MinSIOp
998 //===----------------------------------------------------------------------===//
999 
1000 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1001  // minsi(x,x) -> x
1002  if (getLhs() == getRhs())
1003  return getRhs();
1004 
1005  APInt intValue;
1006  // minsi(x,MIN_INT) -> MIN_INT
1007  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
1008  intValue.isMinSignedValue())
1009  return getRhs();
1010 
1011  // minsi(x, MAX_INT) -> x
1012  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
1013  intValue.isMaxSignedValue())
1014  return getLhs();
1015 
1016  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1017  [](const APInt &a, const APInt &b) {
1018  return llvm::APIntOps::smin(a, b);
1019  });
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // MinUIOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1027  // minui(x,x) -> x
1028  if (getLhs() == getRhs())
1029  return getRhs();
1030 
1031  APInt intValue;
1032  // minui(x,MIN_INT) -> MIN_INT
1033  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
1034  return getRhs();
1035 
1036  // minui(x, MAX_INT) -> x
1037  if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
1038  return getLhs();
1039 
1040  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1041  [](const APInt &a, const APInt &b) {
1042  return llvm::APIntOps::umin(a, b);
1043  });
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // MulFOp
1048 //===----------------------------------------------------------------------===//
1049 
1050 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1051  // mulf(x, 1) -> x
1052  if (matchPattern(getRhs(), m_OneFloat()))
1053  return getLhs();
1054 
1055  return constFoldBinaryOp<FloatAttr>(
1056  adaptor.getOperands(),
1057  [](const APFloat &a, const APFloat &b) { return a * b; });
1058 }
1059 
1060 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1061  MLIRContext *context) {
1062  patterns.add<MulFOfNegF>(context);
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // DivFOp
1067 //===----------------------------------------------------------------------===//
1068 
1069 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1070  // divf(x, 1) -> x
1071  if (matchPattern(getRhs(), m_OneFloat()))
1072  return getLhs();
1073 
1074  return constFoldBinaryOp<FloatAttr>(
1075  adaptor.getOperands(),
1076  [](const APFloat &a, const APFloat &b) { return a / b; });
1077 }
1078 
1079 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1080  MLIRContext *context) {
1081  patterns.add<DivFOfNegF>(context);
1082 }
1083 
1084 //===----------------------------------------------------------------------===//
1085 // RemFOp
1086 //===----------------------------------------------------------------------===//
1087 
1088 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1089  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1090  [](const APFloat &a, const APFloat &b) {
1091  APFloat result(a);
1092  (void)result.remainder(b);
1093  return result;
1094  });
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // Utility functions for verifying cast ops
1099 //===----------------------------------------------------------------------===//
1100 
1101 template <typename... Types>
1102 using type_list = std::tuple<Types...> *;
1103 
1104 /// Returns a non-null type only if the provided type is one of the allowed
1105 /// types or one of the allowed shaped types of the allowed types. Returns the
1106 /// element type if a valid shaped type is provided.
1107 template <typename... ShapedTypes, typename... ElementTypes>
1110  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1111  return {};
1112 
1113  auto underlyingType = getElementTypeOrSelf(type);
1114  if (!llvm::isa<ElementTypes...>(underlyingType))
1115  return {};
1116 
1117  return underlyingType;
1118 }
1119 
1120 /// Get allowed underlying types for vectors and tensors.
1121 template <typename... ElementTypes>
1122 static Type getTypeIfLike(Type type) {
1125 }
1126 
1127 /// Get allowed underlying types for vectors, tensors, and memrefs.
1128 template <typename... ElementTypes>
1130  return getUnderlyingType(type,
1133 }
1134 
1135 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1136  return inputs.size() == 1 && outputs.size() == 1 &&
1137  succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // Verifiers for integer and floating point extension/truncation ops
1142 //===----------------------------------------------------------------------===//
1143 
1144 // Extend ops can only extend to a wider type.
1145 template <typename ValType, typename Op>
1147  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1148  Type dstType = getElementTypeOrSelf(op.getType());
1149 
1150  if (llvm::cast<ValType>(srcType).getWidth() >=
1151  llvm::cast<ValType>(dstType).getWidth())
1152  return op.emitError("result type ")
1153  << dstType << " must be wider than operand type " << srcType;
1154 
1155  return success();
1156 }
1157 
1158 // Truncate ops can only truncate to a shorter type.
1159 template <typename ValType, typename Op>
1161  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1162  Type dstType = getElementTypeOrSelf(op.getType());
1163 
1164  if (llvm::cast<ValType>(srcType).getWidth() <=
1165  llvm::cast<ValType>(dstType).getWidth())
1166  return op.emitError("result type ")
1167  << dstType << " must be shorter than operand type " << srcType;
1168 
1169  return success();
1170 }
1171 
1172 /// Validate a cast that changes the width of a type.
1173 template <template <typename> class WidthComparator, typename... ElementTypes>
1174 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1175  if (!areValidCastInputsAndOutputs(inputs, outputs))
1176  return false;
1177 
1178  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1179  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1180  if (!srcType || !dstType)
1181  return false;
1182 
1183  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1184  srcType.getIntOrFloatBitWidth());
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // ExtUIOp
1189 //===----------------------------------------------------------------------===//
1190 
1191 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1192  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1193  getInMutable().assign(lhs.getIn());
1194  return getResult();
1195  }
1196 
1197  Type resType = getElementTypeOrSelf(getType());
1198  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1199  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1200  adaptor.getOperands(), getType(),
1201  [bitWidth](const APInt &a, bool &castStatus) {
1202  return a.zext(bitWidth);
1203  });
1204 }
1205 
1206 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1207  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1208 }
1209 
1211  return verifyExtOp<IntegerType>(*this);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // ExtSIOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1219  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1220  getInMutable().assign(lhs.getIn());
1221  return getResult();
1222  }
1223 
1224  Type resType = getElementTypeOrSelf(getType());
1225  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1226  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1227  adaptor.getOperands(), getType(),
1228  [bitWidth](const APInt &a, bool &castStatus) {
1229  return a.sext(bitWidth);
1230  });
1231 }
1232 
1233 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1234  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1235 }
1236 
1237 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1238  MLIRContext *context) {
1239  patterns.add<ExtSIOfExtUI>(context);
1240 }
1241 
1243  return verifyExtOp<IntegerType>(*this);
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // ExtFOp
1248 //===----------------------------------------------------------------------===//
1249 
1250 /// Always fold extension of FP constants.
1251 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1252  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
1253  if (!constOperand)
1254  return {};
1255 
1256  // Convert to target type via 'double'.
1257  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
1258 }
1259 
1260 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1261  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1262 }
1263 
1264 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // TruncIOp
1268 //===----------------------------------------------------------------------===//
1269 
1270 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1271  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1272  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1273  Value src = getOperand().getDefiningOp()->getOperand(0);
1274  Type srcType = getElementTypeOrSelf(src.getType());
1275  Type dstType = getElementTypeOrSelf(getType());
1276  // trunci(zexti(a)) -> trunci(a)
1277  // trunci(sexti(a)) -> trunci(a)
1278  if (llvm::cast<IntegerType>(srcType).getWidth() >
1279  llvm::cast<IntegerType>(dstType).getWidth()) {
1280  setOperand(src);
1281  return getResult();
1282  }
1283  // trunci(zexti(a)) -> a
1284  // trunci(sexti(a)) -> a
1285  return src;
1286  }
1287 
1288  // trunci(trunci(a)) -> trunci(a))
1289  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1290  setOperand(getOperand().getDefiningOp()->getOperand(0));
1291  return getResult();
1292  }
1293 
1294  Type resType = getElementTypeOrSelf(getType());
1295  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1296  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1297  adaptor.getOperands(), getType(),
1298  [bitWidth](const APInt &a, bool &castStatus) {
1299  return a.trunc(bitWidth);
1300  });
1301 }
1302 
1303 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1304  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1305 }
1306 
1307 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1308  MLIRContext *context) {
1309  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1310  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1311  context);
1312 }
1313 
1315  return verifyTruncateOp<IntegerType>(*this);
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // TruncFOp
1320 //===----------------------------------------------------------------------===//
1321 
1322 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
1323 /// can be represented without precision loss or rounding.
1324 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1325  auto constOperand = adaptor.getIn();
1326  if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
1327  return {};
1328 
1329  // Convert to target type via 'double'.
1330  double sourceValue =
1331  llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
1332  auto targetAttr = FloatAttr::get(getType(), sourceValue);
1333 
1334  // Propagate if constant's value does not change after truncation.
1335  if (sourceValue == targetAttr.getValue().convertToDouble())
1336  return targetAttr;
1337 
1338  return {};
1339 }
1340 
1341 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1342  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1343 }
1344 
1346  return verifyTruncateOp<FloatType>(*this);
1347 }
1348 
1349 //===----------------------------------------------------------------------===//
1350 // AndIOp
1351 //===----------------------------------------------------------------------===//
1352 
1353 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1354  MLIRContext *context) {
1355  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1356 }
1357 
1358 //===----------------------------------------------------------------------===//
1359 // OrIOp
1360 //===----------------------------------------------------------------------===//
1361 
1362 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1363  MLIRContext *context) {
1364  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // Verifiers for casts between integers and floats.
1369 //===----------------------------------------------------------------------===//
1370 
1371 template <typename From, typename To>
1372 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1373  if (!areValidCastInputsAndOutputs(inputs, outputs))
1374  return false;
1375 
1376  auto srcType = getTypeIfLike<From>(inputs.front());
1377  auto dstType = getTypeIfLike<To>(outputs.back());
1378 
1379  return srcType && dstType;
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // UIToFPOp
1384 //===----------------------------------------------------------------------===//
1385 
1386 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1387  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1388 }
1389 
1390 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1391  Type resEleType = getElementTypeOrSelf(getType());
1392  return constFoldCastOp<IntegerAttr, FloatAttr>(
1393  adaptor.getOperands(), getType(),
1394  [&resEleType](const APInt &a, bool &castStatus) {
1395  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1396  APFloat apf(floatTy.getFloatSemantics(),
1397  APInt::getZero(floatTy.getWidth()));
1398  apf.convertFromAPInt(a, /*IsSigned=*/false,
1399  APFloat::rmNearestTiesToEven);
1400  return apf;
1401  });
1402 }
1403 
1404 //===----------------------------------------------------------------------===//
1405 // SIToFPOp
1406 //===----------------------------------------------------------------------===//
1407 
1408 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1409  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1410 }
1411 
1412 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1413  Type resEleType = getElementTypeOrSelf(getType());
1414  return constFoldCastOp<IntegerAttr, FloatAttr>(
1415  adaptor.getOperands(), getType(),
1416  [&resEleType](const APInt &a, bool &castStatus) {
1417  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1418  APFloat apf(floatTy.getFloatSemantics(),
1419  APInt::getZero(floatTy.getWidth()));
1420  apf.convertFromAPInt(a, /*IsSigned=*/true,
1421  APFloat::rmNearestTiesToEven);
1422  return apf;
1423  });
1424 }
1425 //===----------------------------------------------------------------------===//
1426 // FPToUIOp
1427 //===----------------------------------------------------------------------===//
1428 
1429 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1430  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1431 }
1432 
1433 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1434  Type resType = getElementTypeOrSelf(getType());
1435  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1436  return constFoldCastOp<FloatAttr, IntegerAttr>(
1437  adaptor.getOperands(), getType(),
1438  [&bitWidth](const APFloat &a, bool &castStatus) {
1439  bool ignored;
1440  APSInt api(bitWidth, /*isUnsigned=*/true);
1441  castStatus = APFloat::opInvalidOp !=
1442  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1443  return api;
1444  });
1445 }
1446 
1447 //===----------------------------------------------------------------------===//
1448 // FPToSIOp
1449 //===----------------------------------------------------------------------===//
1450 
1451 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1452  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1453 }
1454 
1455 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1456  Type resType = getElementTypeOrSelf(getType());
1457  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1458  return constFoldCastOp<FloatAttr, IntegerAttr>(
1459  adaptor.getOperands(), getType(),
1460  [&bitWidth](const APFloat &a, bool &castStatus) {
1461  bool ignored;
1462  APSInt api(bitWidth, /*isUnsigned=*/false);
1463  castStatus = APFloat::opInvalidOp !=
1464  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1465  return api;
1466  });
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // IndexCastOp
1471 //===----------------------------------------------------------------------===//
1472 
1473 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1474  if (!areValidCastInputsAndOutputs(inputs, outputs))
1475  return false;
1476 
1477  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1478  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1479  if (!srcType || !dstType)
1480  return false;
1481 
1482  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1483  (srcType.isSignlessInteger() && dstType.isIndex());
1484 }
1485 
1486 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1487  TypeRange outputs) {
1488  return areIndexCastCompatible(inputs, outputs);
1489 }
1490 
1491 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1492  // index_cast(constant) -> constant
1493  unsigned resultBitwidth = 64; // Default for index integer attributes.
1494  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1495  resultBitwidth = intTy.getWidth();
1496 
1497  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1498  adaptor.getOperands(), getType(),
1499  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1500  return a.sextOrTrunc(resultBitwidth);
1501  });
1502 }
1503 
1504 void arith::IndexCastOp::getCanonicalizationPatterns(
1505  RewritePatternSet &patterns, MLIRContext *context) {
1506  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1507 }
1508 
1509 //===----------------------------------------------------------------------===//
1510 // IndexCastUIOp
1511 //===----------------------------------------------------------------------===//
1512 
1513 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1514  TypeRange outputs) {
1515  return areIndexCastCompatible(inputs, outputs);
1516 }
1517 
1518 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1519  // index_castui(constant) -> constant
1520  unsigned resultBitwidth = 64; // Default for index integer attributes.
1521  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1522  resultBitwidth = intTy.getWidth();
1523 
1524  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1525  adaptor.getOperands(), getType(),
1526  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1527  return a.zextOrTrunc(resultBitwidth);
1528  });
1529 }
1530 
1531 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1532  RewritePatternSet &patterns, MLIRContext *context) {
1533  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // BitcastOp
1538 //===----------------------------------------------------------------------===//
1539 
1540 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1541  if (!areValidCastInputsAndOutputs(inputs, outputs))
1542  return false;
1543 
1544  auto srcType =
1545  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1546  auto dstType =
1547  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1548  if (!srcType || !dstType)
1549  return false;
1550 
1551  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1552 }
1553 
1554 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1555  auto resType = getType();
1556  auto operand = adaptor.getIn();
1557  if (!operand)
1558  return {};
1559 
1560  /// Bitcast dense elements.
1561  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1562  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1563  /// Other shaped types unhandled.
1564  if (llvm::isa<ShapedType>(resType))
1565  return {};
1566 
1567  /// Bitcast integer or float to integer or float.
1568  APInt bits = llvm::isa<FloatAttr>(operand)
1569  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1570  : llvm::cast<IntegerAttr>(operand).getValue();
1571 
1572  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1573  return FloatAttr::get(resType,
1574  APFloat(resFloatType.getFloatSemantics(), bits));
1575  return IntegerAttr::get(resType, bits);
1576 }
1577 
1578 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1579  MLIRContext *context) {
1580  patterns.add<BitcastOfBitcast>(context);
1581 }
1582 
1583 //===----------------------------------------------------------------------===//
1584 // CmpIOp
1585 //===----------------------------------------------------------------------===//
1586 
1587 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1588 /// comparison predicates.
1589 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1590  const APInt &lhs, const APInt &rhs) {
1591  switch (predicate) {
1592  case arith::CmpIPredicate::eq:
1593  return lhs.eq(rhs);
1594  case arith::CmpIPredicate::ne:
1595  return lhs.ne(rhs);
1596  case arith::CmpIPredicate::slt:
1597  return lhs.slt(rhs);
1598  case arith::CmpIPredicate::sle:
1599  return lhs.sle(rhs);
1600  case arith::CmpIPredicate::sgt:
1601  return lhs.sgt(rhs);
1602  case arith::CmpIPredicate::sge:
1603  return lhs.sge(rhs);
1604  case arith::CmpIPredicate::ult:
1605  return lhs.ult(rhs);
1606  case arith::CmpIPredicate::ule:
1607  return lhs.ule(rhs);
1608  case arith::CmpIPredicate::ugt:
1609  return lhs.ugt(rhs);
1610  case arith::CmpIPredicate::uge:
1611  return lhs.uge(rhs);
1612  }
1613  llvm_unreachable("unknown cmpi predicate kind");
1614 }
1615 
1616 /// Returns true if the predicate is true for two equal operands.
1617 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1618  switch (predicate) {
1619  case arith::CmpIPredicate::eq:
1620  case arith::CmpIPredicate::sle:
1621  case arith::CmpIPredicate::sge:
1622  case arith::CmpIPredicate::ule:
1623  case arith::CmpIPredicate::uge:
1624  return true;
1625  case arith::CmpIPredicate::ne:
1626  case arith::CmpIPredicate::slt:
1627  case arith::CmpIPredicate::sgt:
1628  case arith::CmpIPredicate::ult:
1629  case arith::CmpIPredicate::ugt:
1630  return false;
1631  }
1632  llvm_unreachable("unknown cmpi predicate kind");
1633 }
1634 
1635 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1636  auto boolAttr = BoolAttr::get(ctx, value);
1637  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
1638  if (!shapedType)
1639  return boolAttr;
1640  return DenseElementsAttr::get(shapedType, boolAttr);
1641 }
1642 
1643 static std::optional<int64_t> getIntegerWidth(Type t) {
1644  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1645  return intType.getWidth();
1646  }
1647  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1648  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1649  }
1650  return std::nullopt;
1651 }
1652 
1653 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1654  // cmpi(pred, x, x)
1655  if (getLhs() == getRhs()) {
1656  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1657  return getBoolAttribute(getType(), getContext(), val);
1658  }
1659 
1660  if (matchPattern(getRhs(), m_Zero())) {
1661  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1662  // extsi(%x : i1 -> iN) != 0 -> %x
1663  std::optional<int64_t> integerWidth =
1664  getIntegerWidth(extOp.getOperand().getType());
1665  if (integerWidth && integerWidth.value() == 1 &&
1666  getPredicate() == arith::CmpIPredicate::ne)
1667  return extOp.getOperand();
1668  }
1669  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1670  // extui(%x : i1 -> iN) != 0 -> %x
1671  std::optional<int64_t> integerWidth =
1672  getIntegerWidth(extOp.getOperand().getType());
1673  if (integerWidth && integerWidth.value() == 1 &&
1674  getPredicate() == arith::CmpIPredicate::ne)
1675  return extOp.getOperand();
1676  }
1677  }
1678 
1679  // Move constant to the right side.
1680  if (adaptor.getLhs() && !adaptor.getRhs()) {
1681  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1682  using Pred = CmpIPredicate;
1683  const std::pair<Pred, Pred> invPreds[] = {
1684  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1685  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1686  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1687  {Pred::ne, Pred::ne},
1688  };
1689  Pred origPred = getPredicate();
1690  for (auto pred : invPreds) {
1691  if (origPred == pred.first) {
1692  setPredicate(pred.second);
1693  Value lhs = getLhs();
1694  Value rhs = getRhs();
1695  getLhsMutable().assign(rhs);
1696  getRhsMutable().assign(lhs);
1697  return getResult();
1698  }
1699  }
1700  llvm_unreachable("unknown cmpi predicate kind");
1701  }
1702 
1703  // We are moving constants to the right side; So if lhs is constant rhs is
1704  // guaranteed to be a constant.
1705  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1706  return constFoldBinaryOp<IntegerAttr>(
1707  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1708  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1709  return APInt(1,
1710  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1711  });
1712  }
1713 
1714  return {};
1715 }
1716 
1717 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1718  MLIRContext *context) {
1719  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // CmpFOp
1724 //===----------------------------------------------------------------------===//
1725 
1726 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1727 /// comparison predicates.
1728 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1729  const APFloat &lhs, const APFloat &rhs) {
1730  auto cmpResult = lhs.compare(rhs);
1731  switch (predicate) {
1732  case arith::CmpFPredicate::AlwaysFalse:
1733  return false;
1734  case arith::CmpFPredicate::OEQ:
1735  return cmpResult == APFloat::cmpEqual;
1736  case arith::CmpFPredicate::OGT:
1737  return cmpResult == APFloat::cmpGreaterThan;
1738  case arith::CmpFPredicate::OGE:
1739  return cmpResult == APFloat::cmpGreaterThan ||
1740  cmpResult == APFloat::cmpEqual;
1741  case arith::CmpFPredicate::OLT:
1742  return cmpResult == APFloat::cmpLessThan;
1743  case arith::CmpFPredicate::OLE:
1744  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1745  case arith::CmpFPredicate::ONE:
1746  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1747  case arith::CmpFPredicate::ORD:
1748  return cmpResult != APFloat::cmpUnordered;
1749  case arith::CmpFPredicate::UEQ:
1750  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1751  case arith::CmpFPredicate::UGT:
1752  return cmpResult == APFloat::cmpUnordered ||
1753  cmpResult == APFloat::cmpGreaterThan;
1754  case arith::CmpFPredicate::UGE:
1755  return cmpResult == APFloat::cmpUnordered ||
1756  cmpResult == APFloat::cmpGreaterThan ||
1757  cmpResult == APFloat::cmpEqual;
1758  case arith::CmpFPredicate::ULT:
1759  return cmpResult == APFloat::cmpUnordered ||
1760  cmpResult == APFloat::cmpLessThan;
1761  case arith::CmpFPredicate::ULE:
1762  return cmpResult == APFloat::cmpUnordered ||
1763  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1764  case arith::CmpFPredicate::UNE:
1765  return cmpResult != APFloat::cmpEqual;
1766  case arith::CmpFPredicate::UNO:
1767  return cmpResult == APFloat::cmpUnordered;
1768  case arith::CmpFPredicate::AlwaysTrue:
1769  return true;
1770  }
1771  llvm_unreachable("unknown cmpf predicate kind");
1772 }
1773 
1774 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1775  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1776  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1777 
1778  // If one operand is NaN, making them both NaN does not change the result.
1779  if (lhs && lhs.getValue().isNaN())
1780  rhs = lhs;
1781  if (rhs && rhs.getValue().isNaN())
1782  lhs = rhs;
1783 
1784  if (!lhs || !rhs)
1785  return {};
1786 
1787  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1788  return BoolAttr::get(getContext(), val);
1789 }
1790 
1791 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1792 public:
1794 
1795  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1796  bool isUnsigned) {
1797  using namespace arith;
1798  switch (pred) {
1799  case CmpFPredicate::UEQ:
1800  case CmpFPredicate::OEQ:
1801  return CmpIPredicate::eq;
1802  case CmpFPredicate::UGT:
1803  case CmpFPredicate::OGT:
1804  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1805  case CmpFPredicate::UGE:
1806  case CmpFPredicate::OGE:
1807  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1808  case CmpFPredicate::ULT:
1809  case CmpFPredicate::OLT:
1810  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1811  case CmpFPredicate::ULE:
1812  case CmpFPredicate::OLE:
1813  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1814  case CmpFPredicate::UNE:
1815  case CmpFPredicate::ONE:
1816  return CmpIPredicate::ne;
1817  default:
1818  llvm_unreachable("Unexpected predicate!");
1819  }
1820  }
1821 
1823  PatternRewriter &rewriter) const override {
1824  FloatAttr flt;
1825  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1826  return failure();
1827 
1828  const APFloat &rhs = flt.getValue();
1829 
1830  // Don't attempt to fold a nan.
1831  if (rhs.isNaN())
1832  return failure();
1833 
1834  // Get the width of the mantissa. We don't want to hack on conversions that
1835  // might lose information from the integer, e.g. "i64 -> float"
1836  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1837  int mantissaWidth = floatTy.getFPMantissaWidth();
1838  if (mantissaWidth <= 0)
1839  return failure();
1840 
1841  bool isUnsigned;
1842  Value intVal;
1843 
1844  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1845  isUnsigned = false;
1846  intVal = si.getIn();
1847  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1848  isUnsigned = true;
1849  intVal = ui.getIn();
1850  } else {
1851  return failure();
1852  }
1853 
1854  // Check to see that the input is converted from an integer type that is
1855  // small enough that preserves all bits.
1856  auto intTy = llvm::cast<IntegerType>(intVal.getType());
1857  auto intWidth = intTy.getWidth();
1858 
1859  // Number of bits representing values, as opposed to the sign
1860  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1861 
1862  // Following test does NOT adjust intWidth downwards for signed inputs,
1863  // because the most negative value still requires all the mantissa bits
1864  // to distinguish it from one less than that value.
1865  if ((int)intWidth > mantissaWidth) {
1866  // Conversion would lose accuracy. Check if loss can impact comparison.
1867  int exponent = ilogb(rhs);
1868  if (exponent == APFloat::IEK_Inf) {
1869  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1870  if (maxExponent < (int)valueBits) {
1871  // Conversion could create infinity.
1872  return failure();
1873  }
1874  } else {
1875  // Note that if rhs is zero or NaN, then Exp is negative
1876  // and first condition is trivially false.
1877  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1878  // Conversion could affect comparison.
1879  return failure();
1880  }
1881  }
1882  }
1883 
1884  // Convert to equivalent cmpi predicate
1885  CmpIPredicate pred;
1886  switch (op.getPredicate()) {
1887  case CmpFPredicate::ORD:
1888  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1889  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1890  /*width=*/1);
1891  return success();
1892  case CmpFPredicate::UNO:
1893  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1894  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1895  /*width=*/1);
1896  return success();
1897  default:
1898  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1899  break;
1900  }
1901 
1902  if (!isUnsigned) {
1903  // If the rhs value is > SignedMax, fold the comparison. This handles
1904  // +INF and large values.
1905  APFloat signedMax(rhs.getSemantics());
1906  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1907  APFloat::rmNearestTiesToEven);
1908  if (signedMax < rhs) { // smax < 13123.0
1909  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1910  pred == CmpIPredicate::sle)
1911  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1912  /*width=*/1);
1913  else
1914  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1915  /*width=*/1);
1916  return success();
1917  }
1918  } else {
1919  // If the rhs value is > UnsignedMax, fold the comparison. This handles
1920  // +INF and large values.
1921  APFloat unsignedMax(rhs.getSemantics());
1922  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1923  APFloat::rmNearestTiesToEven);
1924  if (unsignedMax < rhs) { // umax < 13123.0
1925  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1926  pred == CmpIPredicate::ule)
1927  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1928  /*width=*/1);
1929  else
1930  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1931  /*width=*/1);
1932  return success();
1933  }
1934  }
1935 
1936  if (!isUnsigned) {
1937  // See if the rhs value is < SignedMin.
1938  APFloat signedMin(rhs.getSemantics());
1939  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1940  APFloat::rmNearestTiesToEven);
1941  if (signedMin > rhs) { // smin > 12312.0
1942  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1943  pred == CmpIPredicate::sge)
1944  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1945  /*width=*/1);
1946  else
1947  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1948  /*width=*/1);
1949  return success();
1950  }
1951  } else {
1952  // See if the rhs value is < UnsignedMin.
1953  APFloat unsignedMin(rhs.getSemantics());
1954  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1955  APFloat::rmNearestTiesToEven);
1956  if (unsignedMin > rhs) { // umin > 12312.0
1957  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1958  pred == CmpIPredicate::uge)
1959  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1960  /*width=*/1);
1961  else
1962  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1963  /*width=*/1);
1964  return success();
1965  }
1966  }
1967 
1968  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1969  // [0, UMAX], but it may still be fractional. See if it is fractional by
1970  // casting the FP value to the integer value and back, checking for
1971  // equality. Don't do this for zero, because -0.0 is not fractional.
1972  bool ignored;
1973  APSInt rhsInt(intWidth, isUnsigned);
1974  if (APFloat::opInvalidOp ==
1975  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1976  // Undefined behavior invoked - the destination type can't represent
1977  // the input constant.
1978  return failure();
1979  }
1980 
1981  if (!rhs.isZero()) {
1982  APFloat apf(floatTy.getFloatSemantics(),
1983  APInt::getZero(floatTy.getWidth()));
1984  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1985 
1986  bool equal = apf == rhs;
1987  if (!equal) {
1988  // If we had a comparison against a fractional value, we have to adjust
1989  // the compare predicate and sometimes the value. rhsInt is rounded
1990  // towards zero at this point.
1991  switch (pred) {
1992  case CmpIPredicate::ne: // (float)int != 4.4 --> true
1993  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1994  /*width=*/1);
1995  return success();
1996  case CmpIPredicate::eq: // (float)int == 4.4 --> false
1997  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1998  /*width=*/1);
1999  return success();
2000  case CmpIPredicate::ule:
2001  // (float)int <= 4.4 --> int <= 4
2002  // (float)int <= -4.4 --> false
2003  if (rhs.isNegative()) {
2004  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2005  /*width=*/1);
2006  return success();
2007  }
2008  break;
2009  case CmpIPredicate::sle:
2010  // (float)int <= 4.4 --> int <= 4
2011  // (float)int <= -4.4 --> int < -4
2012  if (rhs.isNegative())
2013  pred = CmpIPredicate::slt;
2014  break;
2015  case CmpIPredicate::ult:
2016  // (float)int < -4.4 --> false
2017  // (float)int < 4.4 --> int <= 4
2018  if (rhs.isNegative()) {
2019  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2020  /*width=*/1);
2021  return success();
2022  }
2023  pred = CmpIPredicate::ule;
2024  break;
2025  case CmpIPredicate::slt:
2026  // (float)int < -4.4 --> int < -4
2027  // (float)int < 4.4 --> int <= 4
2028  if (!rhs.isNegative())
2029  pred = CmpIPredicate::sle;
2030  break;
2031  case CmpIPredicate::ugt:
2032  // (float)int > 4.4 --> int > 4
2033  // (float)int > -4.4 --> true
2034  if (rhs.isNegative()) {
2035  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2036  /*width=*/1);
2037  return success();
2038  }
2039  break;
2040  case CmpIPredicate::sgt:
2041  // (float)int > 4.4 --> int > 4
2042  // (float)int > -4.4 --> int >= -4
2043  if (rhs.isNegative())
2044  pred = CmpIPredicate::sge;
2045  break;
2046  case CmpIPredicate::uge:
2047  // (float)int >= -4.4 --> true
2048  // (float)int >= 4.4 --> int > 4
2049  if (rhs.isNegative()) {
2050  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2051  /*width=*/1);
2052  return success();
2053  }
2054  pred = CmpIPredicate::ugt;
2055  break;
2056  case CmpIPredicate::sge:
2057  // (float)int >= -4.4 --> int >= -4
2058  // (float)int >= 4.4 --> int > 4
2059  if (!rhs.isNegative())
2060  pred = CmpIPredicate::sgt;
2061  break;
2062  }
2063  }
2064  }
2065 
2066  // Lower this FP comparison into an appropriate integer version of the
2067  // comparison.
2068  rewriter.replaceOpWithNewOp<CmpIOp>(
2069  op, pred, intVal,
2070  rewriter.create<ConstantOp>(
2071  op.getLoc(), intVal.getType(),
2072  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2073  return success();
2074  }
2075 };
2076 
2077 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2078  MLIRContext *context) {
2079  patterns.insert<CmpFIntToFPConst>(context);
2080 }
2081 
2082 //===----------------------------------------------------------------------===//
2083 // SelectOp
2084 //===----------------------------------------------------------------------===//
2085 
2086 // Transforms a select of a boolean to arithmetic operations
2087 //
2088 // arith.select %arg, %x, %y : i1
2089 //
2090 // becomes
2091 //
2092 // and(%arg, %x) or and(!%arg, %y)
2093 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
2095 
2096  LogicalResult matchAndRewrite(arith::SelectOp op,
2097  PatternRewriter &rewriter) const override {
2098  if (!op.getType().isInteger(1))
2099  return failure();
2100 
2101  Value falseConstant =
2102  rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
2103  Value notCondition = rewriter.create<arith::XOrIOp>(
2104  op.getLoc(), op.getCondition(), falseConstant);
2105 
2106  Value trueVal = rewriter.create<arith::AndIOp>(
2107  op.getLoc(), op.getCondition(), op.getTrueValue());
2108  Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
2109  op.getFalseValue());
2110  rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
2111  return success();
2112  }
2113 };
2114 
2115 // select %arg, %c1, %c0 => extui %arg
2116 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2118 
2119  LogicalResult matchAndRewrite(arith::SelectOp op,
2120  PatternRewriter &rewriter) const override {
2121  // Cannot extui i1 to i1, or i1 to f32
2122  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2123  return failure();
2124 
2125  // select %x, c1, %c0 => extui %arg
2126  if (matchPattern(op.getTrueValue(), m_One()) &&
2127  matchPattern(op.getFalseValue(), m_Zero())) {
2128  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2129  op.getCondition());
2130  return success();
2131  }
2132 
2133  // select %x, c0, %c1 => extui (xor %arg, true)
2134  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2135  matchPattern(op.getFalseValue(), m_One())) {
2136  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2137  op, op.getType(),
2138  rewriter.create<arith::XOrIOp>(
2139  op.getLoc(), op.getCondition(),
2140  rewriter.create<arith::ConstantIntOp>(
2141  op.getLoc(), 1, op.getCondition().getType())));
2142  return success();
2143  }
2144 
2145  return failure();
2146  }
2147 };
2148 
2149 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2150  MLIRContext *context) {
2151  results.add<SelectI1Simplify, SelectToExtUI>(context);
2152 }
2153 
2154 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2155  Value trueVal = getTrueValue();
2156  Value falseVal = getFalseValue();
2157  if (trueVal == falseVal)
2158  return trueVal;
2159 
2160  Value condition = getCondition();
2161 
2162  // select true, %0, %1 => %0
2163  if (matchPattern(condition, m_One()))
2164  return trueVal;
2165 
2166  // select false, %0, %1 => %1
2167  if (matchPattern(condition, m_Zero()))
2168  return falseVal;
2169 
2170  // select %x, true, false => %x
2171  if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
2172  matchPattern(getFalseValue(), m_Zero()))
2173  return condition;
2174 
2175  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2176  auto pred = cmp.getPredicate();
2177  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2178  auto cmpLhs = cmp.getLhs();
2179  auto cmpRhs = cmp.getRhs();
2180 
2181  // %0 = arith.cmpi eq, %arg0, %arg1
2182  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2183 
2184  // %0 = arith.cmpi ne, %arg0, %arg1
2185  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2186 
2187  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2188  (cmpRhs == trueVal && cmpLhs == falseVal))
2189  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2190  }
2191  }
2192 
2193  // Constant-fold constant operands over non-splat constant condition.
2194  // select %cst_vec, %cst0, %cst1 => %cst2
2195  if (auto cond =
2196  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2197  if (auto lhs =
2198  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2199  if (auto rhs =
2200  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2201  SmallVector<Attribute> results;
2202  results.reserve(static_cast<size_t>(cond.getNumElements()));
2203  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2204  cond.value_end<BoolAttr>());
2205  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2206  lhs.value_end<Attribute>());
2207  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2208  rhs.value_end<Attribute>());
2209 
2210  for (auto [condVal, lhsVal, rhsVal] :
2211  llvm::zip_equal(condVals, lhsVals, rhsVals))
2212  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2213 
2214  return DenseElementsAttr::get(lhs.getType(), results);
2215  }
2216  }
2217  }
2218 
2219  return nullptr;
2220 }
2221 
2222 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2223  Type conditionType, resultType;
2225  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2226  parser.parseOptionalAttrDict(result.attributes) ||
2227  parser.parseColonType(resultType))
2228  return failure();
2229 
2230  // Check for the explicit condition type if this is a masked tensor or vector.
2231  if (succeeded(parser.parseOptionalComma())) {
2232  conditionType = resultType;
2233  if (parser.parseType(resultType))
2234  return failure();
2235  } else {
2236  conditionType = parser.getBuilder().getI1Type();
2237  }
2238 
2239  result.addTypes(resultType);
2240  return parser.resolveOperands(operands,
2241  {conditionType, resultType, resultType},
2242  parser.getNameLoc(), result.operands);
2243 }
2244 
2246  p << " " << getOperands();
2247  p.printOptionalAttrDict((*this)->getAttrs());
2248  p << " : ";
2249  if (ShapedType condType =
2250  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2251  p << condType << ", ";
2252  p << getType();
2253 }
2254 
2256  Type conditionType = getCondition().getType();
2257  if (conditionType.isSignlessInteger(1))
2258  return success();
2259 
2260  // If the result type is a vector or tensor, the type can be a mask with the
2261  // same elements.
2262  Type resultType = getType();
2263  if (!llvm::isa<TensorType, VectorType>(resultType))
2264  return emitOpError() << "expected condition to be a signless i1, but got "
2265  << conditionType;
2266  Type shapedConditionType = getI1SameShape(resultType);
2267  if (conditionType != shapedConditionType) {
2268  return emitOpError() << "expected condition type to have the same shape "
2269  "as the result type, expected "
2270  << shapedConditionType << ", but got "
2271  << conditionType;
2272  }
2273  return success();
2274 }
2275 //===----------------------------------------------------------------------===//
2276 // ShLIOp
2277 //===----------------------------------------------------------------------===//
2278 
2279 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2280  // shli(x, 0) -> x
2281  if (matchPattern(getRhs(), m_Zero()))
2282  return getLhs();
2283  // Don't fold if shifting more than the bit width.
2284  bool bounded = false;
2285  auto result = constFoldBinaryOp<IntegerAttr>(
2286  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2287  bounded = b.ule(b.getBitWidth());
2288  return a.shl(b);
2289  });
2290  return bounded ? result : Attribute();
2291 }
2292 
2293 //===----------------------------------------------------------------------===//
2294 // ShRUIOp
2295 //===----------------------------------------------------------------------===//
2296 
2297 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2298  // shrui(x, 0) -> x
2299  if (matchPattern(getRhs(), m_Zero()))
2300  return getLhs();
2301  // Don't fold if shifting more than the bit width.
2302  bool bounded = false;
2303  auto result = constFoldBinaryOp<IntegerAttr>(
2304  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2305  bounded = b.ule(b.getBitWidth());
2306  return a.lshr(b);
2307  });
2308  return bounded ? result : Attribute();
2309 }
2310 
2311 //===----------------------------------------------------------------------===//
2312 // ShRSIOp
2313 //===----------------------------------------------------------------------===//
2314 
2315 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2316  // shrsi(x, 0) -> x
2317  if (matchPattern(getRhs(), m_Zero()))
2318  return getLhs();
2319  // Don't fold if shifting more than the bit width.
2320  bool bounded = false;
2321  auto result = constFoldBinaryOp<IntegerAttr>(
2322  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2323  bounded = b.ule(b.getBitWidth());
2324  return a.ashr(b);
2325  });
2326  return bounded ? result : Attribute();
2327 }
2328 
2329 //===----------------------------------------------------------------------===//
2330 // Atomic Enum
2331 //===----------------------------------------------------------------------===//
2332 
2333 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2334 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2335  OpBuilder &builder, Location loc) {
2336  switch (kind) {
2337  case AtomicRMWKind::maxf:
2338  return builder.getFloatAttr(
2339  resultType,
2340  APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
2341  /*Negative=*/true));
2342  case AtomicRMWKind::addf:
2343  case AtomicRMWKind::addi:
2344  case AtomicRMWKind::maxu:
2345  case AtomicRMWKind::ori:
2346  return builder.getZeroAttr(resultType);
2347  case AtomicRMWKind::andi:
2348  return builder.getIntegerAttr(
2349  resultType,
2350  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2351  case AtomicRMWKind::maxs:
2352  return builder.getIntegerAttr(
2353  resultType, APInt::getSignedMinValue(
2354  llvm::cast<IntegerType>(resultType).getWidth()));
2355  case AtomicRMWKind::minf:
2356  return builder.getFloatAttr(
2357  resultType,
2358  APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
2359  /*Negative=*/false));
2360  case AtomicRMWKind::mins:
2361  return builder.getIntegerAttr(
2362  resultType, APInt::getSignedMaxValue(
2363  llvm::cast<IntegerType>(resultType).getWidth()));
2364  case AtomicRMWKind::minu:
2365  return builder.getIntegerAttr(
2366  resultType,
2367  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2368  case AtomicRMWKind::muli:
2369  return builder.getIntegerAttr(resultType, 1);
2370  case AtomicRMWKind::mulf:
2371  return builder.getFloatAttr(resultType, 1);
2372  // TODO: Add remaining reduction operations.
2373  default:
2374  (void)emitOptionalError(loc, "Reduction operation type not supported");
2375  break;
2376  }
2377  return nullptr;
2378 }
2379 
2380 /// Returns the identity value associated with an AtomicRMWKind op.
2381 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2382  OpBuilder &builder, Location loc) {
2383  auto attr = getIdentityValueAttr(op, resultType, builder, loc);
2384  return builder.create<arith::ConstantOp>(loc, attr);
2385 }
2386 
2387 /// Return the value obtained by applying the reduction operation kind
2388 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2389 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2390  Location loc, Value lhs, Value rhs) {
2391  switch (op) {
2392  case AtomicRMWKind::addf:
2393  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2394  case AtomicRMWKind::addi:
2395  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2396  case AtomicRMWKind::mulf:
2397  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2398  case AtomicRMWKind::muli:
2399  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2400  case AtomicRMWKind::maxf:
2401  return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2402  case AtomicRMWKind::minf:
2403  return builder.create<arith::MinFOp>(loc, lhs, rhs);
2404  case AtomicRMWKind::maxs:
2405  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2406  case AtomicRMWKind::mins:
2407  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2408  case AtomicRMWKind::maxu:
2409  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2410  case AtomicRMWKind::minu:
2411  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2412  case AtomicRMWKind::ori:
2413  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2414  case AtomicRMWKind::andi:
2415  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2416  // TODO: Add remaining reduction operations.
2417  default:
2418  (void)emitOptionalError(loc, "Reduction operation type not supported");
2419  break;
2420  }
2421  return nullptr;
2422 }
2423 
2424 //===----------------------------------------------------------------------===//
2425 // TableGen'd op method definitions
2426 //===----------------------------------------------------------------------===//
2427 
2428 #define GET_OP_CLASSES
2429 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2430 
2431 //===----------------------------------------------------------------------===//
2432 // TableGen'd enum attribute definitions
2433 //===----------------------------------------------------------------------===//
2434 
2435 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1174
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1122
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1617
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1108
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:557
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:94
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:36
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1473
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1372
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1146
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:43
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:82
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:758
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1129
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:118
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1643
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1135
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Definition: ArithOps.cpp:1635
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1102
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:287
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1160
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:699
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1822
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1795
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:122
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:225
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:248
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:85
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:318
IntegerType getI1Type()
Definition: Builders.cpp:71
IndexType getIndexType()
Definition: Builders.cpp:69
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:202
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
This class represents a single result from folding an operation.
Definition: OpDefinition.h:265
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:266
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:700
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:514
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:55
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:220
static bool classof(Operation *op)
Definition: ArithOps.cpp:226
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:232
static bool classof(Operation *op)
Definition: ArithOps.cpp:238
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:199
static bool classof(Operation *op)
Definition: ArithOps.cpp:214
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2381
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1589
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2334
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2389
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:51
auto m_Val(Value v)
Definition: Matchers.h:413
MPInt ceil(const Fraction &f)
Definition: Fraction.h:70
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:378
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:322
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:355
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:366
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:401
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:361
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:348
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:327
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:287
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:340
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:332
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2096
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2119
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes