MLIR  19.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
209  auto type = getType();
210  // The value's type must match the return type.
211  if (getValue().getType() != type) {
212  return emitOpError() << "value type " << getValue().getType()
213  << " must match return type: " << type;
214  }
215  // Integer values must be signless.
216  if (llvm::isa<IntegerType>(type) &&
217  !llvm::cast<IntegerType>(type).isSignless())
218  return emitOpError("integer return type must be signless");
219  // Any float or elements attribute are acceptable.
220  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221  return emitOpError(
222  "value must be an integer, float, or elements attribute");
223  }
224 
225  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
226  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
227  // However, this would most likely require updating the lowerings to LLVM.
228  auto vecType = dyn_cast<VectorType>(type);
229  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
230  return emitOpError(
231  "intializing scalable vectors with elements attribute is not supported"
232  " unless it's a vector splat");
233  return success();
234 }
235 
236 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
237  // The value's type must be the same as the provided type.
238  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
239  if (!typedAttr || typedAttr.getType() != type)
240  return false;
241  // Integer values must be signless.
242  if (llvm::isa<IntegerType>(type) &&
243  !llvm::cast<IntegerType>(type).isSignless())
244  return false;
245  // Integer, float, and element attributes are buildable.
246  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
247 }
248 
249 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
250  Type type, Location loc) {
251  if (isBuildableWith(value, type))
252  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
253  return nullptr;
254 }
255 
256 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
257 
259  int64_t value, unsigned width) {
260  auto type = builder.getIntegerType(width);
261  arith::ConstantOp::build(builder, result, type,
262  builder.getIntegerAttr(type, value));
263 }
264 
266  int64_t value, Type type) {
267  assert(type.isSignlessInteger() &&
268  "ConstantIntOp can only have signless integer type values");
269  arith::ConstantOp::build(builder, result, type,
270  builder.getIntegerAttr(type, value));
271 }
272 
274  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
275  return constOp.getType().isSignlessInteger();
276  return false;
277 }
278 
280  const APFloat &value, FloatType type) {
281  arith::ConstantOp::build(builder, result, type,
282  builder.getFloatAttr(type, value));
283 }
284 
286  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
287  return llvm::isa<FloatType>(constOp.getType());
288  return false;
289 }
290 
292  int64_t value) {
293  arith::ConstantOp::build(builder, result, builder.getIndexType(),
294  builder.getIndexAttr(value));
295 }
296 
298  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
299  return constOp.getType().isIndex();
300  return false;
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // AddIOp
305 //===----------------------------------------------------------------------===//
306 
307 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
308  // addi(x, 0) -> x
309  if (matchPattern(adaptor.getRhs(), m_Zero()))
310  return getLhs();
311 
312  // addi(subi(a, b), b) -> a
313  if (auto sub = getLhs().getDefiningOp<SubIOp>())
314  if (getRhs() == sub.getRhs())
315  return sub.getLhs();
316 
317  // addi(b, subi(a, b)) -> a
318  if (auto sub = getRhs().getDefiningOp<SubIOp>())
319  if (getLhs() == sub.getRhs())
320  return sub.getLhs();
321 
322  return constFoldBinaryOp<IntegerAttr>(
323  adaptor.getOperands(),
324  [](APInt a, const APInt &b) { return std::move(a) + b; });
325 }
326 
327 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
328  MLIRContext *context) {
329  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
330  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // AddUIExtendedOp
335 //===----------------------------------------------------------------------===//
336 
337 std::optional<SmallVector<int64_t, 4>>
338 arith::AddUIExtendedOp::getShapeForUnroll() {
339  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
340  return llvm::to_vector<4>(vt.getShape());
341  return std::nullopt;
342 }
343 
344 // Returns the overflow bit, assuming that `sum` is the result of unsigned
345 // addition of `operand` and another number.
346 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
347  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
348 }
349 
351 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
353  Type overflowTy = getOverflow().getType();
354  // addui_extended(x, 0) -> x, false
355  if (matchPattern(getRhs(), m_Zero())) {
356  Builder builder(getContext());
357  auto falseValue = builder.getZeroAttr(overflowTy);
358 
359  results.push_back(getLhs());
360  results.push_back(falseValue);
361  return success();
362  }
363 
364  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
365  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
366  // operands. If that succeeds, calculate the overflow bit based on the sum
367  // and the first (constant) operand, `lhs`.
368  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
369  adaptor.getOperands(),
370  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
371  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
372  ArrayRef({sumAttr, adaptor.getLhs()}),
373  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
375  if (!overflowAttr)
376  return failure();
377 
378  results.push_back(sumAttr);
379  results.push_back(overflowAttr);
380  return success();
381  }
382 
383  return failure();
384 }
385 
386 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
387  RewritePatternSet &patterns, MLIRContext *context) {
388  patterns.add<AddUIExtendedToAddI>(context);
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // SubIOp
393 //===----------------------------------------------------------------------===//
394 
395 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
396  // subi(x,x) -> 0
397  if (getOperand(0) == getOperand(1))
398  return Builder(getContext()).getZeroAttr(getType());
399  // subi(x,0) -> x
400  if (matchPattern(adaptor.getRhs(), m_Zero()))
401  return getLhs();
402 
403  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
404  // subi(addi(a, b), b) -> a
405  if (getRhs() == add.getRhs())
406  return add.getLhs();
407  // subi(addi(a, b), a) -> b
408  if (getRhs() == add.getLhs())
409  return add.getRhs();
410  }
411 
412  return constFoldBinaryOp<IntegerAttr>(
413  adaptor.getOperands(),
414  [](APInt a, const APInt &b) { return std::move(a) - b; });
415 }
416 
417 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
418  MLIRContext *context) {
419  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
420  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
421  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // MulIOp
426 //===----------------------------------------------------------------------===//
427 
428 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
429  // muli(x, 0) -> 0
430  if (matchPattern(adaptor.getRhs(), m_Zero()))
431  return getRhs();
432  // muli(x, 1) -> x
433  if (matchPattern(adaptor.getRhs(), m_One()))
434  return getLhs();
435  // TODO: Handle the overflow case.
436 
437  // default folder
438  return constFoldBinaryOp<IntegerAttr>(
439  adaptor.getOperands(),
440  [](const APInt &a, const APInt &b) { return a * b; });
441 }
442 
443 void arith::MulIOp::getAsmResultNames(
444  function_ref<void(Value, StringRef)> setNameFn) {
445  if (!isa<IndexType>(getType()))
446  return;
447 
448  // Match vector.vscale by name to avoid depending on the vector dialect (which
449  // is a circular dependency).
450  auto isVscale = [](Operation *op) {
451  return op && op->getName().getStringRef() == "vector.vscale";
452  };
453 
454  IntegerAttr baseValue;
455  auto isVscaleExpr = [&](Value a, Value b) {
456  return matchPattern(a, m_Constant(&baseValue)) &&
457  isVscale(b.getDefiningOp());
458  };
459 
460  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
461  return;
462 
463  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
464  SmallString<32> specialNameBuffer;
465  llvm::raw_svector_ostream specialName(specialNameBuffer);
466  specialName << 'c' << baseValue.getInt() << "_vscale";
467  setNameFn(getResult(), specialName.str());
468 }
469 
470 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
471  MLIRContext *context) {
472  patterns.add<MulIMulIConstant>(context);
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // MulSIExtendedOp
477 //===----------------------------------------------------------------------===//
478 
479 std::optional<SmallVector<int64_t, 4>>
480 arith::MulSIExtendedOp::getShapeForUnroll() {
481  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
482  return llvm::to_vector<4>(vt.getShape());
483  return std::nullopt;
484 }
485 
487 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
489  // mulsi_extended(x, 0) -> 0, 0
490  if (matchPattern(adaptor.getRhs(), m_Zero())) {
491  Attribute zero = adaptor.getRhs();
492  results.push_back(zero);
493  results.push_back(zero);
494  return success();
495  }
496 
497  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
498  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
499  adaptor.getOperands(),
500  [](const APInt &a, const APInt &b) { return a * b; })) {
501  // Invoke the constant fold helper again to calculate the 'high' result.
502  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
503  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
504  return llvm::APIntOps::mulhs(a, b);
505  });
506  assert(highAttr && "Unexpected constant-folding failure");
507 
508  results.push_back(lowAttr);
509  results.push_back(highAttr);
510  return success();
511  }
512 
513  return failure();
514 }
515 
516 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
517  RewritePatternSet &patterns, MLIRContext *context) {
518  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // MulUIExtendedOp
523 //===----------------------------------------------------------------------===//
524 
525 std::optional<SmallVector<int64_t, 4>>
526 arith::MulUIExtendedOp::getShapeForUnroll() {
527  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
528  return llvm::to_vector<4>(vt.getShape());
529  return std::nullopt;
530 }
531 
533 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
535  // mului_extended(x, 0) -> 0, 0
536  if (matchPattern(adaptor.getRhs(), m_Zero())) {
537  Attribute zero = adaptor.getRhs();
538  results.push_back(zero);
539  results.push_back(zero);
540  return success();
541  }
542 
543  // mului_extended(x, 1) -> x, 0
544  if (matchPattern(adaptor.getRhs(), m_One())) {
545  Builder builder(getContext());
546  Attribute zero = builder.getZeroAttr(getLhs().getType());
547  results.push_back(getLhs());
548  results.push_back(zero);
549  return success();
550  }
551 
552  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
553  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
554  adaptor.getOperands(),
555  [](const APInt &a, const APInt &b) { return a * b; })) {
556  // Invoke the constant fold helper again to calculate the 'high' result.
557  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
558  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
559  return llvm::APIntOps::mulhu(a, b);
560  });
561  assert(highAttr && "Unexpected constant-folding failure");
562 
563  results.push_back(lowAttr);
564  results.push_back(highAttr);
565  return success();
566  }
567 
568  return failure();
569 }
570 
571 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
572  RewritePatternSet &patterns, MLIRContext *context) {
573  patterns.add<MulUIExtendedToMulI>(context);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // DivUIOp
578 //===----------------------------------------------------------------------===//
579 
580 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
581  // divui (x, 1) -> x.
582  if (matchPattern(adaptor.getRhs(), m_One()))
583  return getLhs();
584 
585  // Don't fold if it would require a division by zero.
586  bool div0 = false;
587  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
588  [&](APInt a, const APInt &b) {
589  if (div0 || !b) {
590  div0 = true;
591  return a;
592  }
593  return a.udiv(b);
594  });
595 
596  return div0 ? Attribute() : result;
597 }
598 
599 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
600  // X / 0 => UB
601  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // DivSIOp
607 //===----------------------------------------------------------------------===//
608 
609 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
610  // divsi (x, 1) -> x.
611  if (matchPattern(adaptor.getRhs(), m_One()))
612  return getLhs();
613 
614  // Don't fold if it would overflow or if it requires a division by zero.
615  bool overflowOrDiv0 = false;
616  auto result = constFoldBinaryOp<IntegerAttr>(
617  adaptor.getOperands(), [&](APInt a, const APInt &b) {
618  if (overflowOrDiv0 || !b) {
619  overflowOrDiv0 = true;
620  return a;
621  }
622  return a.sdiv_ov(b, overflowOrDiv0);
623  });
624 
625  return overflowOrDiv0 ? Attribute() : result;
626 }
627 
628 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
629  bool mayHaveUB = true;
630 
631  APInt constRHS;
632  // X / 0 => UB
633  // INT_MIN / -1 => UB
634  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
635  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
636 
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // Ceil and floor division folding helpers
642 //===----------------------------------------------------------------------===//
643 
644 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
645  bool &overflow) {
646  // Returns (a-1)/b + 1
647  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
648  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
649  return val.sadd_ov(one, overflow);
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // CeilDivUIOp
654 //===----------------------------------------------------------------------===//
655 
656 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
657  // ceildivui (x, 1) -> x.
658  if (matchPattern(adaptor.getRhs(), m_One()))
659  return getLhs();
660 
661  bool overflowOrDiv0 = false;
662  auto result = constFoldBinaryOp<IntegerAttr>(
663  adaptor.getOperands(), [&](APInt a, const APInt &b) {
664  if (overflowOrDiv0 || !b) {
665  overflowOrDiv0 = true;
666  return a;
667  }
668  APInt quotient = a.udiv(b);
669  if (!a.urem(b))
670  return quotient;
671  APInt one(a.getBitWidth(), 1, true);
672  return quotient.uadd_ov(one, overflowOrDiv0);
673  });
674 
675  return overflowOrDiv0 ? Attribute() : result;
676 }
677 
678 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
679  // X / 0 => UB
680  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // CeilDivSIOp
686 //===----------------------------------------------------------------------===//
687 
688 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
689  // ceildivsi (x, 1) -> x.
690  if (matchPattern(adaptor.getRhs(), m_One()))
691  return getLhs();
692 
693  // Don't fold if it would overflow or if it requires a division by zero.
694  // TODO: This hook won't fold operations where a = MININT, because
695  // negating MININT overflows. This can be improved.
696  bool overflowOrDiv0 = false;
697  auto result = constFoldBinaryOp<IntegerAttr>(
698  adaptor.getOperands(), [&](APInt a, const APInt &b) {
699  if (overflowOrDiv0 || !b) {
700  overflowOrDiv0 = true;
701  return a;
702  }
703  if (!a)
704  return a;
705  // After this point we know that neither a or b are zero.
706  unsigned bits = a.getBitWidth();
707  APInt zero = APInt::getZero(bits);
708  bool aGtZero = a.sgt(zero);
709  bool bGtZero = b.sgt(zero);
710  if (aGtZero && bGtZero) {
711  // Both positive, return ceil(a, b).
712  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
713  }
714 
715  // No folding happens if any of the intermediate arithmetic operations
716  // overflows.
717  bool overflowNegA = false;
718  bool overflowNegB = false;
719  bool overflowDiv = false;
720  bool overflowNegRes = false;
721  if (!aGtZero && !bGtZero) {
722  // Both negative, return ceil(-a, -b).
723  APInt posA = zero.ssub_ov(a, overflowNegA);
724  APInt posB = zero.ssub_ov(b, overflowNegB);
725  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
726  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
727  return res;
728  }
729  if (!aGtZero && bGtZero) {
730  // A is negative, b is positive, return - ( -a / b).
731  APInt posA = zero.ssub_ov(a, overflowNegA);
732  APInt div = posA.sdiv_ov(b, overflowDiv);
733  APInt res = zero.ssub_ov(div, overflowNegRes);
734  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
735  return res;
736  }
737  // A is positive, b is negative, return - (a / -b).
738  APInt posB = zero.ssub_ov(b, overflowNegB);
739  APInt div = a.sdiv_ov(posB, overflowDiv);
740  APInt res = zero.ssub_ov(div, overflowNegRes);
741 
742  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
743  return res;
744  });
745 
746  return overflowOrDiv0 ? Attribute() : result;
747 }
748 
749 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
750  bool mayHaveUB = true;
751 
752  APInt constRHS;
753  // X / 0 => UB
754  // INT_MIN / -1 => UB
755  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
756  mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
757 
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // FloorDivSIOp
763 //===----------------------------------------------------------------------===//
764 
765 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
766  // floordivsi (x, 1) -> x.
767  if (matchPattern(adaptor.getRhs(), m_One()))
768  return getLhs();
769 
770  // Don't fold if it would overflow or if it requires a division by zero.
771  bool overflowOrDiv = false;
772  auto result = constFoldBinaryOp<IntegerAttr>(
773  adaptor.getOperands(), [&](APInt a, const APInt &b) {
774  if (b.isZero()) {
775  overflowOrDiv = true;
776  return a;
777  }
778  return a.sfloordiv_ov(b, overflowOrDiv);
779  });
780 
781  return overflowOrDiv ? Attribute() : result;
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // RemUIOp
786 //===----------------------------------------------------------------------===//
787 
788 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
789  // remui (x, 1) -> 0.
790  if (matchPattern(adaptor.getRhs(), m_One()))
791  return Builder(getContext()).getZeroAttr(getType());
792 
793  // Don't fold if it would require a division by zero.
794  bool div0 = false;
795  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
796  [&](APInt a, const APInt &b) {
797  if (div0 || b.isZero()) {
798  div0 = true;
799  return a;
800  }
801  return a.urem(b);
802  });
803 
804  return div0 ? Attribute() : result;
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // RemSIOp
809 //===----------------------------------------------------------------------===//
810 
811 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
812  // remsi (x, 1) -> 0.
813  if (matchPattern(adaptor.getRhs(), m_One()))
814  return Builder(getContext()).getZeroAttr(getType());
815 
816  // Don't fold if it would require a division by zero.
817  bool div0 = false;
818  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
819  [&](APInt a, const APInt &b) {
820  if (div0 || b.isZero()) {
821  div0 = true;
822  return a;
823  }
824  return a.srem(b);
825  });
826 
827  return div0 ? Attribute() : result;
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // AndIOp
832 //===----------------------------------------------------------------------===//
833 
834 /// Fold `and(a, and(a, b))` to `and(a, b)`
835 static Value foldAndIofAndI(arith::AndIOp op) {
836  for (bool reversePrev : {false, true}) {
837  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
838  .getDefiningOp<arith::AndIOp>();
839  if (!prev)
840  continue;
841 
842  Value other = (reversePrev ? op.getLhs() : op.getRhs());
843  if (other != prev.getLhs() && other != prev.getRhs())
844  continue;
845 
846  return prev.getResult();
847  }
848  return {};
849 }
850 
851 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
852  /// and(x, 0) -> 0
853  if (matchPattern(adaptor.getRhs(), m_Zero()))
854  return getRhs();
855  /// and(x, allOnes) -> x
856  APInt intValue;
857  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
858  intValue.isAllOnes())
859  return getLhs();
860  /// and(x, not(x)) -> 0
861  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
862  m_ConstantInt(&intValue))) &&
863  intValue.isAllOnes())
864  return Builder(getContext()).getZeroAttr(getType());
865  /// and(not(x), x) -> 0
866  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
867  m_ConstantInt(&intValue))) &&
868  intValue.isAllOnes())
869  return Builder(getContext()).getZeroAttr(getType());
870 
871  /// and(a, and(a, b)) -> and(a, b)
872  if (Value result = foldAndIofAndI(*this))
873  return result;
874 
875  return constFoldBinaryOp<IntegerAttr>(
876  adaptor.getOperands(),
877  [](APInt a, const APInt &b) { return std::move(a) & b; });
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // OrIOp
882 //===----------------------------------------------------------------------===//
883 
884 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
885  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
886  /// or(x, 0) -> x
887  if (rhsVal.isZero())
888  return getLhs();
889  /// or(x, <all ones>) -> <all ones>
890  if (rhsVal.isAllOnes())
891  return adaptor.getRhs();
892  }
893 
894  APInt intValue;
895  /// or(x, xor(x, 1)) -> 1
896  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
897  m_ConstantInt(&intValue))) &&
898  intValue.isAllOnes())
899  return getRhs().getDefiningOp<XOrIOp>().getRhs();
900  /// or(xor(x, 1), x) -> 1
901  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
902  m_ConstantInt(&intValue))) &&
903  intValue.isAllOnes())
904  return getLhs().getDefiningOp<XOrIOp>().getRhs();
905 
906  return constFoldBinaryOp<IntegerAttr>(
907  adaptor.getOperands(),
908  [](APInt a, const APInt &b) { return std::move(a) | b; });
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // XOrIOp
913 //===----------------------------------------------------------------------===//
914 
915 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
916  /// xor(x, 0) -> x
917  if (matchPattern(adaptor.getRhs(), m_Zero()))
918  return getLhs();
919  /// xor(x, x) -> 0
920  if (getLhs() == getRhs())
921  return Builder(getContext()).getZeroAttr(getType());
922  /// xor(xor(x, a), a) -> x
923  /// xor(xor(a, x), a) -> x
924  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
925  if (prev.getRhs() == getRhs())
926  return prev.getLhs();
927  if (prev.getLhs() == getRhs())
928  return prev.getRhs();
929  }
930  /// xor(a, xor(x, a)) -> x
931  /// xor(a, xor(a, x)) -> x
932  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
933  if (prev.getRhs() == getLhs())
934  return prev.getLhs();
935  if (prev.getLhs() == getLhs())
936  return prev.getRhs();
937  }
938 
939  return constFoldBinaryOp<IntegerAttr>(
940  adaptor.getOperands(),
941  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
942 }
943 
944 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
945  MLIRContext *context) {
946  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
947 }
948 
949 //===----------------------------------------------------------------------===//
950 // NegFOp
951 //===----------------------------------------------------------------------===//
952 
953 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
954  /// negf(negf(x)) -> x
955  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
956  return op.getOperand();
957  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
958  [](const APFloat &a) { return -a; });
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // AddFOp
963 //===----------------------------------------------------------------------===//
964 
965 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
966  // addf(x, -0) -> x
967  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
968  return getLhs();
969 
970  return constFoldBinaryOp<FloatAttr>(
971  adaptor.getOperands(),
972  [](const APFloat &a, const APFloat &b) { return a + b; });
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // SubFOp
977 //===----------------------------------------------------------------------===//
978 
979 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
980  // subf(x, +0) -> x
981  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
982  return getLhs();
983 
984  return constFoldBinaryOp<FloatAttr>(
985  adaptor.getOperands(),
986  [](const APFloat &a, const APFloat &b) { return a - b; });
987 }
988 
989 //===----------------------------------------------------------------------===//
990 // MaximumFOp
991 //===----------------------------------------------------------------------===//
992 
993 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
994  // maximumf(x,x) -> x
995  if (getLhs() == getRhs())
996  return getRhs();
997 
998  // maximumf(x, -inf) -> x
999  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1000  return getLhs();
1001 
1002  return constFoldBinaryOp<FloatAttr>(
1003  adaptor.getOperands(),
1004  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1005 }
1006 
1007 //===----------------------------------------------------------------------===//
1008 // MaxNumFOp
1009 //===----------------------------------------------------------------------===//
1010 
1011 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1012  // maxnumf(x,x) -> x
1013  if (getLhs() == getRhs())
1014  return getRhs();
1015 
1016  // maxnumf(x, -inf) -> x
1017  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1018  return getLhs();
1019 
1020  return constFoldBinaryOp<FloatAttr>(
1021  adaptor.getOperands(),
1022  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // MaxSIOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1030  // maxsi(x,x) -> x
1031  if (getLhs() == getRhs())
1032  return getRhs();
1033 
1034  if (APInt intValue;
1035  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1036  // maxsi(x,MAX_INT) -> MAX_INT
1037  if (intValue.isMaxSignedValue())
1038  return getRhs();
1039  // maxsi(x, MIN_INT) -> x
1040  if (intValue.isMinSignedValue())
1041  return getLhs();
1042  }
1043 
1044  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1045  [](const APInt &a, const APInt &b) {
1046  return llvm::APIntOps::smax(a, b);
1047  });
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // MaxUIOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1055  // maxui(x,x) -> x
1056  if (getLhs() == getRhs())
1057  return getRhs();
1058 
1059  if (APInt intValue;
1060  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1061  // maxui(x,MAX_INT) -> MAX_INT
1062  if (intValue.isMaxValue())
1063  return getRhs();
1064  // maxui(x, MIN_INT) -> x
1065  if (intValue.isMinValue())
1066  return getLhs();
1067  }
1068 
1069  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1070  [](const APInt &a, const APInt &b) {
1071  return llvm::APIntOps::umax(a, b);
1072  });
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // MinimumFOp
1077 //===----------------------------------------------------------------------===//
1078 
1079 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1080  // minimumf(x,x) -> x
1081  if (getLhs() == getRhs())
1082  return getRhs();
1083 
1084  // minimumf(x, +inf) -> x
1085  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1086  return getLhs();
1087 
1088  return constFoldBinaryOp<FloatAttr>(
1089  adaptor.getOperands(),
1090  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1091 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // MinNumFOp
1095 //===----------------------------------------------------------------------===//
1096 
1097 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1098  // minnumf(x,x) -> x
1099  if (getLhs() == getRhs())
1100  return getRhs();
1101 
1102  // minnumf(x, +inf) -> x
1103  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1104  return getLhs();
1105 
1106  return constFoldBinaryOp<FloatAttr>(
1107  adaptor.getOperands(),
1108  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // MinSIOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1116  // minsi(x,x) -> x
1117  if (getLhs() == getRhs())
1118  return getRhs();
1119 
1120  if (APInt intValue;
1121  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1122  // minsi(x,MIN_INT) -> MIN_INT
1123  if (intValue.isMinSignedValue())
1124  return getRhs();
1125  // minsi(x, MAX_INT) -> x
1126  if (intValue.isMaxSignedValue())
1127  return getLhs();
1128  }
1129 
1130  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1131  [](const APInt &a, const APInt &b) {
1132  return llvm::APIntOps::smin(a, b);
1133  });
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // MinUIOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1141  // minui(x,x) -> x
1142  if (getLhs() == getRhs())
1143  return getRhs();
1144 
1145  if (APInt intValue;
1146  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1147  // minui(x,MIN_INT) -> MIN_INT
1148  if (intValue.isMinValue())
1149  return getRhs();
1150  // minui(x, MAX_INT) -> x
1151  if (intValue.isMaxValue())
1152  return getLhs();
1153  }
1154 
1155  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1156  [](const APInt &a, const APInt &b) {
1157  return llvm::APIntOps::umin(a, b);
1158  });
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // MulFOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1166  // mulf(x, 1) -> x
1167  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1168  return getLhs();
1169 
1170  return constFoldBinaryOp<FloatAttr>(
1171  adaptor.getOperands(),
1172  [](const APFloat &a, const APFloat &b) { return a * b; });
1173 }
1174 
1175 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1176  MLIRContext *context) {
1177  patterns.add<MulFOfNegF>(context);
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // DivFOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1185  // divf(x, 1) -> x
1186  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1187  return getLhs();
1188 
1189  return constFoldBinaryOp<FloatAttr>(
1190  adaptor.getOperands(),
1191  [](const APFloat &a, const APFloat &b) { return a / b; });
1192 }
1193 
1194 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1195  MLIRContext *context) {
1196  patterns.add<DivFOfNegF>(context);
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // RemFOp
1201 //===----------------------------------------------------------------------===//
1202 
1203 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1204  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1205  [](const APFloat &a, const APFloat &b) {
1206  APFloat result(a);
1207  (void)result.remainder(b);
1208  return result;
1209  });
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // Utility functions for verifying cast ops
1214 //===----------------------------------------------------------------------===//
1215 
1216 template <typename... Types>
1217 using type_list = std::tuple<Types...> *;
1218 
1219 /// Returns a non-null type only if the provided type is one of the allowed
1220 /// types or one of the allowed shaped types of the allowed types. Returns the
1221 /// element type if a valid shaped type is provided.
1222 template <typename... ShapedTypes, typename... ElementTypes>
1225  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1226  return {};
1227 
1228  auto underlyingType = getElementTypeOrSelf(type);
1229  if (!llvm::isa<ElementTypes...>(underlyingType))
1230  return {};
1231 
1232  return underlyingType;
1233 }
1234 
1235 /// Get allowed underlying types for vectors and tensors.
1236 template <typename... ElementTypes>
1237 static Type getTypeIfLike(Type type) {
1240 }
1241 
1242 /// Get allowed underlying types for vectors, tensors, and memrefs.
1243 template <typename... ElementTypes>
1245  return getUnderlyingType(type,
1248 }
1249 
1250 /// Return false if both types are ranked tensor with mismatching encoding.
1251 static bool hasSameEncoding(Type typeA, Type typeB) {
1252  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1253  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1254  if (!rankedTensorA || !rankedTensorB)
1255  return true;
1256  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1257 }
1258 
1259 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1260  if (inputs.size() != 1 || outputs.size() != 1)
1261  return false;
1262  if (!hasSameEncoding(inputs.front(), outputs.front()))
1263  return false;
1264  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // Verifiers for integer and floating point extension/truncation ops
1269 //===----------------------------------------------------------------------===//
1270 
1271 // Extend ops can only extend to a wider type.
1272 template <typename ValType, typename Op>
1274  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1275  Type dstType = getElementTypeOrSelf(op.getType());
1276 
1277  if (llvm::cast<ValType>(srcType).getWidth() >=
1278  llvm::cast<ValType>(dstType).getWidth())
1279  return op.emitError("result type ")
1280  << dstType << " must be wider than operand type " << srcType;
1281 
1282  return success();
1283 }
1284 
1285 // Truncate ops can only truncate to a shorter type.
1286 template <typename ValType, typename Op>
1288  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1289  Type dstType = getElementTypeOrSelf(op.getType());
1290 
1291  if (llvm::cast<ValType>(srcType).getWidth() <=
1292  llvm::cast<ValType>(dstType).getWidth())
1293  return op.emitError("result type ")
1294  << dstType << " must be shorter than operand type " << srcType;
1295 
1296  return success();
1297 }
1298 
1299 /// Validate a cast that changes the width of a type.
1300 template <template <typename> class WidthComparator, typename... ElementTypes>
1301 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1302  if (!areValidCastInputsAndOutputs(inputs, outputs))
1303  return false;
1304 
1305  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1306  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1307  if (!srcType || !dstType)
1308  return false;
1309 
1310  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1311  srcType.getIntOrFloatBitWidth());
1312 }
1313 
1314 /// Attempts to convert `sourceValue` to an APFloat value with
1315 /// `targetSemantics` and `roundingMode`, without any information loss.
1317  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1318  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1319  bool losesInfo = false;
1320  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1321  if (losesInfo || status != APFloat::opOK)
1322  return failure();
1323 
1324  return sourceValue;
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // ExtUIOp
1329 //===----------------------------------------------------------------------===//
1330 
1331 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1332  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1333  getInMutable().assign(lhs.getIn());
1334  return getResult();
1335  }
1336 
1337  Type resType = getElementTypeOrSelf(getType());
1338  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1339  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1340  adaptor.getOperands(), getType(),
1341  [bitWidth](const APInt &a, bool &castStatus) {
1342  return a.zext(bitWidth);
1343  });
1344 }
1345 
1346 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1347  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1348 }
1349 
1351  return verifyExtOp<IntegerType>(*this);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // ExtSIOp
1356 //===----------------------------------------------------------------------===//
1357 
1358 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1359  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1360  getInMutable().assign(lhs.getIn());
1361  return getResult();
1362  }
1363 
1364  Type resType = getElementTypeOrSelf(getType());
1365  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1366  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1367  adaptor.getOperands(), getType(),
1368  [bitWidth](const APInt &a, bool &castStatus) {
1369  return a.sext(bitWidth);
1370  });
1371 }
1372 
1373 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1374  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1375 }
1376 
1377 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1378  MLIRContext *context) {
1379  patterns.add<ExtSIOfExtUI>(context);
1380 }
1381 
1383  return verifyExtOp<IntegerType>(*this);
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // ExtFOp
1388 //===----------------------------------------------------------------------===//
1389 
1390 /// Fold extension of float constants when there is no information loss due the
1391 /// difference in fp semantics.
1392 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1393  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1394  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1395  return constFoldCastOp<FloatAttr, FloatAttr>(
1396  adaptor.getOperands(), getType(),
1397  [&targetSemantics](const APFloat &a, bool &castStatus) {
1398  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1399  if (failed(result)) {
1400  castStatus = false;
1401  return a;
1402  }
1403  return *result;
1404  });
1405 }
1406 
1407 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1408  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1409 }
1410 
1411 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1412 
1413 //===----------------------------------------------------------------------===//
1414 // TruncIOp
1415 //===----------------------------------------------------------------------===//
1416 
1417 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1418  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1419  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1420  Value src = getOperand().getDefiningOp()->getOperand(0);
1421  Type srcType = getElementTypeOrSelf(src.getType());
1422  Type dstType = getElementTypeOrSelf(getType());
1423  // trunci(zexti(a)) -> trunci(a)
1424  // trunci(sexti(a)) -> trunci(a)
1425  if (llvm::cast<IntegerType>(srcType).getWidth() >
1426  llvm::cast<IntegerType>(dstType).getWidth()) {
1427  setOperand(src);
1428  return getResult();
1429  }
1430 
1431  // trunci(zexti(a)) -> a
1432  // trunci(sexti(a)) -> a
1433  if (srcType == dstType)
1434  return src;
1435  }
1436 
1437  // trunci(trunci(a)) -> trunci(a))
1438  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1439  setOperand(getOperand().getDefiningOp()->getOperand(0));
1440  return getResult();
1441  }
1442 
1443  Type resType = getElementTypeOrSelf(getType());
1444  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1445  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1446  adaptor.getOperands(), getType(),
1447  [bitWidth](const APInt &a, bool &castStatus) {
1448  return a.trunc(bitWidth);
1449  });
1450 }
1451 
1452 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1453  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1454 }
1455 
1456 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1457  MLIRContext *context) {
1458  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1459  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1460  context);
1461 }
1462 
1464  return verifyTruncateOp<IntegerType>(*this);
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // TruncFOp
1469 //===----------------------------------------------------------------------===//
1470 
1471 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1472 /// can be represented without precision loss.
1473 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1474  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1475  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1476  return constFoldCastOp<FloatAttr, FloatAttr>(
1477  adaptor.getOperands(), getType(),
1478  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1479  RoundingMode roundingMode =
1480  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1481  llvm::RoundingMode llvmRoundingMode =
1482  convertArithRoundingModeToLLVMIR(roundingMode);
1483  FailureOr<APFloat> result =
1484  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1485  if (failed(result)) {
1486  castStatus = false;
1487  return a;
1488  }
1489  return *result;
1490  });
1491 }
1492 
1493 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1494  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1495 }
1496 
1498  return verifyTruncateOp<FloatType>(*this);
1499 }
1500 
1501 //===----------------------------------------------------------------------===//
1502 // AndIOp
1503 //===----------------------------------------------------------------------===//
1504 
1505 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1506  MLIRContext *context) {
1507  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1508 }
1509 
1510 //===----------------------------------------------------------------------===//
1511 // OrIOp
1512 //===----------------------------------------------------------------------===//
1513 
1514 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1515  MLIRContext *context) {
1516  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // Verifiers for casts between integers and floats.
1521 //===----------------------------------------------------------------------===//
1522 
1523 template <typename From, typename To>
1524 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1525  if (!areValidCastInputsAndOutputs(inputs, outputs))
1526  return false;
1527 
1528  auto srcType = getTypeIfLike<From>(inputs.front());
1529  auto dstType = getTypeIfLike<To>(outputs.back());
1530 
1531  return srcType && dstType;
1532 }
1533 
1534 //===----------------------------------------------------------------------===//
1535 // UIToFPOp
1536 //===----------------------------------------------------------------------===//
1537 
1538 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1539  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1540 }
1541 
1542 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1543  Type resEleType = getElementTypeOrSelf(getType());
1544  return constFoldCastOp<IntegerAttr, FloatAttr>(
1545  adaptor.getOperands(), getType(),
1546  [&resEleType](const APInt &a, bool &castStatus) {
1547  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1548  APFloat apf(floatTy.getFloatSemantics(),
1549  APInt::getZero(floatTy.getWidth()));
1550  apf.convertFromAPInt(a, /*IsSigned=*/false,
1551  APFloat::rmNearestTiesToEven);
1552  return apf;
1553  });
1554 }
1555 
1556 //===----------------------------------------------------------------------===//
1557 // SIToFPOp
1558 //===----------------------------------------------------------------------===//
1559 
1560 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1561  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1562 }
1563 
1564 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1565  Type resEleType = getElementTypeOrSelf(getType());
1566  return constFoldCastOp<IntegerAttr, FloatAttr>(
1567  adaptor.getOperands(), getType(),
1568  [&resEleType](const APInt &a, bool &castStatus) {
1569  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1570  APFloat apf(floatTy.getFloatSemantics(),
1571  APInt::getZero(floatTy.getWidth()));
1572  apf.convertFromAPInt(a, /*IsSigned=*/true,
1573  APFloat::rmNearestTiesToEven);
1574  return apf;
1575  });
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // FPToUIOp
1580 //===----------------------------------------------------------------------===//
1581 
1582 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1583  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1584 }
1585 
1586 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1587  Type resType = getElementTypeOrSelf(getType());
1588  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1589  return constFoldCastOp<FloatAttr, IntegerAttr>(
1590  adaptor.getOperands(), getType(),
1591  [&bitWidth](const APFloat &a, bool &castStatus) {
1592  bool ignored;
1593  APSInt api(bitWidth, /*isUnsigned=*/true);
1594  castStatus = APFloat::opInvalidOp !=
1595  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1596  return api;
1597  });
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // FPToSIOp
1602 //===----------------------------------------------------------------------===//
1603 
1604 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1605  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1606 }
1607 
1608 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1609  Type resType = getElementTypeOrSelf(getType());
1610  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1611  return constFoldCastOp<FloatAttr, IntegerAttr>(
1612  adaptor.getOperands(), getType(),
1613  [&bitWidth](const APFloat &a, bool &castStatus) {
1614  bool ignored;
1615  APSInt api(bitWidth, /*isUnsigned=*/false);
1616  castStatus = APFloat::opInvalidOp !=
1617  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1618  return api;
1619  });
1620 }
1621 
1622 //===----------------------------------------------------------------------===//
1623 // IndexCastOp
1624 //===----------------------------------------------------------------------===//
1625 
1626 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1627  if (!areValidCastInputsAndOutputs(inputs, outputs))
1628  return false;
1629 
1630  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1631  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1632  if (!srcType || !dstType)
1633  return false;
1634 
1635  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1636  (srcType.isSignlessInteger() && dstType.isIndex());
1637 }
1638 
1639 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1640  TypeRange outputs) {
1641  return areIndexCastCompatible(inputs, outputs);
1642 }
1643 
1644 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1645  // index_cast(constant) -> constant
1646  unsigned resultBitwidth = 64; // Default for index integer attributes.
1647  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1648  resultBitwidth = intTy.getWidth();
1649 
1650  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1651  adaptor.getOperands(), getType(),
1652  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1653  return a.sextOrTrunc(resultBitwidth);
1654  });
1655 }
1656 
1657 void arith::IndexCastOp::getCanonicalizationPatterns(
1658  RewritePatternSet &patterns, MLIRContext *context) {
1659  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1660 }
1661 
1662 //===----------------------------------------------------------------------===//
1663 // IndexCastUIOp
1664 //===----------------------------------------------------------------------===//
1665 
1666 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1667  TypeRange outputs) {
1668  return areIndexCastCompatible(inputs, outputs);
1669 }
1670 
1671 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1672  // index_castui(constant) -> constant
1673  unsigned resultBitwidth = 64; // Default for index integer attributes.
1674  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1675  resultBitwidth = intTy.getWidth();
1676 
1677  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1678  adaptor.getOperands(), getType(),
1679  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1680  return a.zextOrTrunc(resultBitwidth);
1681  });
1682 }
1683 
1684 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1685  RewritePatternSet &patterns, MLIRContext *context) {
1686  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // BitcastOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1694  if (!areValidCastInputsAndOutputs(inputs, outputs))
1695  return false;
1696 
1697  auto srcType =
1698  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1699  auto dstType =
1700  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1701  if (!srcType || !dstType)
1702  return false;
1703 
1704  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1705 }
1706 
1707 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1708  auto resType = getType();
1709  auto operand = adaptor.getIn();
1710  if (!operand)
1711  return {};
1712 
1713  /// Bitcast dense elements.
1714  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1715  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1716  /// Other shaped types unhandled.
1717  if (llvm::isa<ShapedType>(resType))
1718  return {};
1719 
1720  /// Bitcast integer or float to integer or float.
1721  APInt bits = llvm::isa<FloatAttr>(operand)
1722  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1723  : llvm::cast<IntegerAttr>(operand).getValue();
1724 
1725  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1726  return FloatAttr::get(resType,
1727  APFloat(resFloatType.getFloatSemantics(), bits));
1728  return IntegerAttr::get(resType, bits);
1729 }
1730 
1731 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1732  MLIRContext *context) {
1733  patterns.add<BitcastOfBitcast>(context);
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // CmpIOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1741 /// comparison predicates.
1742 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1743  const APInt &lhs, const APInt &rhs) {
1744  switch (predicate) {
1745  case arith::CmpIPredicate::eq:
1746  return lhs.eq(rhs);
1747  case arith::CmpIPredicate::ne:
1748  return lhs.ne(rhs);
1749  case arith::CmpIPredicate::slt:
1750  return lhs.slt(rhs);
1751  case arith::CmpIPredicate::sle:
1752  return lhs.sle(rhs);
1753  case arith::CmpIPredicate::sgt:
1754  return lhs.sgt(rhs);
1755  case arith::CmpIPredicate::sge:
1756  return lhs.sge(rhs);
1757  case arith::CmpIPredicate::ult:
1758  return lhs.ult(rhs);
1759  case arith::CmpIPredicate::ule:
1760  return lhs.ule(rhs);
1761  case arith::CmpIPredicate::ugt:
1762  return lhs.ugt(rhs);
1763  case arith::CmpIPredicate::uge:
1764  return lhs.uge(rhs);
1765  }
1766  llvm_unreachable("unknown cmpi predicate kind");
1767 }
1768 
1769 /// Returns true if the predicate is true for two equal operands.
1770 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1771  switch (predicate) {
1772  case arith::CmpIPredicate::eq:
1773  case arith::CmpIPredicate::sle:
1774  case arith::CmpIPredicate::sge:
1775  case arith::CmpIPredicate::ule:
1776  case arith::CmpIPredicate::uge:
1777  return true;
1778  case arith::CmpIPredicate::ne:
1779  case arith::CmpIPredicate::slt:
1780  case arith::CmpIPredicate::sgt:
1781  case arith::CmpIPredicate::ult:
1782  case arith::CmpIPredicate::ugt:
1783  return false;
1784  }
1785  llvm_unreachable("unknown cmpi predicate kind");
1786 }
1787 
1788 static std::optional<int64_t> getIntegerWidth(Type t) {
1789  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1790  return intType.getWidth();
1791  }
1792  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1793  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1794  }
1795  return std::nullopt;
1796 }
1797 
1798 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1799  // cmpi(pred, x, x)
1800  if (getLhs() == getRhs()) {
1801  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1802  return getBoolAttribute(getType(), val);
1803  }
1804 
1805  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1806  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1807  // extsi(%x : i1 -> iN) != 0 -> %x
1808  std::optional<int64_t> integerWidth =
1809  getIntegerWidth(extOp.getOperand().getType());
1810  if (integerWidth && integerWidth.value() == 1 &&
1811  getPredicate() == arith::CmpIPredicate::ne)
1812  return extOp.getOperand();
1813  }
1814  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1815  // extui(%x : i1 -> iN) != 0 -> %x
1816  std::optional<int64_t> integerWidth =
1817  getIntegerWidth(extOp.getOperand().getType());
1818  if (integerWidth && integerWidth.value() == 1 &&
1819  getPredicate() == arith::CmpIPredicate::ne)
1820  return extOp.getOperand();
1821  }
1822  }
1823 
1824  // Move constant to the right side.
1825  if (adaptor.getLhs() && !adaptor.getRhs()) {
1826  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1827  using Pred = CmpIPredicate;
1828  const std::pair<Pred, Pred> invPreds[] = {
1829  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1830  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1831  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1832  {Pred::ne, Pred::ne},
1833  };
1834  Pred origPred = getPredicate();
1835  for (auto pred : invPreds) {
1836  if (origPred == pred.first) {
1837  setPredicate(pred.second);
1838  Value lhs = getLhs();
1839  Value rhs = getRhs();
1840  getLhsMutable().assign(rhs);
1841  getRhsMutable().assign(lhs);
1842  return getResult();
1843  }
1844  }
1845  llvm_unreachable("unknown cmpi predicate kind");
1846  }
1847 
1848  // We are moving constants to the right side; So if lhs is constant rhs is
1849  // guaranteed to be a constant.
1850  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1851  return constFoldBinaryOp<IntegerAttr>(
1852  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1853  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1854  return APInt(1,
1855  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1856  });
1857  }
1858 
1859  return {};
1860 }
1861 
1862 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1863  MLIRContext *context) {
1864  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1865 }
1866 
1867 //===----------------------------------------------------------------------===//
1868 // CmpFOp
1869 //===----------------------------------------------------------------------===//
1870 
1871 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1872 /// comparison predicates.
1873 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1874  const APFloat &lhs, const APFloat &rhs) {
1875  auto cmpResult = lhs.compare(rhs);
1876  switch (predicate) {
1877  case arith::CmpFPredicate::AlwaysFalse:
1878  return false;
1879  case arith::CmpFPredicate::OEQ:
1880  return cmpResult == APFloat::cmpEqual;
1881  case arith::CmpFPredicate::OGT:
1882  return cmpResult == APFloat::cmpGreaterThan;
1883  case arith::CmpFPredicate::OGE:
1884  return cmpResult == APFloat::cmpGreaterThan ||
1885  cmpResult == APFloat::cmpEqual;
1886  case arith::CmpFPredicate::OLT:
1887  return cmpResult == APFloat::cmpLessThan;
1888  case arith::CmpFPredicate::OLE:
1889  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1890  case arith::CmpFPredicate::ONE:
1891  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1892  case arith::CmpFPredicate::ORD:
1893  return cmpResult != APFloat::cmpUnordered;
1894  case arith::CmpFPredicate::UEQ:
1895  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1896  case arith::CmpFPredicate::UGT:
1897  return cmpResult == APFloat::cmpUnordered ||
1898  cmpResult == APFloat::cmpGreaterThan;
1899  case arith::CmpFPredicate::UGE:
1900  return cmpResult == APFloat::cmpUnordered ||
1901  cmpResult == APFloat::cmpGreaterThan ||
1902  cmpResult == APFloat::cmpEqual;
1903  case arith::CmpFPredicate::ULT:
1904  return cmpResult == APFloat::cmpUnordered ||
1905  cmpResult == APFloat::cmpLessThan;
1906  case arith::CmpFPredicate::ULE:
1907  return cmpResult == APFloat::cmpUnordered ||
1908  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1909  case arith::CmpFPredicate::UNE:
1910  return cmpResult != APFloat::cmpEqual;
1911  case arith::CmpFPredicate::UNO:
1912  return cmpResult == APFloat::cmpUnordered;
1913  case arith::CmpFPredicate::AlwaysTrue:
1914  return true;
1915  }
1916  llvm_unreachable("unknown cmpf predicate kind");
1917 }
1918 
1919 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1920  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1921  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1922 
1923  // If one operand is NaN, making them both NaN does not change the result.
1924  if (lhs && lhs.getValue().isNaN())
1925  rhs = lhs;
1926  if (rhs && rhs.getValue().isNaN())
1927  lhs = rhs;
1928 
1929  if (!lhs || !rhs)
1930  return {};
1931 
1932  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1933  return BoolAttr::get(getContext(), val);
1934 }
1935 
1936 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1937 public:
1939 
1940  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1941  bool isUnsigned) {
1942  using namespace arith;
1943  switch (pred) {
1944  case CmpFPredicate::UEQ:
1945  case CmpFPredicate::OEQ:
1946  return CmpIPredicate::eq;
1947  case CmpFPredicate::UGT:
1948  case CmpFPredicate::OGT:
1949  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1950  case CmpFPredicate::UGE:
1951  case CmpFPredicate::OGE:
1952  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1953  case CmpFPredicate::ULT:
1954  case CmpFPredicate::OLT:
1955  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1956  case CmpFPredicate::ULE:
1957  case CmpFPredicate::OLE:
1958  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1959  case CmpFPredicate::UNE:
1960  case CmpFPredicate::ONE:
1961  return CmpIPredicate::ne;
1962  default:
1963  llvm_unreachable("Unexpected predicate!");
1964  }
1965  }
1966 
1968  PatternRewriter &rewriter) const override {
1969  FloatAttr flt;
1970  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1971  return failure();
1972 
1973  const APFloat &rhs = flt.getValue();
1974 
1975  // Don't attempt to fold a nan.
1976  if (rhs.isNaN())
1977  return failure();
1978 
1979  // Get the width of the mantissa. We don't want to hack on conversions that
1980  // might lose information from the integer, e.g. "i64 -> float"
1981  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1982  int mantissaWidth = floatTy.getFPMantissaWidth();
1983  if (mantissaWidth <= 0)
1984  return failure();
1985 
1986  bool isUnsigned;
1987  Value intVal;
1988 
1989  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1990  isUnsigned = false;
1991  intVal = si.getIn();
1992  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1993  isUnsigned = true;
1994  intVal = ui.getIn();
1995  } else {
1996  return failure();
1997  }
1998 
1999  // Check to see that the input is converted from an integer type that is
2000  // small enough that preserves all bits.
2001  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2002  auto intWidth = intTy.getWidth();
2003 
2004  // Number of bits representing values, as opposed to the sign
2005  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2006 
2007  // Following test does NOT adjust intWidth downwards for signed inputs,
2008  // because the most negative value still requires all the mantissa bits
2009  // to distinguish it from one less than that value.
2010  if ((int)intWidth > mantissaWidth) {
2011  // Conversion would lose accuracy. Check if loss can impact comparison.
2012  int exponent = ilogb(rhs);
2013  if (exponent == APFloat::IEK_Inf) {
2014  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2015  if (maxExponent < (int)valueBits) {
2016  // Conversion could create infinity.
2017  return failure();
2018  }
2019  } else {
2020  // Note that if rhs is zero or NaN, then Exp is negative
2021  // and first condition is trivially false.
2022  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2023  // Conversion could affect comparison.
2024  return failure();
2025  }
2026  }
2027  }
2028 
2029  // Convert to equivalent cmpi predicate
2030  CmpIPredicate pred;
2031  switch (op.getPredicate()) {
2032  case CmpFPredicate::ORD:
2033  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2034  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2035  /*width=*/1);
2036  return success();
2037  case CmpFPredicate::UNO:
2038  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2039  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2040  /*width=*/1);
2041  return success();
2042  default:
2043  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2044  break;
2045  }
2046 
2047  if (!isUnsigned) {
2048  // If the rhs value is > SignedMax, fold the comparison. This handles
2049  // +INF and large values.
2050  APFloat signedMax(rhs.getSemantics());
2051  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2052  APFloat::rmNearestTiesToEven);
2053  if (signedMax < rhs) { // smax < 13123.0
2054  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2055  pred == CmpIPredicate::sle)
2056  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2057  /*width=*/1);
2058  else
2059  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2060  /*width=*/1);
2061  return success();
2062  }
2063  } else {
2064  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2065  // +INF and large values.
2066  APFloat unsignedMax(rhs.getSemantics());
2067  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2068  APFloat::rmNearestTiesToEven);
2069  if (unsignedMax < rhs) { // umax < 13123.0
2070  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2071  pred == CmpIPredicate::ule)
2072  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2073  /*width=*/1);
2074  else
2075  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2076  /*width=*/1);
2077  return success();
2078  }
2079  }
2080 
2081  if (!isUnsigned) {
2082  // See if the rhs value is < SignedMin.
2083  APFloat signedMin(rhs.getSemantics());
2084  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2085  APFloat::rmNearestTiesToEven);
2086  if (signedMin > rhs) { // smin > 12312.0
2087  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2088  pred == CmpIPredicate::sge)
2089  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2090  /*width=*/1);
2091  else
2092  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2093  /*width=*/1);
2094  return success();
2095  }
2096  } else {
2097  // See if the rhs value is < UnsignedMin.
2098  APFloat unsignedMin(rhs.getSemantics());
2099  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2100  APFloat::rmNearestTiesToEven);
2101  if (unsignedMin > rhs) { // umin > 12312.0
2102  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2103  pred == CmpIPredicate::uge)
2104  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2105  /*width=*/1);
2106  else
2107  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2108  /*width=*/1);
2109  return success();
2110  }
2111  }
2112 
2113  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2114  // [0, UMAX], but it may still be fractional. See if it is fractional by
2115  // casting the FP value to the integer value and back, checking for
2116  // equality. Don't do this for zero, because -0.0 is not fractional.
2117  bool ignored;
2118  APSInt rhsInt(intWidth, isUnsigned);
2119  if (APFloat::opInvalidOp ==
2120  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2121  // Undefined behavior invoked - the destination type can't represent
2122  // the input constant.
2123  return failure();
2124  }
2125 
2126  if (!rhs.isZero()) {
2127  APFloat apf(floatTy.getFloatSemantics(),
2128  APInt::getZero(floatTy.getWidth()));
2129  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2130 
2131  bool equal = apf == rhs;
2132  if (!equal) {
2133  // If we had a comparison against a fractional value, we have to adjust
2134  // the compare predicate and sometimes the value. rhsInt is rounded
2135  // towards zero at this point.
2136  switch (pred) {
2137  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2138  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2139  /*width=*/1);
2140  return success();
2141  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2142  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2143  /*width=*/1);
2144  return success();
2145  case CmpIPredicate::ule:
2146  // (float)int <= 4.4 --> int <= 4
2147  // (float)int <= -4.4 --> false
2148  if (rhs.isNegative()) {
2149  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2150  /*width=*/1);
2151  return success();
2152  }
2153  break;
2154  case CmpIPredicate::sle:
2155  // (float)int <= 4.4 --> int <= 4
2156  // (float)int <= -4.4 --> int < -4
2157  if (rhs.isNegative())
2158  pred = CmpIPredicate::slt;
2159  break;
2160  case CmpIPredicate::ult:
2161  // (float)int < -4.4 --> false
2162  // (float)int < 4.4 --> int <= 4
2163  if (rhs.isNegative()) {
2164  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2165  /*width=*/1);
2166  return success();
2167  }
2168  pred = CmpIPredicate::ule;
2169  break;
2170  case CmpIPredicate::slt:
2171  // (float)int < -4.4 --> int < -4
2172  // (float)int < 4.4 --> int <= 4
2173  if (!rhs.isNegative())
2174  pred = CmpIPredicate::sle;
2175  break;
2176  case CmpIPredicate::ugt:
2177  // (float)int > 4.4 --> int > 4
2178  // (float)int > -4.4 --> true
2179  if (rhs.isNegative()) {
2180  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2181  /*width=*/1);
2182  return success();
2183  }
2184  break;
2185  case CmpIPredicate::sgt:
2186  // (float)int > 4.4 --> int > 4
2187  // (float)int > -4.4 --> int >= -4
2188  if (rhs.isNegative())
2189  pred = CmpIPredicate::sge;
2190  break;
2191  case CmpIPredicate::uge:
2192  // (float)int >= -4.4 --> true
2193  // (float)int >= 4.4 --> int > 4
2194  if (rhs.isNegative()) {
2195  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2196  /*width=*/1);
2197  return success();
2198  }
2199  pred = CmpIPredicate::ugt;
2200  break;
2201  case CmpIPredicate::sge:
2202  // (float)int >= -4.4 --> int >= -4
2203  // (float)int >= 4.4 --> int > 4
2204  if (!rhs.isNegative())
2205  pred = CmpIPredicate::sgt;
2206  break;
2207  }
2208  }
2209  }
2210 
2211  // Lower this FP comparison into an appropriate integer version of the
2212  // comparison.
2213  rewriter.replaceOpWithNewOp<CmpIOp>(
2214  op, pred, intVal,
2215  rewriter.create<ConstantOp>(
2216  op.getLoc(), intVal.getType(),
2217  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2218  return success();
2219  }
2220 };
2221 
2222 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2223  MLIRContext *context) {
2224  patterns.insert<CmpFIntToFPConst>(context);
2225 }
2226 
2227 //===----------------------------------------------------------------------===//
2228 // SelectOp
2229 //===----------------------------------------------------------------------===//
2230 
2231 // select %arg, %c1, %c0 => extui %arg
2232 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2234 
2235  LogicalResult matchAndRewrite(arith::SelectOp op,
2236  PatternRewriter &rewriter) const override {
2237  // Cannot extui i1 to i1, or i1 to f32
2238  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2239  return failure();
2240 
2241  // select %x, c1, %c0 => extui %arg
2242  if (matchPattern(op.getTrueValue(), m_One()) &&
2243  matchPattern(op.getFalseValue(), m_Zero())) {
2244  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2245  op.getCondition());
2246  return success();
2247  }
2248 
2249  // select %x, c0, %c1 => extui (xor %arg, true)
2250  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2251  matchPattern(op.getFalseValue(), m_One())) {
2252  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2253  op, op.getType(),
2254  rewriter.create<arith::XOrIOp>(
2255  op.getLoc(), op.getCondition(),
2256  rewriter.create<arith::ConstantIntOp>(
2257  op.getLoc(), 1, op.getCondition().getType())));
2258  return success();
2259  }
2260 
2261  return failure();
2262  }
2263 };
2264 
2265 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2266  MLIRContext *context) {
2267  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2268  SelectI1ToNot, SelectToExtUI>(context);
2269 }
2270 
2271 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2272  Value trueVal = getTrueValue();
2273  Value falseVal = getFalseValue();
2274  if (trueVal == falseVal)
2275  return trueVal;
2276 
2277  Value condition = getCondition();
2278 
2279  // select true, %0, %1 => %0
2280  if (matchPattern(adaptor.getCondition(), m_One()))
2281  return trueVal;
2282 
2283  // select false, %0, %1 => %1
2284  if (matchPattern(adaptor.getCondition(), m_Zero()))
2285  return falseVal;
2286 
2287  // If either operand is fully poisoned, return the other.
2288  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2289  return falseVal;
2290 
2291  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2292  return trueVal;
2293 
2294  // select %x, true, false => %x
2295  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2296  matchPattern(adaptor.getFalseValue(), m_Zero()))
2297  return condition;
2298 
2299  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2300  auto pred = cmp.getPredicate();
2301  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2302  auto cmpLhs = cmp.getLhs();
2303  auto cmpRhs = cmp.getRhs();
2304 
2305  // %0 = arith.cmpi eq, %arg0, %arg1
2306  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2307 
2308  // %0 = arith.cmpi ne, %arg0, %arg1
2309  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2310 
2311  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2312  (cmpRhs == trueVal && cmpLhs == falseVal))
2313  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2314  }
2315  }
2316 
2317  // Constant-fold constant operands over non-splat constant condition.
2318  // select %cst_vec, %cst0, %cst1 => %cst2
2319  if (auto cond =
2320  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2321  if (auto lhs =
2322  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2323  if (auto rhs =
2324  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2325  SmallVector<Attribute> results;
2326  results.reserve(static_cast<size_t>(cond.getNumElements()));
2327  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2328  cond.value_end<BoolAttr>());
2329  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2330  lhs.value_end<Attribute>());
2331  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2332  rhs.value_end<Attribute>());
2333 
2334  for (auto [condVal, lhsVal, rhsVal] :
2335  llvm::zip_equal(condVals, lhsVals, rhsVals))
2336  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2337 
2338  return DenseElementsAttr::get(lhs.getType(), results);
2339  }
2340  }
2341  }
2342 
2343  return nullptr;
2344 }
2345 
2347  Type conditionType, resultType;
2349  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2350  parser.parseOptionalAttrDict(result.attributes) ||
2351  parser.parseColonType(resultType))
2352  return failure();
2353 
2354  // Check for the explicit condition type if this is a masked tensor or vector.
2355  if (succeeded(parser.parseOptionalComma())) {
2356  conditionType = resultType;
2357  if (parser.parseType(resultType))
2358  return failure();
2359  } else {
2360  conditionType = parser.getBuilder().getI1Type();
2361  }
2362 
2363  result.addTypes(resultType);
2364  return parser.resolveOperands(operands,
2365  {conditionType, resultType, resultType},
2366  parser.getNameLoc(), result.operands);
2367 }
2368 
2370  p << " " << getOperands();
2371  p.printOptionalAttrDict((*this)->getAttrs());
2372  p << " : ";
2373  if (ShapedType condType =
2374  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2375  p << condType << ", ";
2376  p << getType();
2377 }
2378 
2380  Type conditionType = getCondition().getType();
2381  if (conditionType.isSignlessInteger(1))
2382  return success();
2383 
2384  // If the result type is a vector or tensor, the type can be a mask with the
2385  // same elements.
2386  Type resultType = getType();
2387  if (!llvm::isa<TensorType, VectorType>(resultType))
2388  return emitOpError() << "expected condition to be a signless i1, but got "
2389  << conditionType;
2390  Type shapedConditionType = getI1SameShape(resultType);
2391  if (conditionType != shapedConditionType) {
2392  return emitOpError() << "expected condition type to have the same shape "
2393  "as the result type, expected "
2394  << shapedConditionType << ", but got "
2395  << conditionType;
2396  }
2397  return success();
2398 }
2399 //===----------------------------------------------------------------------===//
2400 // ShLIOp
2401 //===----------------------------------------------------------------------===//
2402 
2403 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2404  // shli(x, 0) -> x
2405  if (matchPattern(adaptor.getRhs(), m_Zero()))
2406  return getLhs();
2407  // Don't fold if shifting more or equal than the bit width.
2408  bool bounded = false;
2409  auto result = constFoldBinaryOp<IntegerAttr>(
2410  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2411  bounded = b.ult(b.getBitWidth());
2412  return a.shl(b);
2413  });
2414  return bounded ? result : Attribute();
2415 }
2416 
2417 //===----------------------------------------------------------------------===//
2418 // ShRUIOp
2419 //===----------------------------------------------------------------------===//
2420 
2421 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2422  // shrui(x, 0) -> x
2423  if (matchPattern(adaptor.getRhs(), m_Zero()))
2424  return getLhs();
2425  // Don't fold if shifting more or equal than the bit width.
2426  bool bounded = false;
2427  auto result = constFoldBinaryOp<IntegerAttr>(
2428  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2429  bounded = b.ult(b.getBitWidth());
2430  return a.lshr(b);
2431  });
2432  return bounded ? result : Attribute();
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // ShRSIOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2440  // shrsi(x, 0) -> x
2441  if (matchPattern(adaptor.getRhs(), m_Zero()))
2442  return getLhs();
2443  // Don't fold if shifting more or equal than the bit width.
2444  bool bounded = false;
2445  auto result = constFoldBinaryOp<IntegerAttr>(
2446  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2447  bounded = b.ult(b.getBitWidth());
2448  return a.ashr(b);
2449  });
2450  return bounded ? result : Attribute();
2451 }
2452 
2453 //===----------------------------------------------------------------------===//
2454 // Atomic Enum
2455 //===----------------------------------------------------------------------===//
2456 
2457 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2458 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2459  OpBuilder &builder, Location loc,
2460  bool useOnlyFiniteValue) {
2461  switch (kind) {
2462  case AtomicRMWKind::maximumf: {
2463  const llvm::fltSemantics &semantic =
2464  llvm::cast<FloatType>(resultType).getFloatSemantics();
2465  APFloat identity = useOnlyFiniteValue
2466  ? APFloat::getLargest(semantic, /*Negative=*/true)
2467  : APFloat::getInf(semantic, /*Negative=*/true);
2468  return builder.getFloatAttr(resultType, identity);
2469  }
2470  case AtomicRMWKind::addf:
2471  case AtomicRMWKind::addi:
2472  case AtomicRMWKind::maxu:
2473  case AtomicRMWKind::ori:
2474  return builder.getZeroAttr(resultType);
2475  case AtomicRMWKind::andi:
2476  return builder.getIntegerAttr(
2477  resultType,
2478  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2479  case AtomicRMWKind::maxs:
2480  return builder.getIntegerAttr(
2481  resultType, APInt::getSignedMinValue(
2482  llvm::cast<IntegerType>(resultType).getWidth()));
2483  case AtomicRMWKind::minimumf: {
2484  const llvm::fltSemantics &semantic =
2485  llvm::cast<FloatType>(resultType).getFloatSemantics();
2486  APFloat identity = useOnlyFiniteValue
2487  ? APFloat::getLargest(semantic, /*Negative=*/false)
2488  : APFloat::getInf(semantic, /*Negative=*/false);
2489 
2490  return builder.getFloatAttr(resultType, identity);
2491  }
2492  case AtomicRMWKind::mins:
2493  return builder.getIntegerAttr(
2494  resultType, APInt::getSignedMaxValue(
2495  llvm::cast<IntegerType>(resultType).getWidth()));
2496  case AtomicRMWKind::minu:
2497  return builder.getIntegerAttr(
2498  resultType,
2499  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2500  case AtomicRMWKind::muli:
2501  return builder.getIntegerAttr(resultType, 1);
2502  case AtomicRMWKind::mulf:
2503  return builder.getFloatAttr(resultType, 1);
2504  // TODO: Add remaining reduction operations.
2505  default:
2506  (void)emitOptionalError(loc, "Reduction operation type not supported");
2507  break;
2508  }
2509  return nullptr;
2510 }
2511 
2512 /// Return the identity numeric value associated to the give op.
2513 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2514  std::optional<AtomicRMWKind> maybeKind =
2516  // Floating-point operations.
2517  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2518  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2519  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2520  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2521  // Integer operations.
2522  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2523  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2524  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2525  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2526  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2527  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2528  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2529  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2530  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2531  .Default([](Operation *op) { return std::nullopt; });
2532  if (!maybeKind) {
2533  op->emitError() << "Unknown neutral element for: " << *op;
2534  return std::nullopt;
2535  }
2536 
2537  bool useOnlyFiniteValue = false;
2538  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2539  if (fmfOpInterface) {
2540  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2541  useOnlyFiniteValue =
2542  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2543  }
2544 
2545  // Builder only used as helper for attribute creation.
2546  OpBuilder b(op->getContext());
2547  Type resultType = op->getResult(0).getType();
2548 
2549  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2550  useOnlyFiniteValue);
2551 }
2552 
2553 /// Returns the identity value associated with an AtomicRMWKind op.
2554 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2555  OpBuilder &builder, Location loc,
2556  bool useOnlyFiniteValue) {
2557  auto attr =
2558  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2559  return builder.create<arith::ConstantOp>(loc, attr);
2560 }
2561 
2562 /// Return the value obtained by applying the reduction operation kind
2563 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2564 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2565  Location loc, Value lhs, Value rhs) {
2566  switch (op) {
2567  case AtomicRMWKind::addf:
2568  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2569  case AtomicRMWKind::addi:
2570  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2571  case AtomicRMWKind::mulf:
2572  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2573  case AtomicRMWKind::muli:
2574  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2575  case AtomicRMWKind::maximumf:
2576  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2577  case AtomicRMWKind::minimumf:
2578  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2579  case AtomicRMWKind::maxnumf:
2580  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2581  case AtomicRMWKind::minnumf:
2582  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2583  case AtomicRMWKind::maxs:
2584  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2585  case AtomicRMWKind::mins:
2586  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2587  case AtomicRMWKind::maxu:
2588  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2589  case AtomicRMWKind::minu:
2590  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2591  case AtomicRMWKind::ori:
2592  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2593  case AtomicRMWKind::andi:
2594  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2595  // TODO: Add remaining reduction operations.
2596  default:
2597  (void)emitOptionalError(loc, "Reduction operation type not supported");
2598  break;
2599  }
2600  return nullptr;
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // TableGen'd op method definitions
2605 //===----------------------------------------------------------------------===//
2606 
2607 #define GET_OP_CLASSES
2608 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2609 
2610 //===----------------------------------------------------------------------===//
2611 // TableGen'd enum attribute definitions
2612 //===----------------------------------------------------------------------===//
2613 
2614 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1301
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1237
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1770
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1251
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1223
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:644
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1316
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1626
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1524
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1273
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:835
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1244
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1788
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1259
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1217
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:346
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1287
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1967
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1940
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This provides public APIs that all operations should have.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:67
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:279
static bool classof(Operation *op)
Definition: ArithOps.cpp:285
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:291
static bool classof(Operation *op)
Definition: ArithOps.cpp:297
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:258
static bool classof(Operation *op)
Definition: ArithOps.cpp:273
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2513
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1742
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2458
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2564
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2554
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:345
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
Definition: Matchers.h:384
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:371
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:350
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:363
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:355
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2235
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes