MLIR  19.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 /// Invert an integer comparison predicate.
68 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
69  switch (pred) {
70  case arith::CmpIPredicate::eq:
71  return arith::CmpIPredicate::ne;
72  case arith::CmpIPredicate::ne:
73  return arith::CmpIPredicate::eq;
74  case arith::CmpIPredicate::slt:
75  return arith::CmpIPredicate::sge;
76  case arith::CmpIPredicate::sle:
77  return arith::CmpIPredicate::sgt;
78  case arith::CmpIPredicate::sgt:
79  return arith::CmpIPredicate::sle;
80  case arith::CmpIPredicate::sge:
81  return arith::CmpIPredicate::slt;
82  case arith::CmpIPredicate::ult:
83  return arith::CmpIPredicate::uge;
84  case arith::CmpIPredicate::ule:
85  return arith::CmpIPredicate::ugt;
86  case arith::CmpIPredicate::ugt:
87  return arith::CmpIPredicate::ule;
88  case arith::CmpIPredicate::uge:
89  return arith::CmpIPredicate::ult;
90  }
91  llvm_unreachable("unknown cmpi predicate kind");
92 }
93 
94 /// Equivalent to
95 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
96 ///
97 /// Not possible to implement as chain of calls as this would introduce a
98 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
99 /// on the LLVM dialect and on translation to LLVM.
100 static llvm::RoundingMode
101 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
102  switch (roundingMode) {
103  case RoundingMode::downward:
104  return llvm::RoundingMode::TowardNegative;
105  case RoundingMode::to_nearest_away:
106  return llvm::RoundingMode::NearestTiesToAway;
107  case RoundingMode::to_nearest_even:
108  return llvm::RoundingMode::NearestTiesToEven;
109  case RoundingMode::toward_zero:
110  return llvm::RoundingMode::TowardZero;
111  case RoundingMode::upward:
112  return llvm::RoundingMode::TowardPositive;
113  }
114  llvm_unreachable("Unhandled rounding mode");
115 }
116 
117 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
118  return arith::CmpIPredicateAttr::get(pred.getContext(),
119  invertPredicate(pred.getValue()));
120 }
121 
122 static int64_t getScalarOrElementWidth(Type type) {
123  Type elemTy = getElementTypeOrSelf(type);
124  if (elemTy.isIntOrFloat())
125  return elemTy.getIntOrFloatBitWidth();
126 
127  return -1;
128 }
129 
130 static int64_t getScalarOrElementWidth(Value value) {
131  return getScalarOrElementWidth(value.getType());
132 }
133 
135  APInt value;
136  if (matchPattern(attr, m_ConstantInt(&value)))
137  return value;
138 
139  return failure();
140 }
141 
142 static Attribute getBoolAttribute(Type type, bool value) {
143  auto boolAttr = BoolAttr::get(type.getContext(), value);
144  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
145  if (!shapedType)
146  return boolAttr;
147  return DenseElementsAttr::get(shapedType, boolAttr);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // TableGen'd canonicalization patterns
152 //===----------------------------------------------------------------------===//
153 
154 namespace {
155 #include "ArithCanonicalization.inc"
156 } // namespace
157 
158 //===----------------------------------------------------------------------===//
159 // Common helpers
160 //===----------------------------------------------------------------------===//
161 
162 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
163 static Type getI1SameShape(Type type) {
164  auto i1Type = IntegerType::get(type.getContext(), 1);
165  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
166  return shapedType.cloneWith(std::nullopt, i1Type);
167  if (llvm::isa<UnrankedTensorType>(type))
168  return UnrankedTensorType::get(i1Type);
169  return i1Type;
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // ConstantOp
174 //===----------------------------------------------------------------------===//
175 
176 void arith::ConstantOp::getAsmResultNames(
177  function_ref<void(Value, StringRef)> setNameFn) {
178  auto type = getType();
179  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
180  auto intType = llvm::dyn_cast<IntegerType>(type);
181 
182  // Sugar i1 constants with 'true' and 'false'.
183  if (intType && intType.getWidth() == 1)
184  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
185 
186  // Otherwise, build a complex name with the value and type.
187  SmallString<32> specialNameBuffer;
188  llvm::raw_svector_ostream specialName(specialNameBuffer);
189  specialName << 'c' << intCst.getValue();
190  if (intType)
191  specialName << '_' << type;
192  setNameFn(getResult(), specialName.str());
193  } else {
194  setNameFn(getResult(), "cst");
195  }
196 }
197 
198 /// TODO: disallow arith.constant to return anything other than signless integer
199 /// or float like.
201  auto type = getType();
202  // The value's type must match the return type.
203  if (getValue().getType() != type) {
204  return emitOpError() << "value type " << getValue().getType()
205  << " must match return type: " << type;
206  }
207  // Integer values must be signless.
208  if (llvm::isa<IntegerType>(type) &&
209  !llvm::cast<IntegerType>(type).isSignless())
210  return emitOpError("integer return type must be signless");
211  // Any float or elements attribute are acceptable.
212  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
213  return emitOpError(
214  "value must be an integer, float, or elements attribute");
215  }
216 
217  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
218  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
219  // However, this would most likely require updating the lowerings to LLVM.
220  auto vecType = dyn_cast<VectorType>(type);
221  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
222  return emitOpError(
223  "intializing scalable vectors with elements attribute is not supported"
224  " unless it's a vector splat");
225  return success();
226 }
227 
228 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
229  // The value's type must be the same as the provided type.
230  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
231  if (!typedAttr || typedAttr.getType() != type)
232  return false;
233  // Integer values must be signless.
234  if (llvm::isa<IntegerType>(type) &&
235  !llvm::cast<IntegerType>(type).isSignless())
236  return false;
237  // Integer, float, and element attributes are buildable.
238  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
239 }
240 
241 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
242  Type type, Location loc) {
243  if (isBuildableWith(value, type))
244  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
245  return nullptr;
246 }
247 
248 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
249 
251  int64_t value, unsigned width) {
252  auto type = builder.getIntegerType(width);
253  arith::ConstantOp::build(builder, result, type,
254  builder.getIntegerAttr(type, value));
255 }
256 
258  int64_t value, Type type) {
259  assert(type.isSignlessInteger() &&
260  "ConstantIntOp can only have signless integer type values");
261  arith::ConstantOp::build(builder, result, type,
262  builder.getIntegerAttr(type, value));
263 }
264 
266  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
267  return constOp.getType().isSignlessInteger();
268  return false;
269 }
270 
272  const APFloat &value, FloatType type) {
273  arith::ConstantOp::build(builder, result, type,
274  builder.getFloatAttr(type, value));
275 }
276 
278  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
279  return llvm::isa<FloatType>(constOp.getType());
280  return false;
281 }
282 
284  int64_t value) {
285  arith::ConstantOp::build(builder, result, builder.getIndexType(),
286  builder.getIndexAttr(value));
287 }
288 
290  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
291  return constOp.getType().isIndex();
292  return false;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // AddIOp
297 //===----------------------------------------------------------------------===//
298 
299 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
300  // addi(x, 0) -> x
301  if (matchPattern(adaptor.getRhs(), m_Zero()))
302  return getLhs();
303 
304  // addi(subi(a, b), b) -> a
305  if (auto sub = getLhs().getDefiningOp<SubIOp>())
306  if (getRhs() == sub.getRhs())
307  return sub.getLhs();
308 
309  // addi(b, subi(a, b)) -> a
310  if (auto sub = getRhs().getDefiningOp<SubIOp>())
311  if (getLhs() == sub.getRhs())
312  return sub.getLhs();
313 
314  return constFoldBinaryOp<IntegerAttr>(
315  adaptor.getOperands(),
316  [](APInt a, const APInt &b) { return std::move(a) + b; });
317 }
318 
319 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
320  MLIRContext *context) {
321  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
322  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // AddUIExtendedOp
327 //===----------------------------------------------------------------------===//
328 
329 std::optional<SmallVector<int64_t, 4>>
330 arith::AddUIExtendedOp::getShapeForUnroll() {
331  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
332  return llvm::to_vector<4>(vt.getShape());
333  return std::nullopt;
334 }
335 
336 // Returns the overflow bit, assuming that `sum` is the result of unsigned
337 // addition of `operand` and another number.
338 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
339  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
340 }
341 
343 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
345  Type overflowTy = getOverflow().getType();
346  // addui_extended(x, 0) -> x, false
347  if (matchPattern(getRhs(), m_Zero())) {
348  Builder builder(getContext());
349  auto falseValue = builder.getZeroAttr(overflowTy);
350 
351  results.push_back(getLhs());
352  results.push_back(falseValue);
353  return success();
354  }
355 
356  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
357  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
358  // operands. If that succeeds, calculate the overflow bit based on the sum
359  // and the first (constant) operand, `lhs`.
360  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
361  adaptor.getOperands(),
362  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
363  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
364  ArrayRef({sumAttr, adaptor.getLhs()}),
365  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
367  if (!overflowAttr)
368  return failure();
369 
370  results.push_back(sumAttr);
371  results.push_back(overflowAttr);
372  return success();
373  }
374 
375  return failure();
376 }
377 
378 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
379  RewritePatternSet &patterns, MLIRContext *context) {
380  patterns.add<AddUIExtendedToAddI>(context);
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // SubIOp
385 //===----------------------------------------------------------------------===//
386 
387 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
388  // subi(x,x) -> 0
389  if (getOperand(0) == getOperand(1))
390  return Builder(getContext()).getZeroAttr(getType());
391  // subi(x,0) -> x
392  if (matchPattern(adaptor.getRhs(), m_Zero()))
393  return getLhs();
394 
395  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
396  // subi(addi(a, b), b) -> a
397  if (getRhs() == add.getRhs())
398  return add.getLhs();
399  // subi(addi(a, b), a) -> b
400  if (getRhs() == add.getLhs())
401  return add.getRhs();
402  }
403 
404  return constFoldBinaryOp<IntegerAttr>(
405  adaptor.getOperands(),
406  [](APInt a, const APInt &b) { return std::move(a) - b; });
407 }
408 
409 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
410  MLIRContext *context) {
411  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
412  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
413  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // MulIOp
418 //===----------------------------------------------------------------------===//
419 
420 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
421  // muli(x, 0) -> 0
422  if (matchPattern(adaptor.getRhs(), m_Zero()))
423  return getRhs();
424  // muli(x, 1) -> x
425  if (matchPattern(adaptor.getRhs(), m_One()))
426  return getLhs();
427  // TODO: Handle the overflow case.
428 
429  // default folder
430  return constFoldBinaryOp<IntegerAttr>(
431  adaptor.getOperands(),
432  [](const APInt &a, const APInt &b) { return a * b; });
433 }
434 
435 void arith::MulIOp::getAsmResultNames(
436  function_ref<void(Value, StringRef)> setNameFn) {
437  if (!isa<IndexType>(getType()))
438  return;
439 
440  // Match vector.vscale by name to avoid depending on the vector dialect (which
441  // is a circular dependency).
442  auto isVscale = [](Operation *op) {
443  return op && op->getName().getStringRef() == "vector.vscale";
444  };
445 
446  IntegerAttr baseValue;
447  auto isVscaleExpr = [&](Value a, Value b) {
448  return matchPattern(a, m_Constant(&baseValue)) &&
449  isVscale(b.getDefiningOp());
450  };
451 
452  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
453  return;
454 
455  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
456  SmallString<32> specialNameBuffer;
457  llvm::raw_svector_ostream specialName(specialNameBuffer);
458  specialName << 'c' << baseValue.getInt() << "_vscale";
459  setNameFn(getResult(), specialName.str());
460 }
461 
462 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
463  MLIRContext *context) {
464  patterns.add<MulIMulIConstant>(context);
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // MulSIExtendedOp
469 //===----------------------------------------------------------------------===//
470 
471 std::optional<SmallVector<int64_t, 4>>
472 arith::MulSIExtendedOp::getShapeForUnroll() {
473  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
474  return llvm::to_vector<4>(vt.getShape());
475  return std::nullopt;
476 }
477 
479 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
481  // mulsi_extended(x, 0) -> 0, 0
482  if (matchPattern(adaptor.getRhs(), m_Zero())) {
483  Attribute zero = adaptor.getRhs();
484  results.push_back(zero);
485  results.push_back(zero);
486  return success();
487  }
488 
489  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
490  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
491  adaptor.getOperands(),
492  [](const APInt &a, const APInt &b) { return a * b; })) {
493  // Invoke the constant fold helper again to calculate the 'high' result.
494  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
495  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
496  return llvm::APIntOps::mulhs(a, b);
497  });
498  assert(highAttr && "Unexpected constant-folding failure");
499 
500  results.push_back(lowAttr);
501  results.push_back(highAttr);
502  return success();
503  }
504 
505  return failure();
506 }
507 
508 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
509  RewritePatternSet &patterns, MLIRContext *context) {
510  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // MulUIExtendedOp
515 //===----------------------------------------------------------------------===//
516 
517 std::optional<SmallVector<int64_t, 4>>
518 arith::MulUIExtendedOp::getShapeForUnroll() {
519  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
520  return llvm::to_vector<4>(vt.getShape());
521  return std::nullopt;
522 }
523 
525 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
527  // mului_extended(x, 0) -> 0, 0
528  if (matchPattern(adaptor.getRhs(), m_Zero())) {
529  Attribute zero = adaptor.getRhs();
530  results.push_back(zero);
531  results.push_back(zero);
532  return success();
533  }
534 
535  // mului_extended(x, 1) -> x, 0
536  if (matchPattern(adaptor.getRhs(), m_One())) {
537  Builder builder(getContext());
538  Attribute zero = builder.getZeroAttr(getLhs().getType());
539  results.push_back(getLhs());
540  results.push_back(zero);
541  return success();
542  }
543 
544  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
545  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
546  adaptor.getOperands(),
547  [](const APInt &a, const APInt &b) { return a * b; })) {
548  // Invoke the constant fold helper again to calculate the 'high' result.
549  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
550  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
551  return llvm::APIntOps::mulhu(a, b);
552  });
553  assert(highAttr && "Unexpected constant-folding failure");
554 
555  results.push_back(lowAttr);
556  results.push_back(highAttr);
557  return success();
558  }
559 
560  return failure();
561 }
562 
563 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
564  RewritePatternSet &patterns, MLIRContext *context) {
565  patterns.add<MulUIExtendedToMulI>(context);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // DivUIOp
570 //===----------------------------------------------------------------------===//
571 
572 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
573  // divui (x, 1) -> x.
574  if (matchPattern(adaptor.getRhs(), m_One()))
575  return getLhs();
576 
577  // Don't fold if it would require a division by zero.
578  bool div0 = false;
579  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
580  [&](APInt a, const APInt &b) {
581  if (div0 || !b) {
582  div0 = true;
583  return a;
584  }
585  return a.udiv(b);
586  });
587 
588  return div0 ? Attribute() : result;
589 }
590 
591 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
592  // X / 0 => UB
593  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // DivSIOp
599 //===----------------------------------------------------------------------===//
600 
601 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
602  // divsi (x, 1) -> x.
603  if (matchPattern(adaptor.getRhs(), m_One()))
604  return getLhs();
605 
606  // Don't fold if it would overflow or if it requires a division by zero.
607  bool overflowOrDiv0 = false;
608  auto result = constFoldBinaryOp<IntegerAttr>(
609  adaptor.getOperands(), [&](APInt a, const APInt &b) {
610  if (overflowOrDiv0 || !b) {
611  overflowOrDiv0 = true;
612  return a;
613  }
614  return a.sdiv_ov(b, overflowOrDiv0);
615  });
616 
617  return overflowOrDiv0 ? Attribute() : result;
618 }
619 
620 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
621  bool mayHaveUB = true;
622 
623  APInt constRHS;
624  // X / 0 => UB
625  // INT_MIN / -1 => UB
626  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
627  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
628 
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // Ceil and floor division folding helpers
634 //===----------------------------------------------------------------------===//
635 
636 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
637  bool &overflow) {
638  // Returns (a-1)/b + 1
639  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
640  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
641  return val.sadd_ov(one, overflow);
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // CeilDivUIOp
646 //===----------------------------------------------------------------------===//
647 
648 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
649  // ceildivui (x, 1) -> x.
650  if (matchPattern(adaptor.getRhs(), m_One()))
651  return getLhs();
652 
653  bool overflowOrDiv0 = false;
654  auto result = constFoldBinaryOp<IntegerAttr>(
655  adaptor.getOperands(), [&](APInt a, const APInt &b) {
656  if (overflowOrDiv0 || !b) {
657  overflowOrDiv0 = true;
658  return a;
659  }
660  APInt quotient = a.udiv(b);
661  if (!a.urem(b))
662  return quotient;
663  APInt one(a.getBitWidth(), 1, true);
664  return quotient.uadd_ov(one, overflowOrDiv0);
665  });
666 
667  return overflowOrDiv0 ? Attribute() : result;
668 }
669 
670 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
671  // X / 0 => UB
672  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // CeilDivSIOp
678 //===----------------------------------------------------------------------===//
679 
680 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
681  // ceildivsi (x, 1) -> x.
682  if (matchPattern(adaptor.getRhs(), m_One()))
683  return getLhs();
684 
685  // Don't fold if it would overflow or if it requires a division by zero.
686  bool overflowOrDiv0 = false;
687  auto result = constFoldBinaryOp<IntegerAttr>(
688  adaptor.getOperands(), [&](APInt a, const APInt &b) {
689  if (overflowOrDiv0 || !b) {
690  overflowOrDiv0 = true;
691  return a;
692  }
693  if (!a)
694  return a;
695  // After this point we know that neither a or b are zero.
696  unsigned bits = a.getBitWidth();
697  APInt zero = APInt::getZero(bits);
698  bool aGtZero = a.sgt(zero);
699  bool bGtZero = b.sgt(zero);
700  if (aGtZero && bGtZero) {
701  // Both positive, return ceil(a, b).
702  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
703  }
704  if (!aGtZero && !bGtZero) {
705  // Both negative, return ceil(-a, -b).
706  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
707  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
708  return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
709  }
710  if (!aGtZero && bGtZero) {
711  // A is negative, b is positive, return - ( -a / b).
712  APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713  APInt div = posA.sdiv_ov(b, overflowOrDiv0);
714  return zero.ssub_ov(div, overflowOrDiv0);
715  }
716  // A is positive, b is negative, return - (a / -b).
717  APInt posB = zero.ssub_ov(b, overflowOrDiv0);
718  APInt div = a.sdiv_ov(posB, overflowOrDiv0);
719  return zero.ssub_ov(div, overflowOrDiv0);
720  });
721 
722  return overflowOrDiv0 ? Attribute() : result;
723 }
724 
725 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
726  bool mayHaveUB = true;
727 
728  APInt constRHS;
729  // X / 0 => UB
730  // INT_MIN / -1 => UB
731  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
732  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
733 
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // FloorDivSIOp
739 //===----------------------------------------------------------------------===//
740 
741 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
742  // floordivsi (x, 1) -> x.
743  if (matchPattern(adaptor.getRhs(), m_One()))
744  return getLhs();
745 
746  // Don't fold if it would overflow or if it requires a division by zero.
747  bool overflowOrDiv = false;
748  auto result = constFoldBinaryOp<IntegerAttr>(
749  adaptor.getOperands(), [&](APInt a, const APInt &b) {
750  if (b.isZero()) {
751  overflowOrDiv = true;
752  return a;
753  }
754  return a.sfloordiv_ov(b, overflowOrDiv);
755  });
756 
757  return overflowOrDiv ? Attribute() : result;
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // RemUIOp
762 //===----------------------------------------------------------------------===//
763 
764 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
765  // remui (x, 1) -> 0.
766  if (matchPattern(adaptor.getRhs(), m_One()))
767  return Builder(getContext()).getZeroAttr(getType());
768 
769  // Don't fold if it would require a division by zero.
770  bool div0 = false;
771  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
772  [&](APInt a, const APInt &b) {
773  if (div0 || b.isZero()) {
774  div0 = true;
775  return a;
776  }
777  return a.urem(b);
778  });
779 
780  return div0 ? Attribute() : result;
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // RemSIOp
785 //===----------------------------------------------------------------------===//
786 
787 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
788  // remsi (x, 1) -> 0.
789  if (matchPattern(adaptor.getRhs(), m_One()))
790  return Builder(getContext()).getZeroAttr(getType());
791 
792  // Don't fold if it would require a division by zero.
793  bool div0 = false;
794  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
795  [&](APInt a, const APInt &b) {
796  if (div0 || b.isZero()) {
797  div0 = true;
798  return a;
799  }
800  return a.srem(b);
801  });
802 
803  return div0 ? Attribute() : result;
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // AndIOp
808 //===----------------------------------------------------------------------===//
809 
810 /// Fold `and(a, and(a, b))` to `and(a, b)`
811 static Value foldAndIofAndI(arith::AndIOp op) {
812  for (bool reversePrev : {false, true}) {
813  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
814  .getDefiningOp<arith::AndIOp>();
815  if (!prev)
816  continue;
817 
818  Value other = (reversePrev ? op.getLhs() : op.getRhs());
819  if (other != prev.getLhs() && other != prev.getRhs())
820  continue;
821 
822  return prev.getResult();
823  }
824  return {};
825 }
826 
827 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
828  /// and(x, 0) -> 0
829  if (matchPattern(adaptor.getRhs(), m_Zero()))
830  return getRhs();
831  /// and(x, allOnes) -> x
832  APInt intValue;
833  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
834  intValue.isAllOnes())
835  return getLhs();
836  /// and(x, not(x)) -> 0
837  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
838  m_ConstantInt(&intValue))) &&
839  intValue.isAllOnes())
840  return Builder(getContext()).getZeroAttr(getType());
841  /// and(not(x), x) -> 0
842  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
843  m_ConstantInt(&intValue))) &&
844  intValue.isAllOnes())
845  return Builder(getContext()).getZeroAttr(getType());
846 
847  /// and(a, and(a, b)) -> and(a, b)
848  if (Value result = foldAndIofAndI(*this))
849  return result;
850 
851  return constFoldBinaryOp<IntegerAttr>(
852  adaptor.getOperands(),
853  [](APInt a, const APInt &b) { return std::move(a) & b; });
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // OrIOp
858 //===----------------------------------------------------------------------===//
859 
860 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
861  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
862  /// or(x, 0) -> x
863  if (rhsVal.isZero())
864  return getLhs();
865  /// or(x, <all ones>) -> <all ones>
866  if (rhsVal.isAllOnes())
867  return adaptor.getRhs();
868  }
869 
870  APInt intValue;
871  /// or(x, xor(x, 1)) -> 1
872  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
873  m_ConstantInt(&intValue))) &&
874  intValue.isAllOnes())
875  return getRhs().getDefiningOp<XOrIOp>().getRhs();
876  /// or(xor(x, 1), x) -> 1
877  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
878  m_ConstantInt(&intValue))) &&
879  intValue.isAllOnes())
880  return getLhs().getDefiningOp<XOrIOp>().getRhs();
881 
882  return constFoldBinaryOp<IntegerAttr>(
883  adaptor.getOperands(),
884  [](APInt a, const APInt &b) { return std::move(a) | b; });
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // XOrIOp
889 //===----------------------------------------------------------------------===//
890 
891 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
892  /// xor(x, 0) -> x
893  if (matchPattern(adaptor.getRhs(), m_Zero()))
894  return getLhs();
895  /// xor(x, x) -> 0
896  if (getLhs() == getRhs())
897  return Builder(getContext()).getZeroAttr(getType());
898  /// xor(xor(x, a), a) -> x
899  /// xor(xor(a, x), a) -> x
900  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
901  if (prev.getRhs() == getRhs())
902  return prev.getLhs();
903  if (prev.getLhs() == getRhs())
904  return prev.getRhs();
905  }
906  /// xor(a, xor(x, a)) -> x
907  /// xor(a, xor(a, x)) -> x
908  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
909  if (prev.getRhs() == getLhs())
910  return prev.getLhs();
911  if (prev.getLhs() == getLhs())
912  return prev.getRhs();
913  }
914 
915  return constFoldBinaryOp<IntegerAttr>(
916  adaptor.getOperands(),
917  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
918 }
919 
920 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
921  MLIRContext *context) {
922  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // NegFOp
927 //===----------------------------------------------------------------------===//
928 
929 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
930  /// negf(negf(x)) -> x
931  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
932  return op.getOperand();
933  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
934  [](const APFloat &a) { return -a; });
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // AddFOp
939 //===----------------------------------------------------------------------===//
940 
941 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
942  // addf(x, -0) -> x
943  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
944  return getLhs();
945 
946  return constFoldBinaryOp<FloatAttr>(
947  adaptor.getOperands(),
948  [](const APFloat &a, const APFloat &b) { return a + b; });
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // SubFOp
953 //===----------------------------------------------------------------------===//
954 
955 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
956  // subf(x, +0) -> x
957  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
958  return getLhs();
959 
960  return constFoldBinaryOp<FloatAttr>(
961  adaptor.getOperands(),
962  [](const APFloat &a, const APFloat &b) { return a - b; });
963 }
964 
965 //===----------------------------------------------------------------------===//
966 // MaximumFOp
967 //===----------------------------------------------------------------------===//
968 
969 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
970  // maximumf(x,x) -> x
971  if (getLhs() == getRhs())
972  return getRhs();
973 
974  // maximumf(x, -inf) -> x
975  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
976  return getLhs();
977 
978  return constFoldBinaryOp<FloatAttr>(
979  adaptor.getOperands(),
980  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // MaxNumFOp
985 //===----------------------------------------------------------------------===//
986 
987 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
988  // maxnumf(x,x) -> x
989  if (getLhs() == getRhs())
990  return getRhs();
991 
992  // maxnumf(x, -inf) -> x
993  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
994  return getLhs();
995 
996  return constFoldBinaryOp<FloatAttr>(
997  adaptor.getOperands(),
998  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // MaxSIOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1006  // maxsi(x,x) -> x
1007  if (getLhs() == getRhs())
1008  return getRhs();
1009 
1010  if (APInt intValue;
1011  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1012  // maxsi(x,MAX_INT) -> MAX_INT
1013  if (intValue.isMaxSignedValue())
1014  return getRhs();
1015  // maxsi(x, MIN_INT) -> x
1016  if (intValue.isMinSignedValue())
1017  return getLhs();
1018  }
1019 
1020  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1021  [](const APInt &a, const APInt &b) {
1022  return llvm::APIntOps::smax(a, b);
1023  });
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // MaxUIOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1031  // maxui(x,x) -> x
1032  if (getLhs() == getRhs())
1033  return getRhs();
1034 
1035  if (APInt intValue;
1036  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1037  // maxui(x,MAX_INT) -> MAX_INT
1038  if (intValue.isMaxValue())
1039  return getRhs();
1040  // maxui(x, MIN_INT) -> x
1041  if (intValue.isMinValue())
1042  return getLhs();
1043  }
1044 
1045  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046  [](const APInt &a, const APInt &b) {
1047  return llvm::APIntOps::umax(a, b);
1048  });
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // MinimumFOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1056  // minimumf(x,x) -> x
1057  if (getLhs() == getRhs())
1058  return getRhs();
1059 
1060  // minimumf(x, +inf) -> x
1061  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1062  return getLhs();
1063 
1064  return constFoldBinaryOp<FloatAttr>(
1065  adaptor.getOperands(),
1066  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1067 }
1068 
1069 //===----------------------------------------------------------------------===//
1070 // MinNumFOp
1071 //===----------------------------------------------------------------------===//
1072 
1073 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1074  // minnumf(x,x) -> x
1075  if (getLhs() == getRhs())
1076  return getRhs();
1077 
1078  // minnumf(x, +inf) -> x
1079  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1080  return getLhs();
1081 
1082  return constFoldBinaryOp<FloatAttr>(
1083  adaptor.getOperands(),
1084  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // MinSIOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1092  // minsi(x,x) -> x
1093  if (getLhs() == getRhs())
1094  return getRhs();
1095 
1096  if (APInt intValue;
1097  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1098  // minsi(x,MIN_INT) -> MIN_INT
1099  if (intValue.isMinSignedValue())
1100  return getRhs();
1101  // minsi(x, MAX_INT) -> x
1102  if (intValue.isMaxSignedValue())
1103  return getLhs();
1104  }
1105 
1106  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1107  [](const APInt &a, const APInt &b) {
1108  return llvm::APIntOps::smin(a, b);
1109  });
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // MinUIOp
1114 //===----------------------------------------------------------------------===//
1115 
1116 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1117  // minui(x,x) -> x
1118  if (getLhs() == getRhs())
1119  return getRhs();
1120 
1121  if (APInt intValue;
1122  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1123  // minui(x,MIN_INT) -> MIN_INT
1124  if (intValue.isMinValue())
1125  return getRhs();
1126  // minui(x, MAX_INT) -> x
1127  if (intValue.isMaxValue())
1128  return getLhs();
1129  }
1130 
1131  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1132  [](const APInt &a, const APInt &b) {
1133  return llvm::APIntOps::umin(a, b);
1134  });
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // MulFOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1142  // mulf(x, 1) -> x
1143  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1144  return getLhs();
1145 
1146  return constFoldBinaryOp<FloatAttr>(
1147  adaptor.getOperands(),
1148  [](const APFloat &a, const APFloat &b) { return a * b; });
1149 }
1150 
1151 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1152  MLIRContext *context) {
1153  patterns.add<MulFOfNegF>(context);
1154 }
1155 
1156 //===----------------------------------------------------------------------===//
1157 // DivFOp
1158 //===----------------------------------------------------------------------===//
1159 
1160 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1161  // divf(x, 1) -> x
1162  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1163  return getLhs();
1164 
1165  return constFoldBinaryOp<FloatAttr>(
1166  adaptor.getOperands(),
1167  [](const APFloat &a, const APFloat &b) { return a / b; });
1168 }
1169 
1170 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1171  MLIRContext *context) {
1172  patterns.add<DivFOfNegF>(context);
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // RemFOp
1177 //===----------------------------------------------------------------------===//
1178 
1179 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1180  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1181  [](const APFloat &a, const APFloat &b) {
1182  APFloat result(a);
1183  (void)result.remainder(b);
1184  return result;
1185  });
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // Utility functions for verifying cast ops
1190 //===----------------------------------------------------------------------===//
1191 
1192 template <typename... Types>
1193 using type_list = std::tuple<Types...> *;
1194 
1195 /// Returns a non-null type only if the provided type is one of the allowed
1196 /// types or one of the allowed shaped types of the allowed types. Returns the
1197 /// element type if a valid shaped type is provided.
1198 template <typename... ShapedTypes, typename... ElementTypes>
1201  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1202  return {};
1203 
1204  auto underlyingType = getElementTypeOrSelf(type);
1205  if (!llvm::isa<ElementTypes...>(underlyingType))
1206  return {};
1207 
1208  return underlyingType;
1209 }
1210 
1211 /// Get allowed underlying types for vectors and tensors.
1212 template <typename... ElementTypes>
1213 static Type getTypeIfLike(Type type) {
1216 }
1217 
1218 /// Get allowed underlying types for vectors, tensors, and memrefs.
1219 template <typename... ElementTypes>
1221  return getUnderlyingType(type,
1224 }
1225 
1226 /// Return false if both types are ranked tensor with mismatching encoding.
1227 static bool hasSameEncoding(Type typeA, Type typeB) {
1228  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1229  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1230  if (!rankedTensorA || !rankedTensorB)
1231  return true;
1232  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1233 }
1234 
1235 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1236  if (inputs.size() != 1 || outputs.size() != 1)
1237  return false;
1238  if (!hasSameEncoding(inputs.front(), outputs.front()))
1239  return false;
1240  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // Verifiers for integer and floating point extension/truncation ops
1245 //===----------------------------------------------------------------------===//
1246 
1247 // Extend ops can only extend to a wider type.
1248 template <typename ValType, typename Op>
1250  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1251  Type dstType = getElementTypeOrSelf(op.getType());
1252 
1253  if (llvm::cast<ValType>(srcType).getWidth() >=
1254  llvm::cast<ValType>(dstType).getWidth())
1255  return op.emitError("result type ")
1256  << dstType << " must be wider than operand type " << srcType;
1257 
1258  return success();
1259 }
1260 
1261 // Truncate ops can only truncate to a shorter type.
1262 template <typename ValType, typename Op>
1264  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1265  Type dstType = getElementTypeOrSelf(op.getType());
1266 
1267  if (llvm::cast<ValType>(srcType).getWidth() <=
1268  llvm::cast<ValType>(dstType).getWidth())
1269  return op.emitError("result type ")
1270  << dstType << " must be shorter than operand type " << srcType;
1271 
1272  return success();
1273 }
1274 
1275 /// Validate a cast that changes the width of a type.
1276 template <template <typename> class WidthComparator, typename... ElementTypes>
1277 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1278  if (!areValidCastInputsAndOutputs(inputs, outputs))
1279  return false;
1280 
1281  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1282  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1283  if (!srcType || !dstType)
1284  return false;
1285 
1286  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1287  srcType.getIntOrFloatBitWidth());
1288 }
1289 
1290 /// Attempts to convert `sourceValue` to an APFloat value with
1291 /// `targetSemantics` and `roundingMode`, without any information loss.
1293  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1294  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1295  bool losesInfo = false;
1296  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1297  if (losesInfo || status != APFloat::opOK)
1298  return failure();
1299 
1300  return sourceValue;
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // ExtUIOp
1305 //===----------------------------------------------------------------------===//
1306 
1307 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1308  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1309  getInMutable().assign(lhs.getIn());
1310  return getResult();
1311  }
1312 
1313  Type resType = getElementTypeOrSelf(getType());
1314  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1315  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1316  adaptor.getOperands(), getType(),
1317  [bitWidth](const APInt &a, bool &castStatus) {
1318  return a.zext(bitWidth);
1319  });
1320 }
1321 
1322 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1323  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1324 }
1325 
1327  return verifyExtOp<IntegerType>(*this);
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // ExtSIOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1335  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1336  getInMutable().assign(lhs.getIn());
1337  return getResult();
1338  }
1339 
1340  Type resType = getElementTypeOrSelf(getType());
1341  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1342  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1343  adaptor.getOperands(), getType(),
1344  [bitWidth](const APInt &a, bool &castStatus) {
1345  return a.sext(bitWidth);
1346  });
1347 }
1348 
1349 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1350  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1351 }
1352 
1353 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1354  MLIRContext *context) {
1355  patterns.add<ExtSIOfExtUI>(context);
1356 }
1357 
1359  return verifyExtOp<IntegerType>(*this);
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // ExtFOp
1364 //===----------------------------------------------------------------------===//
1365 
1366 /// Fold extension of float constants when there is no information loss due the
1367 /// difference in fp semantics.
1368 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1369  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1370  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1371  return constFoldCastOp<FloatAttr, FloatAttr>(
1372  adaptor.getOperands(), getType(),
1373  [&targetSemantics](const APFloat &a, bool &castStatus) {
1374  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1375  if (failed(result)) {
1376  castStatus = false;
1377  return a;
1378  }
1379  return *result;
1380  });
1381 }
1382 
1383 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1384  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1385 }
1386 
1387 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1388 
1389 //===----------------------------------------------------------------------===//
1390 // TruncIOp
1391 //===----------------------------------------------------------------------===//
1392 
1393 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1394  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1395  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1396  Value src = getOperand().getDefiningOp()->getOperand(0);
1397  Type srcType = getElementTypeOrSelf(src.getType());
1398  Type dstType = getElementTypeOrSelf(getType());
1399  // trunci(zexti(a)) -> trunci(a)
1400  // trunci(sexti(a)) -> trunci(a)
1401  if (llvm::cast<IntegerType>(srcType).getWidth() >
1402  llvm::cast<IntegerType>(dstType).getWidth()) {
1403  setOperand(src);
1404  return getResult();
1405  }
1406 
1407  // trunci(zexti(a)) -> a
1408  // trunci(sexti(a)) -> a
1409  if (srcType == dstType)
1410  return src;
1411  }
1412 
1413  // trunci(trunci(a)) -> trunci(a))
1414  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1415  setOperand(getOperand().getDefiningOp()->getOperand(0));
1416  return getResult();
1417  }
1418 
1419  Type resType = getElementTypeOrSelf(getType());
1420  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1421  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1422  adaptor.getOperands(), getType(),
1423  [bitWidth](const APInt &a, bool &castStatus) {
1424  return a.trunc(bitWidth);
1425  });
1426 }
1427 
1428 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1429  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1430 }
1431 
1432 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1433  MLIRContext *context) {
1434  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1435  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1436  context);
1437 }
1438 
1440  return verifyTruncateOp<IntegerType>(*this);
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // TruncFOp
1445 //===----------------------------------------------------------------------===//
1446 
1447 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1448 /// can be represented without precision loss.
1449 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1450  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1451  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1452  return constFoldCastOp<FloatAttr, FloatAttr>(
1453  adaptor.getOperands(), getType(),
1454  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1455  RoundingMode roundingMode =
1456  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1457  llvm::RoundingMode llvmRoundingMode =
1458  convertArithRoundingModeToLLVMIR(roundingMode);
1459  FailureOr<APFloat> result =
1460  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1461  if (failed(result)) {
1462  castStatus = false;
1463  return a;
1464  }
1465  return *result;
1466  });
1467 }
1468 
1469 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1470  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1471 }
1472 
1474  return verifyTruncateOp<FloatType>(*this);
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // AndIOp
1479 //===----------------------------------------------------------------------===//
1480 
1481 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1482  MLIRContext *context) {
1483  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // OrIOp
1488 //===----------------------------------------------------------------------===//
1489 
1490 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1491  MLIRContext *context) {
1492  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1493 }
1494 
1495 //===----------------------------------------------------------------------===//
1496 // Verifiers for casts between integers and floats.
1497 //===----------------------------------------------------------------------===//
1498 
1499 template <typename From, typename To>
1500 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1501  if (!areValidCastInputsAndOutputs(inputs, outputs))
1502  return false;
1503 
1504  auto srcType = getTypeIfLike<From>(inputs.front());
1505  auto dstType = getTypeIfLike<To>(outputs.back());
1506 
1507  return srcType && dstType;
1508 }
1509 
1510 //===----------------------------------------------------------------------===//
1511 // UIToFPOp
1512 //===----------------------------------------------------------------------===//
1513 
1514 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1515  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1516 }
1517 
1518 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1519  Type resEleType = getElementTypeOrSelf(getType());
1520  return constFoldCastOp<IntegerAttr, FloatAttr>(
1521  adaptor.getOperands(), getType(),
1522  [&resEleType](const APInt &a, bool &castStatus) {
1523  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1524  APFloat apf(floatTy.getFloatSemantics(),
1525  APInt::getZero(floatTy.getWidth()));
1526  apf.convertFromAPInt(a, /*IsSigned=*/false,
1527  APFloat::rmNearestTiesToEven);
1528  return apf;
1529  });
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // SIToFPOp
1534 //===----------------------------------------------------------------------===//
1535 
1536 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1537  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1538 }
1539 
1540 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1541  Type resEleType = getElementTypeOrSelf(getType());
1542  return constFoldCastOp<IntegerAttr, FloatAttr>(
1543  adaptor.getOperands(), getType(),
1544  [&resEleType](const APInt &a, bool &castStatus) {
1545  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1546  APFloat apf(floatTy.getFloatSemantics(),
1547  APInt::getZero(floatTy.getWidth()));
1548  apf.convertFromAPInt(a, /*IsSigned=*/true,
1549  APFloat::rmNearestTiesToEven);
1550  return apf;
1551  });
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // FPToUIOp
1556 //===----------------------------------------------------------------------===//
1557 
1558 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1559  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1560 }
1561 
1562 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1563  Type resType = getElementTypeOrSelf(getType());
1564  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1565  return constFoldCastOp<FloatAttr, IntegerAttr>(
1566  adaptor.getOperands(), getType(),
1567  [&bitWidth](const APFloat &a, bool &castStatus) {
1568  bool ignored;
1569  APSInt api(bitWidth, /*isUnsigned=*/true);
1570  castStatus = APFloat::opInvalidOp !=
1571  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1572  return api;
1573  });
1574 }
1575 
1576 //===----------------------------------------------------------------------===//
1577 // FPToSIOp
1578 //===----------------------------------------------------------------------===//
1579 
1580 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1581  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1582 }
1583 
1584 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1585  Type resType = getElementTypeOrSelf(getType());
1586  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1587  return constFoldCastOp<FloatAttr, IntegerAttr>(
1588  adaptor.getOperands(), getType(),
1589  [&bitWidth](const APFloat &a, bool &castStatus) {
1590  bool ignored;
1591  APSInt api(bitWidth, /*isUnsigned=*/false);
1592  castStatus = APFloat::opInvalidOp !=
1593  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1594  return api;
1595  });
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // IndexCastOp
1600 //===----------------------------------------------------------------------===//
1601 
1602 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1603  if (!areValidCastInputsAndOutputs(inputs, outputs))
1604  return false;
1605 
1606  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1607  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1608  if (!srcType || !dstType)
1609  return false;
1610 
1611  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1612  (srcType.isSignlessInteger() && dstType.isIndex());
1613 }
1614 
1615 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1616  TypeRange outputs) {
1617  return areIndexCastCompatible(inputs, outputs);
1618 }
1619 
1620 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1621  // index_cast(constant) -> constant
1622  unsigned resultBitwidth = 64; // Default for index integer attributes.
1623  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1624  resultBitwidth = intTy.getWidth();
1625 
1626  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1627  adaptor.getOperands(), getType(),
1628  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1629  return a.sextOrTrunc(resultBitwidth);
1630  });
1631 }
1632 
1633 void arith::IndexCastOp::getCanonicalizationPatterns(
1634  RewritePatternSet &patterns, MLIRContext *context) {
1635  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1636 }
1637 
1638 //===----------------------------------------------------------------------===//
1639 // IndexCastUIOp
1640 //===----------------------------------------------------------------------===//
1641 
1642 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1643  TypeRange outputs) {
1644  return areIndexCastCompatible(inputs, outputs);
1645 }
1646 
1647 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1648  // index_castui(constant) -> constant
1649  unsigned resultBitwidth = 64; // Default for index integer attributes.
1650  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1651  resultBitwidth = intTy.getWidth();
1652 
1653  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1654  adaptor.getOperands(), getType(),
1655  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1656  return a.zextOrTrunc(resultBitwidth);
1657  });
1658 }
1659 
1660 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1661  RewritePatternSet &patterns, MLIRContext *context) {
1662  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1663 }
1664 
1665 //===----------------------------------------------------------------------===//
1666 // BitcastOp
1667 //===----------------------------------------------------------------------===//
1668 
1669 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1670  if (!areValidCastInputsAndOutputs(inputs, outputs))
1671  return false;
1672 
1673  auto srcType =
1674  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1675  auto dstType =
1676  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1677  if (!srcType || !dstType)
1678  return false;
1679 
1680  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1681 }
1682 
1683 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1684  auto resType = getType();
1685  auto operand = adaptor.getIn();
1686  if (!operand)
1687  return {};
1688 
1689  /// Bitcast dense elements.
1690  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1691  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1692  /// Other shaped types unhandled.
1693  if (llvm::isa<ShapedType>(resType))
1694  return {};
1695 
1696  /// Bitcast integer or float to integer or float.
1697  APInt bits = llvm::isa<FloatAttr>(operand)
1698  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1699  : llvm::cast<IntegerAttr>(operand).getValue();
1700 
1701  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1702  return FloatAttr::get(resType,
1703  APFloat(resFloatType.getFloatSemantics(), bits));
1704  return IntegerAttr::get(resType, bits);
1705 }
1706 
1707 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1708  MLIRContext *context) {
1709  patterns.add<BitcastOfBitcast>(context);
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // CmpIOp
1714 //===----------------------------------------------------------------------===//
1715 
1716 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1717 /// comparison predicates.
1718 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1719  const APInt &lhs, const APInt &rhs) {
1720  switch (predicate) {
1721  case arith::CmpIPredicate::eq:
1722  return lhs.eq(rhs);
1723  case arith::CmpIPredicate::ne:
1724  return lhs.ne(rhs);
1725  case arith::CmpIPredicate::slt:
1726  return lhs.slt(rhs);
1727  case arith::CmpIPredicate::sle:
1728  return lhs.sle(rhs);
1729  case arith::CmpIPredicate::sgt:
1730  return lhs.sgt(rhs);
1731  case arith::CmpIPredicate::sge:
1732  return lhs.sge(rhs);
1733  case arith::CmpIPredicate::ult:
1734  return lhs.ult(rhs);
1735  case arith::CmpIPredicate::ule:
1736  return lhs.ule(rhs);
1737  case arith::CmpIPredicate::ugt:
1738  return lhs.ugt(rhs);
1739  case arith::CmpIPredicate::uge:
1740  return lhs.uge(rhs);
1741  }
1742  llvm_unreachable("unknown cmpi predicate kind");
1743 }
1744 
1745 /// Returns true if the predicate is true for two equal operands.
1746 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1747  switch (predicate) {
1748  case arith::CmpIPredicate::eq:
1749  case arith::CmpIPredicate::sle:
1750  case arith::CmpIPredicate::sge:
1751  case arith::CmpIPredicate::ule:
1752  case arith::CmpIPredicate::uge:
1753  return true;
1754  case arith::CmpIPredicate::ne:
1755  case arith::CmpIPredicate::slt:
1756  case arith::CmpIPredicate::sgt:
1757  case arith::CmpIPredicate::ult:
1758  case arith::CmpIPredicate::ugt:
1759  return false;
1760  }
1761  llvm_unreachable("unknown cmpi predicate kind");
1762 }
1763 
1764 static std::optional<int64_t> getIntegerWidth(Type t) {
1765  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1766  return intType.getWidth();
1767  }
1768  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1769  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1770  }
1771  return std::nullopt;
1772 }
1773 
1774 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1775  // cmpi(pred, x, x)
1776  if (getLhs() == getRhs()) {
1777  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1778  return getBoolAttribute(getType(), val);
1779  }
1780 
1781  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1782  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1783  // extsi(%x : i1 -> iN) != 0 -> %x
1784  std::optional<int64_t> integerWidth =
1785  getIntegerWidth(extOp.getOperand().getType());
1786  if (integerWidth && integerWidth.value() == 1 &&
1787  getPredicate() == arith::CmpIPredicate::ne)
1788  return extOp.getOperand();
1789  }
1790  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1791  // extui(%x : i1 -> iN) != 0 -> %x
1792  std::optional<int64_t> integerWidth =
1793  getIntegerWidth(extOp.getOperand().getType());
1794  if (integerWidth && integerWidth.value() == 1 &&
1795  getPredicate() == arith::CmpIPredicate::ne)
1796  return extOp.getOperand();
1797  }
1798  }
1799 
1800  // Move constant to the right side.
1801  if (adaptor.getLhs() && !adaptor.getRhs()) {
1802  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1803  using Pred = CmpIPredicate;
1804  const std::pair<Pred, Pred> invPreds[] = {
1805  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1806  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1807  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1808  {Pred::ne, Pred::ne},
1809  };
1810  Pred origPred = getPredicate();
1811  for (auto pred : invPreds) {
1812  if (origPred == pred.first) {
1813  setPredicate(pred.second);
1814  Value lhs = getLhs();
1815  Value rhs = getRhs();
1816  getLhsMutable().assign(rhs);
1817  getRhsMutable().assign(lhs);
1818  return getResult();
1819  }
1820  }
1821  llvm_unreachable("unknown cmpi predicate kind");
1822  }
1823 
1824  // We are moving constants to the right side; So if lhs is constant rhs is
1825  // guaranteed to be a constant.
1826  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1827  return constFoldBinaryOp<IntegerAttr>(
1828  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1829  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1830  return APInt(1,
1831  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1832  });
1833  }
1834 
1835  return {};
1836 }
1837 
1838 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1839  MLIRContext *context) {
1840  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // CmpFOp
1845 //===----------------------------------------------------------------------===//
1846 
1847 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1848 /// comparison predicates.
1849 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1850  const APFloat &lhs, const APFloat &rhs) {
1851  auto cmpResult = lhs.compare(rhs);
1852  switch (predicate) {
1853  case arith::CmpFPredicate::AlwaysFalse:
1854  return false;
1855  case arith::CmpFPredicate::OEQ:
1856  return cmpResult == APFloat::cmpEqual;
1857  case arith::CmpFPredicate::OGT:
1858  return cmpResult == APFloat::cmpGreaterThan;
1859  case arith::CmpFPredicate::OGE:
1860  return cmpResult == APFloat::cmpGreaterThan ||
1861  cmpResult == APFloat::cmpEqual;
1862  case arith::CmpFPredicate::OLT:
1863  return cmpResult == APFloat::cmpLessThan;
1864  case arith::CmpFPredicate::OLE:
1865  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1866  case arith::CmpFPredicate::ONE:
1867  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1868  case arith::CmpFPredicate::ORD:
1869  return cmpResult != APFloat::cmpUnordered;
1870  case arith::CmpFPredicate::UEQ:
1871  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1872  case arith::CmpFPredicate::UGT:
1873  return cmpResult == APFloat::cmpUnordered ||
1874  cmpResult == APFloat::cmpGreaterThan;
1875  case arith::CmpFPredicate::UGE:
1876  return cmpResult == APFloat::cmpUnordered ||
1877  cmpResult == APFloat::cmpGreaterThan ||
1878  cmpResult == APFloat::cmpEqual;
1879  case arith::CmpFPredicate::ULT:
1880  return cmpResult == APFloat::cmpUnordered ||
1881  cmpResult == APFloat::cmpLessThan;
1882  case arith::CmpFPredicate::ULE:
1883  return cmpResult == APFloat::cmpUnordered ||
1884  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1885  case arith::CmpFPredicate::UNE:
1886  return cmpResult != APFloat::cmpEqual;
1887  case arith::CmpFPredicate::UNO:
1888  return cmpResult == APFloat::cmpUnordered;
1889  case arith::CmpFPredicate::AlwaysTrue:
1890  return true;
1891  }
1892  llvm_unreachable("unknown cmpf predicate kind");
1893 }
1894 
1895 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1896  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1897  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1898 
1899  // If one operand is NaN, making them both NaN does not change the result.
1900  if (lhs && lhs.getValue().isNaN())
1901  rhs = lhs;
1902  if (rhs && rhs.getValue().isNaN())
1903  lhs = rhs;
1904 
1905  if (!lhs || !rhs)
1906  return {};
1907 
1908  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1909  return BoolAttr::get(getContext(), val);
1910 }
1911 
1912 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1913 public:
1915 
1916  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1917  bool isUnsigned) {
1918  using namespace arith;
1919  switch (pred) {
1920  case CmpFPredicate::UEQ:
1921  case CmpFPredicate::OEQ:
1922  return CmpIPredicate::eq;
1923  case CmpFPredicate::UGT:
1924  case CmpFPredicate::OGT:
1925  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1926  case CmpFPredicate::UGE:
1927  case CmpFPredicate::OGE:
1928  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1929  case CmpFPredicate::ULT:
1930  case CmpFPredicate::OLT:
1931  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1932  case CmpFPredicate::ULE:
1933  case CmpFPredicate::OLE:
1934  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1935  case CmpFPredicate::UNE:
1936  case CmpFPredicate::ONE:
1937  return CmpIPredicate::ne;
1938  default:
1939  llvm_unreachable("Unexpected predicate!");
1940  }
1941  }
1942 
1944  PatternRewriter &rewriter) const override {
1945  FloatAttr flt;
1946  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1947  return failure();
1948 
1949  const APFloat &rhs = flt.getValue();
1950 
1951  // Don't attempt to fold a nan.
1952  if (rhs.isNaN())
1953  return failure();
1954 
1955  // Get the width of the mantissa. We don't want to hack on conversions that
1956  // might lose information from the integer, e.g. "i64 -> float"
1957  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1958  int mantissaWidth = floatTy.getFPMantissaWidth();
1959  if (mantissaWidth <= 0)
1960  return failure();
1961 
1962  bool isUnsigned;
1963  Value intVal;
1964 
1965  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1966  isUnsigned = false;
1967  intVal = si.getIn();
1968  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1969  isUnsigned = true;
1970  intVal = ui.getIn();
1971  } else {
1972  return failure();
1973  }
1974 
1975  // Check to see that the input is converted from an integer type that is
1976  // small enough that preserves all bits.
1977  auto intTy = llvm::cast<IntegerType>(intVal.getType());
1978  auto intWidth = intTy.getWidth();
1979 
1980  // Number of bits representing values, as opposed to the sign
1981  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1982 
1983  // Following test does NOT adjust intWidth downwards for signed inputs,
1984  // because the most negative value still requires all the mantissa bits
1985  // to distinguish it from one less than that value.
1986  if ((int)intWidth > mantissaWidth) {
1987  // Conversion would lose accuracy. Check if loss can impact comparison.
1988  int exponent = ilogb(rhs);
1989  if (exponent == APFloat::IEK_Inf) {
1990  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1991  if (maxExponent < (int)valueBits) {
1992  // Conversion could create infinity.
1993  return failure();
1994  }
1995  } else {
1996  // Note that if rhs is zero or NaN, then Exp is negative
1997  // and first condition is trivially false.
1998  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1999  // Conversion could affect comparison.
2000  return failure();
2001  }
2002  }
2003  }
2004 
2005  // Convert to equivalent cmpi predicate
2006  CmpIPredicate pred;
2007  switch (op.getPredicate()) {
2008  case CmpFPredicate::ORD:
2009  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2010  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2011  /*width=*/1);
2012  return success();
2013  case CmpFPredicate::UNO:
2014  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2015  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2016  /*width=*/1);
2017  return success();
2018  default:
2019  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2020  break;
2021  }
2022 
2023  if (!isUnsigned) {
2024  // If the rhs value is > SignedMax, fold the comparison. This handles
2025  // +INF and large values.
2026  APFloat signedMax(rhs.getSemantics());
2027  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2028  APFloat::rmNearestTiesToEven);
2029  if (signedMax < rhs) { // smax < 13123.0
2030  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2031  pred == CmpIPredicate::sle)
2032  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2033  /*width=*/1);
2034  else
2035  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2036  /*width=*/1);
2037  return success();
2038  }
2039  } else {
2040  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2041  // +INF and large values.
2042  APFloat unsignedMax(rhs.getSemantics());
2043  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2044  APFloat::rmNearestTiesToEven);
2045  if (unsignedMax < rhs) { // umax < 13123.0
2046  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2047  pred == CmpIPredicate::ule)
2048  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2049  /*width=*/1);
2050  else
2051  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2052  /*width=*/1);
2053  return success();
2054  }
2055  }
2056 
2057  if (!isUnsigned) {
2058  // See if the rhs value is < SignedMin.
2059  APFloat signedMin(rhs.getSemantics());
2060  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2061  APFloat::rmNearestTiesToEven);
2062  if (signedMin > rhs) { // smin > 12312.0
2063  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2064  pred == CmpIPredicate::sge)
2065  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2066  /*width=*/1);
2067  else
2068  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2069  /*width=*/1);
2070  return success();
2071  }
2072  } else {
2073  // See if the rhs value is < UnsignedMin.
2074  APFloat unsignedMin(rhs.getSemantics());
2075  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2076  APFloat::rmNearestTiesToEven);
2077  if (unsignedMin > rhs) { // umin > 12312.0
2078  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2079  pred == CmpIPredicate::uge)
2080  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2081  /*width=*/1);
2082  else
2083  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2084  /*width=*/1);
2085  return success();
2086  }
2087  }
2088 
2089  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2090  // [0, UMAX], but it may still be fractional. See if it is fractional by
2091  // casting the FP value to the integer value and back, checking for
2092  // equality. Don't do this for zero, because -0.0 is not fractional.
2093  bool ignored;
2094  APSInt rhsInt(intWidth, isUnsigned);
2095  if (APFloat::opInvalidOp ==
2096  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2097  // Undefined behavior invoked - the destination type can't represent
2098  // the input constant.
2099  return failure();
2100  }
2101 
2102  if (!rhs.isZero()) {
2103  APFloat apf(floatTy.getFloatSemantics(),
2104  APInt::getZero(floatTy.getWidth()));
2105  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2106 
2107  bool equal = apf == rhs;
2108  if (!equal) {
2109  // If we had a comparison against a fractional value, we have to adjust
2110  // the compare predicate and sometimes the value. rhsInt is rounded
2111  // towards zero at this point.
2112  switch (pred) {
2113  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2114  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2115  /*width=*/1);
2116  return success();
2117  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2118  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2119  /*width=*/1);
2120  return success();
2121  case CmpIPredicate::ule:
2122  // (float)int <= 4.4 --> int <= 4
2123  // (float)int <= -4.4 --> false
2124  if (rhs.isNegative()) {
2125  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2126  /*width=*/1);
2127  return success();
2128  }
2129  break;
2130  case CmpIPredicate::sle:
2131  // (float)int <= 4.4 --> int <= 4
2132  // (float)int <= -4.4 --> int < -4
2133  if (rhs.isNegative())
2134  pred = CmpIPredicate::slt;
2135  break;
2136  case CmpIPredicate::ult:
2137  // (float)int < -4.4 --> false
2138  // (float)int < 4.4 --> int <= 4
2139  if (rhs.isNegative()) {
2140  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2141  /*width=*/1);
2142  return success();
2143  }
2144  pred = CmpIPredicate::ule;
2145  break;
2146  case CmpIPredicate::slt:
2147  // (float)int < -4.4 --> int < -4
2148  // (float)int < 4.4 --> int <= 4
2149  if (!rhs.isNegative())
2150  pred = CmpIPredicate::sle;
2151  break;
2152  case CmpIPredicate::ugt:
2153  // (float)int > 4.4 --> int > 4
2154  // (float)int > -4.4 --> true
2155  if (rhs.isNegative()) {
2156  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2157  /*width=*/1);
2158  return success();
2159  }
2160  break;
2161  case CmpIPredicate::sgt:
2162  // (float)int > 4.4 --> int > 4
2163  // (float)int > -4.4 --> int >= -4
2164  if (rhs.isNegative())
2165  pred = CmpIPredicate::sge;
2166  break;
2167  case CmpIPredicate::uge:
2168  // (float)int >= -4.4 --> true
2169  // (float)int >= 4.4 --> int > 4
2170  if (rhs.isNegative()) {
2171  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2172  /*width=*/1);
2173  return success();
2174  }
2175  pred = CmpIPredicate::ugt;
2176  break;
2177  case CmpIPredicate::sge:
2178  // (float)int >= -4.4 --> int >= -4
2179  // (float)int >= 4.4 --> int > 4
2180  if (!rhs.isNegative())
2181  pred = CmpIPredicate::sgt;
2182  break;
2183  }
2184  }
2185  }
2186 
2187  // Lower this FP comparison into an appropriate integer version of the
2188  // comparison.
2189  rewriter.replaceOpWithNewOp<CmpIOp>(
2190  op, pred, intVal,
2191  rewriter.create<ConstantOp>(
2192  op.getLoc(), intVal.getType(),
2193  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2194  return success();
2195  }
2196 };
2197 
2198 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2199  MLIRContext *context) {
2200  patterns.insert<CmpFIntToFPConst>(context);
2201 }
2202 
2203 //===----------------------------------------------------------------------===//
2204 // SelectOp
2205 //===----------------------------------------------------------------------===//
2206 
2207 // select %arg, %c1, %c0 => extui %arg
2208 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2210 
2211  LogicalResult matchAndRewrite(arith::SelectOp op,
2212  PatternRewriter &rewriter) const override {
2213  // Cannot extui i1 to i1, or i1 to f32
2214  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2215  return failure();
2216 
2217  // select %x, c1, %c0 => extui %arg
2218  if (matchPattern(op.getTrueValue(), m_One()) &&
2219  matchPattern(op.getFalseValue(), m_Zero())) {
2220  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2221  op.getCondition());
2222  return success();
2223  }
2224 
2225  // select %x, c0, %c1 => extui (xor %arg, true)
2226  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2227  matchPattern(op.getFalseValue(), m_One())) {
2228  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2229  op, op.getType(),
2230  rewriter.create<arith::XOrIOp>(
2231  op.getLoc(), op.getCondition(),
2232  rewriter.create<arith::ConstantIntOp>(
2233  op.getLoc(), 1, op.getCondition().getType())));
2234  return success();
2235  }
2236 
2237  return failure();
2238  }
2239 };
2240 
2241 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2242  MLIRContext *context) {
2243  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2244  SelectI1ToNot, SelectToExtUI>(context);
2245 }
2246 
2247 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2248  Value trueVal = getTrueValue();
2249  Value falseVal = getFalseValue();
2250  if (trueVal == falseVal)
2251  return trueVal;
2252 
2253  Value condition = getCondition();
2254 
2255  // select true, %0, %1 => %0
2256  if (matchPattern(adaptor.getCondition(), m_One()))
2257  return trueVal;
2258 
2259  // select false, %0, %1 => %1
2260  if (matchPattern(adaptor.getCondition(), m_Zero()))
2261  return falseVal;
2262 
2263  // If either operand is fully poisoned, return the other.
2264  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2265  return falseVal;
2266 
2267  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2268  return trueVal;
2269 
2270  // select %x, true, false => %x
2271  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2272  matchPattern(adaptor.getFalseValue(), m_Zero()))
2273  return condition;
2274 
2275  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2276  auto pred = cmp.getPredicate();
2277  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2278  auto cmpLhs = cmp.getLhs();
2279  auto cmpRhs = cmp.getRhs();
2280 
2281  // %0 = arith.cmpi eq, %arg0, %arg1
2282  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2283 
2284  // %0 = arith.cmpi ne, %arg0, %arg1
2285  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2286 
2287  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2288  (cmpRhs == trueVal && cmpLhs == falseVal))
2289  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2290  }
2291  }
2292 
2293  // Constant-fold constant operands over non-splat constant condition.
2294  // select %cst_vec, %cst0, %cst1 => %cst2
2295  if (auto cond =
2296  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2297  if (auto lhs =
2298  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2299  if (auto rhs =
2300  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2301  SmallVector<Attribute> results;
2302  results.reserve(static_cast<size_t>(cond.getNumElements()));
2303  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2304  cond.value_end<BoolAttr>());
2305  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2306  lhs.value_end<Attribute>());
2307  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2308  rhs.value_end<Attribute>());
2309 
2310  for (auto [condVal, lhsVal, rhsVal] :
2311  llvm::zip_equal(condVals, lhsVals, rhsVals))
2312  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2313 
2314  return DenseElementsAttr::get(lhs.getType(), results);
2315  }
2316  }
2317  }
2318 
2319  return nullptr;
2320 }
2321 
2323  Type conditionType, resultType;
2325  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2326  parser.parseOptionalAttrDict(result.attributes) ||
2327  parser.parseColonType(resultType))
2328  return failure();
2329 
2330  // Check for the explicit condition type if this is a masked tensor or vector.
2331  if (succeeded(parser.parseOptionalComma())) {
2332  conditionType = resultType;
2333  if (parser.parseType(resultType))
2334  return failure();
2335  } else {
2336  conditionType = parser.getBuilder().getI1Type();
2337  }
2338 
2339  result.addTypes(resultType);
2340  return parser.resolveOperands(operands,
2341  {conditionType, resultType, resultType},
2342  parser.getNameLoc(), result.operands);
2343 }
2344 
2346  p << " " << getOperands();
2347  p.printOptionalAttrDict((*this)->getAttrs());
2348  p << " : ";
2349  if (ShapedType condType =
2350  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2351  p << condType << ", ";
2352  p << getType();
2353 }
2354 
2356  Type conditionType = getCondition().getType();
2357  if (conditionType.isSignlessInteger(1))
2358  return success();
2359 
2360  // If the result type is a vector or tensor, the type can be a mask with the
2361  // same elements.
2362  Type resultType = getType();
2363  if (!llvm::isa<TensorType, VectorType>(resultType))
2364  return emitOpError() << "expected condition to be a signless i1, but got "
2365  << conditionType;
2366  Type shapedConditionType = getI1SameShape(resultType);
2367  if (conditionType != shapedConditionType) {
2368  return emitOpError() << "expected condition type to have the same shape "
2369  "as the result type, expected "
2370  << shapedConditionType << ", but got "
2371  << conditionType;
2372  }
2373  return success();
2374 }
2375 //===----------------------------------------------------------------------===//
2376 // ShLIOp
2377 //===----------------------------------------------------------------------===//
2378 
2379 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2380  // shli(x, 0) -> x
2381  if (matchPattern(adaptor.getRhs(), m_Zero()))
2382  return getLhs();
2383  // Don't fold if shifting more or equal than the bit width.
2384  bool bounded = false;
2385  auto result = constFoldBinaryOp<IntegerAttr>(
2386  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2387  bounded = b.ult(b.getBitWidth());
2388  return a.shl(b);
2389  });
2390  return bounded ? result : Attribute();
2391 }
2392 
2393 //===----------------------------------------------------------------------===//
2394 // ShRUIOp
2395 //===----------------------------------------------------------------------===//
2396 
2397 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2398  // shrui(x, 0) -> x
2399  if (matchPattern(adaptor.getRhs(), m_Zero()))
2400  return getLhs();
2401  // Don't fold if shifting more or equal than the bit width.
2402  bool bounded = false;
2403  auto result = constFoldBinaryOp<IntegerAttr>(
2404  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2405  bounded = b.ult(b.getBitWidth());
2406  return a.lshr(b);
2407  });
2408  return bounded ? result : Attribute();
2409 }
2410 
2411 //===----------------------------------------------------------------------===//
2412 // ShRSIOp
2413 //===----------------------------------------------------------------------===//
2414 
2415 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2416  // shrsi(x, 0) -> x
2417  if (matchPattern(adaptor.getRhs(), m_Zero()))
2418  return getLhs();
2419  // Don't fold if shifting more or equal than the bit width.
2420  bool bounded = false;
2421  auto result = constFoldBinaryOp<IntegerAttr>(
2422  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2423  bounded = b.ult(b.getBitWidth());
2424  return a.ashr(b);
2425  });
2426  return bounded ? result : Attribute();
2427 }
2428 
2429 //===----------------------------------------------------------------------===//
2430 // Atomic Enum
2431 //===----------------------------------------------------------------------===//
2432 
2433 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2434 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2435  OpBuilder &builder, Location loc,
2436  bool useOnlyFiniteValue) {
2437  switch (kind) {
2438  case AtomicRMWKind::maximumf: {
2439  const llvm::fltSemantics &semantic =
2440  llvm::cast<FloatType>(resultType).getFloatSemantics();
2441  APFloat identity = useOnlyFiniteValue
2442  ? APFloat::getLargest(semantic, /*Negative=*/true)
2443  : APFloat::getInf(semantic, /*Negative=*/true);
2444  return builder.getFloatAttr(resultType, identity);
2445  }
2446  case AtomicRMWKind::addf:
2447  case AtomicRMWKind::addi:
2448  case AtomicRMWKind::maxu:
2449  case AtomicRMWKind::ori:
2450  return builder.getZeroAttr(resultType);
2451  case AtomicRMWKind::andi:
2452  return builder.getIntegerAttr(
2453  resultType,
2454  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2455  case AtomicRMWKind::maxs:
2456  return builder.getIntegerAttr(
2457  resultType, APInt::getSignedMinValue(
2458  llvm::cast<IntegerType>(resultType).getWidth()));
2459  case AtomicRMWKind::minimumf: {
2460  const llvm::fltSemantics &semantic =
2461  llvm::cast<FloatType>(resultType).getFloatSemantics();
2462  APFloat identity = useOnlyFiniteValue
2463  ? APFloat::getLargest(semantic, /*Negative=*/false)
2464  : APFloat::getInf(semantic, /*Negative=*/false);
2465 
2466  return builder.getFloatAttr(resultType, identity);
2467  }
2468  case AtomicRMWKind::mins:
2469  return builder.getIntegerAttr(
2470  resultType, APInt::getSignedMaxValue(
2471  llvm::cast<IntegerType>(resultType).getWidth()));
2472  case AtomicRMWKind::minu:
2473  return builder.getIntegerAttr(
2474  resultType,
2475  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2476  case AtomicRMWKind::muli:
2477  return builder.getIntegerAttr(resultType, 1);
2478  case AtomicRMWKind::mulf:
2479  return builder.getFloatAttr(resultType, 1);
2480  // TODO: Add remaining reduction operations.
2481  default:
2482  (void)emitOptionalError(loc, "Reduction operation type not supported");
2483  break;
2484  }
2485  return nullptr;
2486 }
2487 
2488 /// Return the identity numeric value associated to the give op.
2489 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2490  std::optional<AtomicRMWKind> maybeKind =
2492  // Floating-point operations.
2493  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2494  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2495  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2496  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2497  // Integer operations.
2498  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2499  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2500  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2501  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2502  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2503  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2504  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2505  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2506  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2507  .Default([](Operation *op) { return std::nullopt; });
2508  if (!maybeKind) {
2509  op->emitError() << "Unknown neutral element for: " << *op;
2510  return std::nullopt;
2511  }
2512 
2513  bool useOnlyFiniteValue = false;
2514  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2515  if (fmfOpInterface) {
2516  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2517  useOnlyFiniteValue =
2518  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2519  }
2520 
2521  // Builder only used as helper for attribute creation.
2522  OpBuilder b(op->getContext());
2523  Type resultType = op->getResult(0).getType();
2524 
2525  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2526  useOnlyFiniteValue);
2527 }
2528 
2529 /// Returns the identity value associated with an AtomicRMWKind op.
2530 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2531  OpBuilder &builder, Location loc,
2532  bool useOnlyFiniteValue) {
2533  auto attr =
2534  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2535  return builder.create<arith::ConstantOp>(loc, attr);
2536 }
2537 
2538 /// Return the value obtained by applying the reduction operation kind
2539 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2540 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2541  Location loc, Value lhs, Value rhs) {
2542  switch (op) {
2543  case AtomicRMWKind::addf:
2544  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2545  case AtomicRMWKind::addi:
2546  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2547  case AtomicRMWKind::mulf:
2548  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2549  case AtomicRMWKind::muli:
2550  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2551  case AtomicRMWKind::maximumf:
2552  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2553  case AtomicRMWKind::minimumf:
2554  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2555  case AtomicRMWKind::maxnumf:
2556  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2557  case AtomicRMWKind::minnumf:
2558  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2559  case AtomicRMWKind::maxs:
2560  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2561  case AtomicRMWKind::mins:
2562  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2563  case AtomicRMWKind::maxu:
2564  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2565  case AtomicRMWKind::minu:
2566  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2567  case AtomicRMWKind::ori:
2568  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2569  case AtomicRMWKind::andi:
2570  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2571  // TODO: Add remaining reduction operations.
2572  default:
2573  (void)emitOptionalError(loc, "Reduction operation type not supported");
2574  break;
2575  }
2576  return nullptr;
2577 }
2578 
2579 //===----------------------------------------------------------------------===//
2580 // TableGen'd op method definitions
2581 //===----------------------------------------------------------------------===//
2582 
2583 #define GET_OP_CLASSES
2584 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2585 
2586 //===----------------------------------------------------------------------===//
2587 // TableGen'd enum attribute definitions
2588 //===----------------------------------------------------------------------===//
2589 
2590 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1277
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:101
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1213
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1746
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1227
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1199
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:636
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:134
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:142
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1292
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1602
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1500
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1249
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:122
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:811
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1220
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:163
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1764
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1235
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1193
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:338
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1263
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1943
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1916
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:67
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:271
static bool classof(Operation *op)
Definition: ArithOps.cpp:277
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:283
static bool classof(Operation *op)
Definition: ArithOps.cpp:289
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:250
static bool classof(Operation *op)
Definition: ArithOps.cpp:265
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2489
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1718
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2434
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2540
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2530
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:68
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:371
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:350
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:363
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:355
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2211
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes