MLIR  22.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 
37 //===----------------------------------------------------------------------===//
38 // Pattern helpers
39 //===----------------------------------------------------------------------===//
40 
41 static IntegerAttr
43  Attribute rhs,
44  function_ref<APInt(const APInt &, const APInt &)> binFn) {
45  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47  APInt value = binFn(lhsVal, rhsVal);
48  return IntegerAttr::get(res.getType(), value);
49 }
50 
51 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
52  Attribute lhs, Attribute rhs) {
53  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54 }
55 
56 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
57  Attribute lhs, Attribute rhs) {
58  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59 }
60 
61 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
62  Attribute lhs, Attribute rhs) {
63  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64 }
65 
66 // Merge overflow flags from 2 ops, selecting the most conservative combination.
67 static IntegerOverflowFlagsAttr
68 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69  IntegerOverflowFlagsAttr val2) {
70  return IntegerOverflowFlagsAttr::get(val1.getContext(),
71  val1.getValue() & val2.getValue());
72 }
73 
74 /// Invert an integer comparison predicate.
75 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76  switch (pred) {
77  case arith::CmpIPredicate::eq:
78  return arith::CmpIPredicate::ne;
79  case arith::CmpIPredicate::ne:
80  return arith::CmpIPredicate::eq;
81  case arith::CmpIPredicate::slt:
82  return arith::CmpIPredicate::sge;
83  case arith::CmpIPredicate::sle:
84  return arith::CmpIPredicate::sgt;
85  case arith::CmpIPredicate::sgt:
86  return arith::CmpIPredicate::sle;
87  case arith::CmpIPredicate::sge:
88  return arith::CmpIPredicate::slt;
89  case arith::CmpIPredicate::ult:
90  return arith::CmpIPredicate::uge;
91  case arith::CmpIPredicate::ule:
92  return arith::CmpIPredicate::ugt;
93  case arith::CmpIPredicate::ugt:
94  return arith::CmpIPredicate::ule;
95  case arith::CmpIPredicate::uge:
96  return arith::CmpIPredicate::ult;
97  }
98  llvm_unreachable("unknown cmpi predicate kind");
99 }
100 
101 /// Equivalent to
102 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103 ///
104 /// Not possible to implement as chain of calls as this would introduce a
105 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106 /// on the LLVM dialect and on translation to LLVM.
107 static llvm::RoundingMode
108 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109  switch (roundingMode) {
110  case RoundingMode::downward:
111  return llvm::RoundingMode::TowardNegative;
112  case RoundingMode::to_nearest_away:
113  return llvm::RoundingMode::NearestTiesToAway;
114  case RoundingMode::to_nearest_even:
115  return llvm::RoundingMode::NearestTiesToEven;
116  case RoundingMode::toward_zero:
117  return llvm::RoundingMode::TowardZero;
118  case RoundingMode::upward:
119  return llvm::RoundingMode::TowardPositive;
120  }
121  llvm_unreachable("Unhandled rounding mode");
122 }
123 
124 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125  return arith::CmpIPredicateAttr::get(pred.getContext(),
126  invertPredicate(pred.getValue()));
127 }
128 
129 static int64_t getScalarOrElementWidth(Type type) {
130  Type elemTy = getElementTypeOrSelf(type);
131  if (elemTy.isIntOrFloat())
132  return elemTy.getIntOrFloatBitWidth();
133 
134  return -1;
135 }
136 
137 static int64_t getScalarOrElementWidth(Value value) {
138  return getScalarOrElementWidth(value.getType());
139 }
140 
141 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142  APInt value;
143  if (matchPattern(attr, m_ConstantInt(&value)))
144  return value;
145 
146  return failure();
147 }
148 
149 static Attribute getBoolAttribute(Type type, bool value) {
150  auto boolAttr = BoolAttr::get(type.getContext(), value);
151  ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152  if (!shapedType)
153  return boolAttr;
154  return DenseElementsAttr::get(shapedType, boolAttr);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TableGen'd canonicalization patterns
159 //===----------------------------------------------------------------------===//
160 
161 namespace {
162 #include "ArithCanonicalization.inc"
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // Common helpers
167 //===----------------------------------------------------------------------===//
168 
169 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
170 static Type getI1SameShape(Type type) {
171  auto i1Type = IntegerType::get(type.getContext(), 1);
172  if (auto shapedType = dyn_cast<ShapedType>(type))
173  return shapedType.cloneWith(std::nullopt, i1Type);
174  if (llvm::isa<UnrankedTensorType>(type))
175  return UnrankedTensorType::get(i1Type);
176  return i1Type;
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // ConstantOp
181 //===----------------------------------------------------------------------===//
182 
183 void arith::ConstantOp::getAsmResultNames(
184  function_ref<void(Value, StringRef)> setNameFn) {
185  auto type = getType();
186  if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187  auto intType = dyn_cast<IntegerType>(type);
188 
189  // Sugar i1 constants with 'true' and 'false'.
190  if (intType && intType.getWidth() == 1)
191  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192 
193  // Otherwise, build a complex name with the value and type.
194  SmallString<32> specialNameBuffer;
195  llvm::raw_svector_ostream specialName(specialNameBuffer);
196  specialName << 'c' << intCst.getValue();
197  if (intType)
198  specialName << '_' << type;
199  setNameFn(getResult(), specialName.str());
200  } else {
201  setNameFn(getResult(), "cst");
202  }
203 }
204 
205 /// TODO: disallow arith.constant to return anything other than signless integer
206 /// or float like.
207 LogicalResult arith::ConstantOp::verify() {
208  auto type = getType();
209  // Integer values must be signless.
210  if (llvm::isa<IntegerType>(type) &&
211  !llvm::cast<IntegerType>(type).isSignless())
212  return emitOpError("integer return type must be signless");
213  // Any float or elements attribute are acceptable.
214  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
215  return emitOpError(
216  "value must be an integer, float, or elements attribute");
217  }
218 
219  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
220  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
221  // However, this would most likely require updating the lowerings to LLVM.
222  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
223  return emitOpError(
224  "intializing scalable vectors with elements attribute is not supported"
225  " unless it's a vector splat");
226  return success();
227 }
228 
229 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
230  // The value's type must be the same as the provided type.
231  auto typedAttr = dyn_cast<TypedAttr>(value);
232  if (!typedAttr || typedAttr.getType() != type)
233  return false;
234  // Integer values must be signless.
235  if (llvm::isa<IntegerType>(type) &&
236  !llvm::cast<IntegerType>(type).isSignless())
237  return false;
238  // Integer, float, and element attributes are buildable.
239  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
240 }
241 
242 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
243  Type type, Location loc) {
244  if (isBuildableWith(value, type))
245  return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
246  return nullptr;
247 }
248 
249 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
250 
252  int64_t value, unsigned width) {
253  auto type = builder.getIntegerType(width);
254  arith::ConstantOp::build(builder, result, type,
255  builder.getIntegerAttr(type, value));
256 }
257 
259  Location location,
260  int64_t value,
261  unsigned width) {
262  mlir::OperationState state(location, getOperationName());
263  build(builder, state, value, width);
264  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
265  assert(result && "builder didn't return the right type");
266  return result;
267 }
268 
270  int64_t value,
271  unsigned width) {
272  return create(builder, builder.getLoc(), value, width);
273 }
274 
276  Type type, int64_t value) {
277  arith::ConstantOp::build(builder, result, type,
278  builder.getIntegerAttr(type, value));
279 }
280 
282  Location location, Type type,
283  int64_t value) {
284  mlir::OperationState state(location, getOperationName());
285  build(builder, state, type, value);
286  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
287  assert(result && "builder didn't return the right type");
288  return result;
289 }
290 
292  Type type, int64_t value) {
293  return create(builder, builder.getLoc(), type, value);
294 }
295 
297  Type type, const APInt &value) {
298  arith::ConstantOp::build(builder, result, type,
299  builder.getIntegerAttr(type, value));
300 }
301 
303  Location location, Type type,
304  const APInt &value) {
305  mlir::OperationState state(location, getOperationName());
306  build(builder, state, type, value);
307  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
308  assert(result && "builder didn't return the right type");
309  return result;
310 }
311 
313  Type type,
314  const APInt &value) {
315  return create(builder, builder.getLoc(), type, value);
316 }
317 
319  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320  return constOp.getType().isSignlessInteger();
321  return false;
322 }
323 
325  FloatType type, const APFloat &value) {
326  arith::ConstantOp::build(builder, result, type,
327  builder.getFloatAttr(type, value));
328 }
329 
331  Location location,
332  FloatType type,
333  const APFloat &value) {
334  mlir::OperationState state(location, getOperationName());
335  build(builder, state, type, value);
336  auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
337  assert(result && "builder didn't return the right type");
338  return result;
339 }
340 
343  const APFloat &value) {
344  return create(builder, builder.getLoc(), type, value);
345 }
346 
348  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349  return llvm::isa<FloatType>(constOp.getType());
350  return false;
351 }
352 
354  int64_t value) {
355  arith::ConstantOp::build(builder, result, builder.getIndexType(),
356  builder.getIndexAttr(value));
357 }
358 
360  Location location,
361  int64_t value) {
362  mlir::OperationState state(location, getOperationName());
363  build(builder, state, value);
364  auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
365  assert(result && "builder didn't return the right type");
366  return result;
367 }
368 
371  return create(builder, builder.getLoc(), value);
372 }
373 
375  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376  return constOp.getType().isIndex();
377  return false;
378 }
379 
381  Type type) {
382  // TODO: Incorporate this check to `FloatAttr::get*`.
383  assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
384  "type doesn't have a zero representation");
385  TypedAttr zeroAttr = builder.getZeroAttr(type);
386  assert(zeroAttr && "unsupported type for zero attribute");
387  return arith::ConstantOp::create(builder, loc, zeroAttr);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // AddIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
395  // addi(x, 0) -> x
396  if (matchPattern(adaptor.getRhs(), m_Zero()))
397  return getLhs();
398 
399  // addi(subi(a, b), b) -> a
400  if (auto sub = getLhs().getDefiningOp<SubIOp>())
401  if (getRhs() == sub.getRhs())
402  return sub.getLhs();
403 
404  // addi(b, subi(a, b)) -> a
405  if (auto sub = getRhs().getDefiningOp<SubIOp>())
406  if (getLhs() == sub.getRhs())
407  return sub.getLhs();
408 
409  return constFoldBinaryOp<IntegerAttr>(
410  adaptor.getOperands(),
411  [](APInt a, const APInt &b) { return std::move(a) + b; });
412 }
413 
414 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
415  MLIRContext *context) {
416  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // AddUIExtendedOp
422 //===----------------------------------------------------------------------===//
423 
424 std::optional<SmallVector<int64_t, 4>>
425 arith::AddUIExtendedOp::getShapeForUnroll() {
426  if (auto vt = dyn_cast<VectorType>(getType(0)))
427  return llvm::to_vector<4>(vt.getShape());
428  return std::nullopt;
429 }
430 
431 // Returns the overflow bit, assuming that `sum` is the result of unsigned
432 // addition of `operand` and another number.
433 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
434  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
435 }
436 
437 LogicalResult
438 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
440  Type overflowTy = getOverflow().getType();
441  // addui_extended(x, 0) -> x, false
442  if (matchPattern(getRhs(), m_Zero())) {
443  Builder builder(getContext());
444  auto falseValue = builder.getZeroAttr(overflowTy);
445 
446  results.push_back(getLhs());
447  results.push_back(falseValue);
448  return success();
449  }
450 
451  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
452  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
453  // operands. If that succeeds, calculate the overflow bit based on the sum
454  // and the first (constant) operand, `lhs`.
455  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
456  adaptor.getOperands(),
457  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
458  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
459  ArrayRef({sumAttr, adaptor.getLhs()}),
460  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
462  if (!overflowAttr)
463  return failure();
464 
465  results.push_back(sumAttr);
466  results.push_back(overflowAttr);
467  return success();
468  }
469 
470  return failure();
471 }
472 
473 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
475  patterns.add<AddUIExtendedToAddI>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // SubIOp
480 //===----------------------------------------------------------------------===//
481 
482 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
483  // subi(x,x) -> 0
484  if (getOperand(0) == getOperand(1)) {
485  auto shapedType = dyn_cast<ShapedType>(getType());
486  // We can't generate a constant with a dynamic shaped tensor.
487  if (!shapedType || shapedType.hasStaticShape())
488  return Builder(getContext()).getZeroAttr(getType());
489  }
490  // subi(x,0) -> x
491  if (matchPattern(adaptor.getRhs(), m_Zero()))
492  return getLhs();
493 
494  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
495  // subi(addi(a, b), b) -> a
496  if (getRhs() == add.getRhs())
497  return add.getLhs();
498  // subi(addi(a, b), a) -> b
499  if (getRhs() == add.getLhs())
500  return add.getRhs();
501  }
502 
503  return constFoldBinaryOp<IntegerAttr>(
504  adaptor.getOperands(),
505  [](APInt a, const APInt &b) { return std::move(a) - b; });
506 }
507 
508 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
509  MLIRContext *context) {
510  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // MulIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
520  // muli(x, 0) -> 0
521  if (matchPattern(adaptor.getRhs(), m_Zero()))
522  return getRhs();
523  // muli(x, 1) -> x
524  if (matchPattern(adaptor.getRhs(), m_One()))
525  return getLhs();
526  // TODO: Handle the overflow case.
527 
528  // default folder
529  return constFoldBinaryOp<IntegerAttr>(
530  adaptor.getOperands(),
531  [](const APInt &a, const APInt &b) { return a * b; });
532 }
533 
534 void arith::MulIOp::getAsmResultNames(
535  function_ref<void(Value, StringRef)> setNameFn) {
536  if (!isa<IndexType>(getType()))
537  return;
538 
539  // Match vector.vscale by name to avoid depending on the vector dialect (which
540  // is a circular dependency).
541  auto isVscale = [](Operation *op) {
542  return op && op->getName().getStringRef() == "vector.vscale";
543  };
544 
545  IntegerAttr baseValue;
546  auto isVscaleExpr = [&](Value a, Value b) {
547  return matchPattern(a, m_Constant(&baseValue)) &&
548  isVscale(b.getDefiningOp());
549  };
550 
551  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
552  return;
553 
554  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
555  SmallString<32> specialNameBuffer;
556  llvm::raw_svector_ostream specialName(specialNameBuffer);
557  specialName << 'c' << baseValue.getInt() << "_vscale";
558  setNameFn(getResult(), specialName.str());
559 }
560 
561 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
562  MLIRContext *context) {
563  patterns.add<MulIMulIConstant>(context);
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // MulSIExtendedOp
568 //===----------------------------------------------------------------------===//
569 
570 std::optional<SmallVector<int64_t, 4>>
571 arith::MulSIExtendedOp::getShapeForUnroll() {
572  if (auto vt = dyn_cast<VectorType>(getType(0)))
573  return llvm::to_vector<4>(vt.getShape());
574  return std::nullopt;
575 }
576 
577 LogicalResult
578 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
580  // mulsi_extended(x, 0) -> 0, 0
581  if (matchPattern(adaptor.getRhs(), m_Zero())) {
582  Attribute zero = adaptor.getRhs();
583  results.push_back(zero);
584  results.push_back(zero);
585  return success();
586  }
587 
588  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
589  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
590  adaptor.getOperands(),
591  [](const APInt &a, const APInt &b) { return a * b; })) {
592  // Invoke the constant fold helper again to calculate the 'high' result.
593  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
594  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
595  return llvm::APIntOps::mulhs(a, b);
596  });
597  assert(highAttr && "Unexpected constant-folding failure");
598 
599  results.push_back(lowAttr);
600  results.push_back(highAttr);
601  return success();
602  }
603 
604  return failure();
605 }
606 
607 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
609  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // MulUIExtendedOp
614 //===----------------------------------------------------------------------===//
615 
616 std::optional<SmallVector<int64_t, 4>>
617 arith::MulUIExtendedOp::getShapeForUnroll() {
618  if (auto vt = dyn_cast<VectorType>(getType(0)))
619  return llvm::to_vector<4>(vt.getShape());
620  return std::nullopt;
621 }
622 
623 LogicalResult
624 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
626  // mului_extended(x, 0) -> 0, 0
627  if (matchPattern(adaptor.getRhs(), m_Zero())) {
628  Attribute zero = adaptor.getRhs();
629  results.push_back(zero);
630  results.push_back(zero);
631  return success();
632  }
633 
634  // mului_extended(x, 1) -> x, 0
635  if (matchPattern(adaptor.getRhs(), m_One())) {
636  Builder builder(getContext());
637  Attribute zero = builder.getZeroAttr(getLhs().getType());
638  results.push_back(getLhs());
639  results.push_back(zero);
640  return success();
641  }
642 
643  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
644  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
645  adaptor.getOperands(),
646  [](const APInt &a, const APInt &b) { return a * b; })) {
647  // Invoke the constant fold helper again to calculate the 'high' result.
648  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
649  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
650  return llvm::APIntOps::mulhu(a, b);
651  });
652  assert(highAttr && "Unexpected constant-folding failure");
653 
654  results.push_back(lowAttr);
655  results.push_back(highAttr);
656  return success();
657  }
658 
659  return failure();
660 }
661 
662 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
664  patterns.add<MulUIExtendedToMulI>(context);
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // DivUIOp
669 //===----------------------------------------------------------------------===//
670 
671 /// Fold `(a * b) / b -> a`
672 static Value foldDivMul(Value lhs, Value rhs,
673  arith::IntegerOverflowFlags ovfFlags) {
674  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
675  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
676  return {};
677 
678  if (mul.getLhs() == rhs)
679  return mul.getRhs();
680 
681  if (mul.getRhs() == rhs)
682  return mul.getLhs();
683 
684  return {};
685 }
686 
687 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
688  // divui (x, 1) -> x.
689  if (matchPattern(adaptor.getRhs(), m_One()))
690  return getLhs();
691 
692  // (a * b) / b -> a
693  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
694  return val;
695 
696  // Don't fold if it would require a division by zero.
697  bool div0 = false;
698  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
699  [&](APInt a, const APInt &b) {
700  if (div0 || !b) {
701  div0 = true;
702  return a;
703  }
704  return a.udiv(b);
705  });
706 
707  return div0 ? Attribute() : result;
708 }
709 
710 /// Returns whether an unsigned division by `divisor` is speculatable.
712  // X / 0 => UB
713  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
715 
717 }
718 
719 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
720  return getDivUISpeculatability(getRhs());
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // DivSIOp
725 //===----------------------------------------------------------------------===//
726 
727 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
728  // divsi (x, 1) -> x.
729  if (matchPattern(adaptor.getRhs(), m_One()))
730  return getLhs();
731 
732  // (a * b) / b -> a
733  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
734  return val;
735 
736  // Don't fold if it would overflow or if it requires a division by zero.
737  bool overflowOrDiv0 = false;
738  auto result = constFoldBinaryOp<IntegerAttr>(
739  adaptor.getOperands(), [&](APInt a, const APInt &b) {
740  if (overflowOrDiv0 || !b) {
741  overflowOrDiv0 = true;
742  return a;
743  }
744  return a.sdiv_ov(b, overflowOrDiv0);
745  });
746 
747  return overflowOrDiv0 ? Attribute() : result;
748 }
749 
750 /// Returns whether a signed division by `divisor` is speculatable. This
751 /// function conservatively assumes that all signed division by -1 are not
752 /// speculatable.
754  // X / 0 => UB
755  // INT_MIN / -1 => UB
756  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
759 
761 }
762 
763 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
764  return getDivSISpeculatability(getRhs());
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // Ceil and floor division folding helpers
769 //===----------------------------------------------------------------------===//
770 
771 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
772  bool &overflow) {
773  // Returns (a-1)/b + 1
774  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
775  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
776  return val.sadd_ov(one, overflow);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // CeilDivUIOp
781 //===----------------------------------------------------------------------===//
782 
783 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
784  // ceildivui (x, 1) -> x.
785  if (matchPattern(adaptor.getRhs(), m_One()))
786  return getLhs();
787 
788  bool overflowOrDiv0 = false;
789  auto result = constFoldBinaryOp<IntegerAttr>(
790  adaptor.getOperands(), [&](APInt a, const APInt &b) {
791  if (overflowOrDiv0 || !b) {
792  overflowOrDiv0 = true;
793  return a;
794  }
795  APInt quotient = a.udiv(b);
796  if (!a.urem(b))
797  return quotient;
798  APInt one(a.getBitWidth(), 1, true);
799  return quotient.uadd_ov(one, overflowOrDiv0);
800  });
801 
802  return overflowOrDiv0 ? Attribute() : result;
803 }
804 
805 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
806  return getDivUISpeculatability(getRhs());
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // CeilDivSIOp
811 //===----------------------------------------------------------------------===//
812 
813 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
814  // ceildivsi (x, 1) -> x.
815  if (matchPattern(adaptor.getRhs(), m_One()))
816  return getLhs();
817 
818  // Don't fold if it would overflow or if it requires a division by zero.
819  // TODO: This hook won't fold operations where a = MININT, because
820  // negating MININT overflows. This can be improved.
821  bool overflowOrDiv0 = false;
822  auto result = constFoldBinaryOp<IntegerAttr>(
823  adaptor.getOperands(), [&](APInt a, const APInt &b) {
824  if (overflowOrDiv0 || !b) {
825  overflowOrDiv0 = true;
826  return a;
827  }
828  if (!a)
829  return a;
830  // After this point we know that neither a or b are zero.
831  unsigned bits = a.getBitWidth();
832  APInt zero = APInt::getZero(bits);
833  bool aGtZero = a.sgt(zero);
834  bool bGtZero = b.sgt(zero);
835  if (aGtZero && bGtZero) {
836  // Both positive, return ceil(a, b).
837  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
838  }
839 
840  // No folding happens if any of the intermediate arithmetic operations
841  // overflows.
842  bool overflowNegA = false;
843  bool overflowNegB = false;
844  bool overflowDiv = false;
845  bool overflowNegRes = false;
846  if (!aGtZero && !bGtZero) {
847  // Both negative, return ceil(-a, -b).
848  APInt posA = zero.ssub_ov(a, overflowNegA);
849  APInt posB = zero.ssub_ov(b, overflowNegB);
850  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
851  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
852  return res;
853  }
854  if (!aGtZero && bGtZero) {
855  // A is negative, b is positive, return - ( -a / b).
856  APInt posA = zero.ssub_ov(a, overflowNegA);
857  APInt div = posA.sdiv_ov(b, overflowDiv);
858  APInt res = zero.ssub_ov(div, overflowNegRes);
859  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
860  return res;
861  }
862  // A is positive, b is negative, return - (a / -b).
863  APInt posB = zero.ssub_ov(b, overflowNegB);
864  APInt div = a.sdiv_ov(posB, overflowDiv);
865  APInt res = zero.ssub_ov(div, overflowNegRes);
866 
867  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
868  return res;
869  });
870 
871  return overflowOrDiv0 ? Attribute() : result;
872 }
873 
874 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
875  return getDivSISpeculatability(getRhs());
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // FloorDivSIOp
880 //===----------------------------------------------------------------------===//
881 
882 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
883  // floordivsi (x, 1) -> x.
884  if (matchPattern(adaptor.getRhs(), m_One()))
885  return getLhs();
886 
887  // Don't fold if it would overflow or if it requires a division by zero.
888  bool overflowOrDiv = false;
889  auto result = constFoldBinaryOp<IntegerAttr>(
890  adaptor.getOperands(), [&](APInt a, const APInt &b) {
891  if (b.isZero()) {
892  overflowOrDiv = true;
893  return a;
894  }
895  return a.sfloordiv_ov(b, overflowOrDiv);
896  });
897 
898  return overflowOrDiv ? Attribute() : result;
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // RemUIOp
903 //===----------------------------------------------------------------------===//
904 
905 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
906  // remui (x, 1) -> 0.
907  if (matchPattern(adaptor.getRhs(), m_One()))
908  return Builder(getContext()).getZeroAttr(getType());
909 
910  // Don't fold if it would require a division by zero.
911  bool div0 = false;
912  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
913  [&](APInt a, const APInt &b) {
914  if (div0 || b.isZero()) {
915  div0 = true;
916  return a;
917  }
918  return a.urem(b);
919  });
920 
921  return div0 ? Attribute() : result;
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // RemSIOp
926 //===----------------------------------------------------------------------===//
927 
928 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
929  // remsi (x, 1) -> 0.
930  if (matchPattern(adaptor.getRhs(), m_One()))
931  return Builder(getContext()).getZeroAttr(getType());
932 
933  // Don't fold if it would require a division by zero.
934  bool div0 = false;
935  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
936  [&](APInt a, const APInt &b) {
937  if (div0 || b.isZero()) {
938  div0 = true;
939  return a;
940  }
941  return a.srem(b);
942  });
943 
944  return div0 ? Attribute() : result;
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // AndIOp
949 //===----------------------------------------------------------------------===//
950 
951 /// Fold `and(a, and(a, b))` to `and(a, b)`
952 static Value foldAndIofAndI(arith::AndIOp op) {
953  for (bool reversePrev : {false, true}) {
954  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955  .getDefiningOp<arith::AndIOp>();
956  if (!prev)
957  continue;
958 
959  Value other = (reversePrev ? op.getLhs() : op.getRhs());
960  if (other != prev.getLhs() && other != prev.getRhs())
961  continue;
962 
963  return prev.getResult();
964  }
965  return {};
966 }
967 
968 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
969  /// and(x, 0) -> 0
970  if (matchPattern(adaptor.getRhs(), m_Zero()))
971  return getRhs();
972  /// and(x, allOnes) -> x
973  APInt intValue;
974  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
975  intValue.isAllOnes())
976  return getLhs();
977  /// and(x, not(x)) -> 0
978  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
979  m_ConstantInt(&intValue))) &&
980  intValue.isAllOnes())
981  return Builder(getContext()).getZeroAttr(getType());
982  /// and(not(x), x) -> 0
983  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
984  m_ConstantInt(&intValue))) &&
985  intValue.isAllOnes())
986  return Builder(getContext()).getZeroAttr(getType());
987 
988  /// and(a, and(a, b)) -> and(a, b)
989  if (Value result = foldAndIofAndI(*this))
990  return result;
991 
992  return constFoldBinaryOp<IntegerAttr>(
993  adaptor.getOperands(),
994  [](APInt a, const APInt &b) { return std::move(a) & b; });
995 }
996 
997 //===----------------------------------------------------------------------===//
998 // OrIOp
999 //===----------------------------------------------------------------------===//
1000 
1001 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1002  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1003  /// or(x, 0) -> x
1004  if (rhsVal.isZero())
1005  return getLhs();
1006  /// or(x, <all ones>) -> <all ones>
1007  if (rhsVal.isAllOnes())
1008  return adaptor.getRhs();
1009  }
1010 
1011  APInt intValue;
1012  /// or(x, xor(x, 1)) -> 1
1013  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1014  m_ConstantInt(&intValue))) &&
1015  intValue.isAllOnes())
1016  return getRhs().getDefiningOp<XOrIOp>().getRhs();
1017  /// or(xor(x, 1), x) -> 1
1018  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1019  m_ConstantInt(&intValue))) &&
1020  intValue.isAllOnes())
1021  return getLhs().getDefiningOp<XOrIOp>().getRhs();
1022 
1023  return constFoldBinaryOp<IntegerAttr>(
1024  adaptor.getOperands(),
1025  [](APInt a, const APInt &b) { return std::move(a) | b; });
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // XOrIOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1033  /// xor(x, 0) -> x
1034  if (matchPattern(adaptor.getRhs(), m_Zero()))
1035  return getLhs();
1036  /// xor(x, x) -> 0
1037  if (getLhs() == getRhs())
1038  return Builder(getContext()).getZeroAttr(getType());
1039  /// xor(xor(x, a), a) -> x
1040  /// xor(xor(a, x), a) -> x
1041  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042  if (prev.getRhs() == getRhs())
1043  return prev.getLhs();
1044  if (prev.getLhs() == getRhs())
1045  return prev.getRhs();
1046  }
1047  /// xor(a, xor(x, a)) -> x
1048  /// xor(a, xor(a, x)) -> x
1049  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050  if (prev.getRhs() == getLhs())
1051  return prev.getLhs();
1052  if (prev.getLhs() == getLhs())
1053  return prev.getRhs();
1054  }
1055 
1056  return constFoldBinaryOp<IntegerAttr>(
1057  adaptor.getOperands(),
1058  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1059 }
1060 
1061 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1062  MLIRContext *context) {
1063  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // NegFOp
1068 //===----------------------------------------------------------------------===//
1069 
1070 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1071  /// negf(negf(x)) -> x
1072  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073  return op.getOperand();
1074  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1075  [](const APFloat &a) { return -a; });
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // AddFOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1083  // addf(x, -0) -> x
1084  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1085  return getLhs();
1086 
1087  return constFoldBinaryOp<FloatAttr>(
1088  adaptor.getOperands(),
1089  [](const APFloat &a, const APFloat &b) { return a + b; });
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // SubFOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1097  // subf(x, +0) -> x
1098  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1099  return getLhs();
1100 
1101  return constFoldBinaryOp<FloatAttr>(
1102  adaptor.getOperands(),
1103  [](const APFloat &a, const APFloat &b) { return a - b; });
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // MaximumFOp
1108 //===----------------------------------------------------------------------===//
1109 
1110 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1111  // maximumf(x,x) -> x
1112  if (getLhs() == getRhs())
1113  return getRhs();
1114 
1115  // maximumf(x, -inf) -> x
1116  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1117  return getLhs();
1118 
1119  return constFoldBinaryOp<FloatAttr>(
1120  adaptor.getOperands(),
1121  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1122 }
1123 
1124 //===----------------------------------------------------------------------===//
1125 // MaxNumFOp
1126 //===----------------------------------------------------------------------===//
1127 
1128 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1129  // maxnumf(x,x) -> x
1130  if (getLhs() == getRhs())
1131  return getRhs();
1132 
1133  // maxnumf(x, NaN) -> x
1134  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1135  return getLhs();
1136 
1137  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // MaxSIOp
1142 //===----------------------------------------------------------------------===//
1143 
1144 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1145  // maxsi(x,x) -> x
1146  if (getLhs() == getRhs())
1147  return getRhs();
1148 
1149  if (APInt intValue;
1150  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1151  // maxsi(x,MAX_INT) -> MAX_INT
1152  if (intValue.isMaxSignedValue())
1153  return getRhs();
1154  // maxsi(x, MIN_INT) -> x
1155  if (intValue.isMinSignedValue())
1156  return getLhs();
1157  }
1158 
1159  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1160  [](const APInt &a, const APInt &b) {
1161  return llvm::APIntOps::smax(a, b);
1162  });
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // MaxUIOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1170  // maxui(x,x) -> x
1171  if (getLhs() == getRhs())
1172  return getRhs();
1173 
1174  if (APInt intValue;
1175  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1176  // maxui(x,MAX_INT) -> MAX_INT
1177  if (intValue.isMaxValue())
1178  return getRhs();
1179  // maxui(x, MIN_INT) -> x
1180  if (intValue.isMinValue())
1181  return getLhs();
1182  }
1183 
1184  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1185  [](const APInt &a, const APInt &b) {
1186  return llvm::APIntOps::umax(a, b);
1187  });
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // MinimumFOp
1192 //===----------------------------------------------------------------------===//
1193 
1194 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1195  // minimumf(x,x) -> x
1196  if (getLhs() == getRhs())
1197  return getRhs();
1198 
1199  // minimumf(x, +inf) -> x
1200  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1201  return getLhs();
1202 
1203  return constFoldBinaryOp<FloatAttr>(
1204  adaptor.getOperands(),
1205  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // MinNumFOp
1210 //===----------------------------------------------------------------------===//
1211 
1212 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1213  // minnumf(x,x) -> x
1214  if (getLhs() == getRhs())
1215  return getRhs();
1216 
1217  // minnumf(x, NaN) -> x
1218  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1219  return getLhs();
1220 
1221  return constFoldBinaryOp<FloatAttr>(
1222  adaptor.getOperands(),
1223  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // MinSIOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1231  // minsi(x,x) -> x
1232  if (getLhs() == getRhs())
1233  return getRhs();
1234 
1235  if (APInt intValue;
1236  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1237  // minsi(x,MIN_INT) -> MIN_INT
1238  if (intValue.isMinSignedValue())
1239  return getRhs();
1240  // minsi(x, MAX_INT) -> x
1241  if (intValue.isMaxSignedValue())
1242  return getLhs();
1243  }
1244 
1245  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246  [](const APInt &a, const APInt &b) {
1247  return llvm::APIntOps::smin(a, b);
1248  });
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // MinUIOp
1253 //===----------------------------------------------------------------------===//
1254 
1255 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1256  // minui(x,x) -> x
1257  if (getLhs() == getRhs())
1258  return getRhs();
1259 
1260  if (APInt intValue;
1261  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1262  // minui(x,MIN_INT) -> MIN_INT
1263  if (intValue.isMinValue())
1264  return getRhs();
1265  // minui(x, MAX_INT) -> x
1266  if (intValue.isMaxValue())
1267  return getLhs();
1268  }
1269 
1270  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1271  [](const APInt &a, const APInt &b) {
1272  return llvm::APIntOps::umin(a, b);
1273  });
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // MulFOp
1278 //===----------------------------------------------------------------------===//
1279 
1280 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1281  // mulf(x, 1) -> x
1282  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1283  return getLhs();
1284 
1285  return constFoldBinaryOp<FloatAttr>(
1286  adaptor.getOperands(),
1287  [](const APFloat &a, const APFloat &b) { return a * b; });
1288 }
1289 
1290 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1291  MLIRContext *context) {
1292  patterns.add<MulFOfNegF>(context);
1293 }
1294 
1295 //===----------------------------------------------------------------------===//
1296 // DivFOp
1297 //===----------------------------------------------------------------------===//
1298 
1299 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1300  // divf(x, 1) -> x
1301  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1302  return getLhs();
1303 
1304  return constFoldBinaryOp<FloatAttr>(
1305  adaptor.getOperands(),
1306  [](const APFloat &a, const APFloat &b) { return a / b; });
1307 }
1308 
1309 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1310  MLIRContext *context) {
1311  patterns.add<DivFOfNegF>(context);
1312 }
1313 
1314 //===----------------------------------------------------------------------===//
1315 // RemFOp
1316 //===----------------------------------------------------------------------===//
1317 
1318 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1319  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1320  [](const APFloat &a, const APFloat &b) {
1321  APFloat result(a);
1322  // APFloat::mod() offers the remainder
1323  // behavior we want, i.e. the result has
1324  // the sign of LHS operand.
1325  (void)result.mod(b);
1326  return result;
1327  });
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // Utility functions for verifying cast ops
1332 //===----------------------------------------------------------------------===//
1333 
1334 template <typename... Types>
1335 using type_list = std::tuple<Types...> *;
1336 
1337 /// Returns a non-null type only if the provided type is one of the allowed
1338 /// types or one of the allowed shaped types of the allowed types. Returns the
1339 /// element type if a valid shaped type is provided.
1340 template <typename... ShapedTypes, typename... ElementTypes>
1343  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1344  return {};
1345 
1346  auto underlyingType = getElementTypeOrSelf(type);
1347  if (!llvm::isa<ElementTypes...>(underlyingType))
1348  return {};
1349 
1350  return underlyingType;
1351 }
1352 
1353 /// Get allowed underlying types for vectors and tensors.
1354 template <typename... ElementTypes>
1355 static Type getTypeIfLike(Type type) {
1358 }
1359 
1360 /// Get allowed underlying types for vectors, tensors, and memrefs.
1361 template <typename... ElementTypes>
1363  return getUnderlyingType(type,
1366 }
1367 
1368 /// Return false if both types are ranked tensor with mismatching encoding.
1369 static bool hasSameEncoding(Type typeA, Type typeB) {
1370  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1371  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1372  if (!rankedTensorA || !rankedTensorB)
1373  return true;
1374  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1375 }
1376 
1377 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1378  if (inputs.size() != 1 || outputs.size() != 1)
1379  return false;
1380  if (!hasSameEncoding(inputs.front(), outputs.front()))
1381  return false;
1382  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1383 }
1384 
1385 //===----------------------------------------------------------------------===//
1386 // Verifiers for integer and floating point extension/truncation ops
1387 //===----------------------------------------------------------------------===//
1388 
1389 // Extend ops can only extend to a wider type.
1390 template <typename ValType, typename Op>
1391 static LogicalResult verifyExtOp(Op op) {
1392  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1393  Type dstType = getElementTypeOrSelf(op.getType());
1394 
1395  if (llvm::cast<ValType>(srcType).getWidth() >=
1396  llvm::cast<ValType>(dstType).getWidth())
1397  return op.emitError("result type ")
1398  << dstType << " must be wider than operand type " << srcType;
1399 
1400  return success();
1401 }
1402 
1403 // Truncate ops can only truncate to a shorter type.
1404 template <typename ValType, typename Op>
1405 static LogicalResult verifyTruncateOp(Op op) {
1406  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1407  Type dstType = getElementTypeOrSelf(op.getType());
1408 
1409  if (llvm::cast<ValType>(srcType).getWidth() <=
1410  llvm::cast<ValType>(dstType).getWidth())
1411  return op.emitError("result type ")
1412  << dstType << " must be shorter than operand type " << srcType;
1413 
1414  return success();
1415 }
1416 
1417 /// Validate a cast that changes the width of a type.
1418 template <template <typename> class WidthComparator, typename... ElementTypes>
1419 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1420  if (!areValidCastInputsAndOutputs(inputs, outputs))
1421  return false;
1422 
1423  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1424  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1425  if (!srcType || !dstType)
1426  return false;
1427 
1428  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1429  srcType.getIntOrFloatBitWidth());
1430 }
1431 
1432 /// Attempts to convert `sourceValue` to an APFloat value with
1433 /// `targetSemantics` and `roundingMode`, without any information loss.
1434 static FailureOr<APFloat> convertFloatValue(
1435  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1436  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1437  bool losesInfo = false;
1438  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1439  if (losesInfo || status != APFloat::opOK)
1440  return failure();
1441 
1442  return sourceValue;
1443 }
1444 
1445 //===----------------------------------------------------------------------===//
1446 // ExtUIOp
1447 //===----------------------------------------------------------------------===//
1448 
1449 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1450  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1451  getInMutable().assign(lhs.getIn());
1452  return getResult();
1453  }
1454 
1455  Type resType = getElementTypeOrSelf(getType());
1456  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1457  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1458  adaptor.getOperands(), getType(),
1459  [bitWidth](const APInt &a, bool &castStatus) {
1460  return a.zext(bitWidth);
1461  });
1462 }
1463 
1464 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1465  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1466 }
1467 
1468 LogicalResult arith::ExtUIOp::verify() {
1469  return verifyExtOp<IntegerType>(*this);
1470 }
1471 
1472 //===----------------------------------------------------------------------===//
1473 // ExtSIOp
1474 //===----------------------------------------------------------------------===//
1475 
1476 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1477  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1478  getInMutable().assign(lhs.getIn());
1479  return getResult();
1480  }
1481 
1482  Type resType = getElementTypeOrSelf(getType());
1483  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1484  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1485  adaptor.getOperands(), getType(),
1486  [bitWidth](const APInt &a, bool &castStatus) {
1487  return a.sext(bitWidth);
1488  });
1489 }
1490 
1491 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1492  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1493 }
1494 
1495 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1496  MLIRContext *context) {
1497  patterns.add<ExtSIOfExtUI>(context);
1498 }
1499 
1500 LogicalResult arith::ExtSIOp::verify() {
1501  return verifyExtOp<IntegerType>(*this);
1502 }
1503 
1504 //===----------------------------------------------------------------------===//
1505 // ExtFOp
1506 //===----------------------------------------------------------------------===//
1507 
1508 /// Fold extension of float constants when there is no information loss due the
1509 /// difference in fp semantics.
1510 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1511  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1512  if (truncFOp.getOperand().getType() == getType()) {
1513  arith::FastMathFlags truncFMF =
1514  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1515  bool isTruncContract =
1516  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1517  arith::FastMathFlags extFMF =
1518  getFastmath().value_or(arith::FastMathFlags::none);
1519  bool isExtContract =
1520  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1521  if (isTruncContract && isExtContract) {
1522  return truncFOp.getOperand();
1523  }
1524  }
1525  }
1526 
1527  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1528  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1529  return constFoldCastOp<FloatAttr, FloatAttr>(
1530  adaptor.getOperands(), getType(),
1531  [&targetSemantics](const APFloat &a, bool &castStatus) {
1532  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1533  if (failed(result)) {
1534  castStatus = false;
1535  return a;
1536  }
1537  return *result;
1538  });
1539 }
1540 
1541 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1542  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1543 }
1544 
1545 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1546 
1547 //===----------------------------------------------------------------------===//
1548 // ScalingExtFOp
1549 //===----------------------------------------------------------------------===//
1550 
1551 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1552  TypeRange outputs) {
1553  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1554 }
1555 
1556 LogicalResult arith::ScalingExtFOp::verify() {
1557  return verifyExtOp<FloatType>(*this);
1558 }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // TruncIOp
1562 //===----------------------------------------------------------------------===//
1563 
1564 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1565  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1566  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1567  Value src = getOperand().getDefiningOp()->getOperand(0);
1568  Type srcType = getElementTypeOrSelf(src.getType());
1569  Type dstType = getElementTypeOrSelf(getType());
1570  // trunci(zexti(a)) -> trunci(a)
1571  // trunci(sexti(a)) -> trunci(a)
1572  if (llvm::cast<IntegerType>(srcType).getWidth() >
1573  llvm::cast<IntegerType>(dstType).getWidth()) {
1574  setOperand(src);
1575  return getResult();
1576  }
1577 
1578  // trunci(zexti(a)) -> a
1579  // trunci(sexti(a)) -> a
1580  if (srcType == dstType)
1581  return src;
1582  }
1583 
1584  // trunci(trunci(a)) -> trunci(a))
1585  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1586  setOperand(getOperand().getDefiningOp()->getOperand(0));
1587  return getResult();
1588  }
1589 
1590  Type resType = getElementTypeOrSelf(getType());
1591  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1592  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1593  adaptor.getOperands(), getType(),
1594  [bitWidth](const APInt &a, bool &castStatus) {
1595  return a.trunc(bitWidth);
1596  });
1597 }
1598 
1599 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1600  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1601 }
1602 
1603 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1604  MLIRContext *context) {
1605  patterns
1606  .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1607  context);
1608 }
1609 
1610 LogicalResult arith::TruncIOp::verify() {
1611  return verifyTruncateOp<IntegerType>(*this);
1612 }
1613 
1614 //===----------------------------------------------------------------------===//
1615 // TruncFOp
1616 //===----------------------------------------------------------------------===//
1617 
1618 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1619 /// can be represented without precision loss.
1620 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1621  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1622  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1623  Value src = extOp.getIn();
1624  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1625  auto intermediateType =
1626  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1627  // Check if the srcType is representable in the intermediateType.
1628  if (llvm::APFloatBase::isRepresentableBy(
1629  srcType.getFloatSemantics(),
1630  intermediateType.getFloatSemantics())) {
1631  // truncf(extf(a)) -> truncf(a)
1632  if (srcType.getWidth() > resElemType.getWidth()) {
1633  setOperand(src);
1634  return getResult();
1635  }
1636 
1637  // truncf(extf(a)) -> a
1638  if (srcType == resElemType)
1639  return src;
1640  }
1641  }
1642 
1643  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1644  return constFoldCastOp<FloatAttr, FloatAttr>(
1645  adaptor.getOperands(), getType(),
1646  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1647  RoundingMode roundingMode =
1648  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1649  llvm::RoundingMode llvmRoundingMode =
1650  convertArithRoundingModeToLLVMIR(roundingMode);
1651  FailureOr<APFloat> result =
1652  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1653  if (failed(result)) {
1654  castStatus = false;
1655  return a;
1656  }
1657  return *result;
1658  });
1659 }
1660 
1661 void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1662  MLIRContext *context) {
1663  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1664 }
1665 
1666 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1667  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1668 }
1669 
1670 LogicalResult arith::TruncFOp::verify() {
1671  return verifyTruncateOp<FloatType>(*this);
1672 }
1673 
1674 //===----------------------------------------------------------------------===//
1675 // ScalingTruncFOp
1676 //===----------------------------------------------------------------------===//
1677 
1678 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1679  TypeRange outputs) {
1680  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1681 }
1682 
1683 LogicalResult arith::ScalingTruncFOp::verify() {
1684  return verifyTruncateOp<FloatType>(*this);
1685 }
1686 
1687 //===----------------------------------------------------------------------===//
1688 // AndIOp
1689 //===----------------------------------------------------------------------===//
1690 
1691 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1692  MLIRContext *context) {
1693  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1694 }
1695 
1696 //===----------------------------------------------------------------------===//
1697 // OrIOp
1698 //===----------------------------------------------------------------------===//
1699 
1700 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1701  MLIRContext *context) {
1702  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1703 }
1704 
1705 //===----------------------------------------------------------------------===//
1706 // Verifiers for casts between integers and floats.
1707 //===----------------------------------------------------------------------===//
1708 
1709 template <typename From, typename To>
1710 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1711  if (!areValidCastInputsAndOutputs(inputs, outputs))
1712  return false;
1713 
1714  auto srcType = getTypeIfLike<From>(inputs.front());
1715  auto dstType = getTypeIfLike<To>(outputs.back());
1716 
1717  return srcType && dstType;
1718 }
1719 
1720 //===----------------------------------------------------------------------===//
1721 // UIToFPOp
1722 //===----------------------------------------------------------------------===//
1723 
1724 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1725  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1726 }
1727 
1728 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1729  Type resEleType = getElementTypeOrSelf(getType());
1730  return constFoldCastOp<IntegerAttr, FloatAttr>(
1731  adaptor.getOperands(), getType(),
1732  [&resEleType](const APInt &a, bool &castStatus) {
1733  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1734  APFloat apf(floatTy.getFloatSemantics(),
1735  APInt::getZero(floatTy.getWidth()));
1736  apf.convertFromAPInt(a, /*IsSigned=*/false,
1737  APFloat::rmNearestTiesToEven);
1738  return apf;
1739  });
1740 }
1741 
1742 //===----------------------------------------------------------------------===//
1743 // SIToFPOp
1744 //===----------------------------------------------------------------------===//
1745 
1746 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1747  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1748 }
1749 
1750 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1751  Type resEleType = getElementTypeOrSelf(getType());
1752  return constFoldCastOp<IntegerAttr, FloatAttr>(
1753  adaptor.getOperands(), getType(),
1754  [&resEleType](const APInt &a, bool &castStatus) {
1755  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1756  APFloat apf(floatTy.getFloatSemantics(),
1757  APInt::getZero(floatTy.getWidth()));
1758  apf.convertFromAPInt(a, /*IsSigned=*/true,
1759  APFloat::rmNearestTiesToEven);
1760  return apf;
1761  });
1762 }
1763 
1764 //===----------------------------------------------------------------------===//
1765 // FPToUIOp
1766 //===----------------------------------------------------------------------===//
1767 
1768 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1769  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1770 }
1771 
1772 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1773  Type resType = getElementTypeOrSelf(getType());
1774  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1775  return constFoldCastOp<FloatAttr, IntegerAttr>(
1776  adaptor.getOperands(), getType(),
1777  [&bitWidth](const APFloat &a, bool &castStatus) {
1778  bool ignored;
1779  APSInt api(bitWidth, /*isUnsigned=*/true);
1780  castStatus = APFloat::opInvalidOp !=
1781  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1782  return api;
1783  });
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // FPToSIOp
1788 //===----------------------------------------------------------------------===//
1789 
1790 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1791  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1792 }
1793 
1794 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1795  Type resType = getElementTypeOrSelf(getType());
1796  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1797  return constFoldCastOp<FloatAttr, IntegerAttr>(
1798  adaptor.getOperands(), getType(),
1799  [&bitWidth](const APFloat &a, bool &castStatus) {
1800  bool ignored;
1801  APSInt api(bitWidth, /*isUnsigned=*/false);
1802  castStatus = APFloat::opInvalidOp !=
1803  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1804  return api;
1805  });
1806 }
1807 
1808 //===----------------------------------------------------------------------===//
1809 // IndexCastOp
1810 //===----------------------------------------------------------------------===//
1811 
1812 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1813  if (!areValidCastInputsAndOutputs(inputs, outputs))
1814  return false;
1815 
1816  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1817  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1818  if (!srcType || !dstType)
1819  return false;
1820 
1821  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1822  (srcType.isSignlessInteger() && dstType.isIndex());
1823 }
1824 
1825 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1826  TypeRange outputs) {
1827  return areIndexCastCompatible(inputs, outputs);
1828 }
1829 
1830 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1831  // index_cast(constant) -> constant
1832  unsigned resultBitwidth = 64; // Default for index integer attributes.
1833  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1834  resultBitwidth = intTy.getWidth();
1835 
1836  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1837  adaptor.getOperands(), getType(),
1838  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1839  return a.sextOrTrunc(resultBitwidth);
1840  });
1841 }
1842 
1843 void arith::IndexCastOp::getCanonicalizationPatterns(
1844  RewritePatternSet &patterns, MLIRContext *context) {
1845  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1846 }
1847 
1848 //===----------------------------------------------------------------------===//
1849 // IndexCastUIOp
1850 //===----------------------------------------------------------------------===//
1851 
1852 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1853  TypeRange outputs) {
1854  return areIndexCastCompatible(inputs, outputs);
1855 }
1856 
1857 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1858  // index_castui(constant) -> constant
1859  unsigned resultBitwidth = 64; // Default for index integer attributes.
1860  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1861  resultBitwidth = intTy.getWidth();
1862 
1863  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1864  adaptor.getOperands(), getType(),
1865  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1866  return a.zextOrTrunc(resultBitwidth);
1867  });
1868 }
1869 
1870 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1871  RewritePatternSet &patterns, MLIRContext *context) {
1872  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1873 }
1874 
1875 //===----------------------------------------------------------------------===//
1876 // BitcastOp
1877 //===----------------------------------------------------------------------===//
1878 
1879 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1880  if (!areValidCastInputsAndOutputs(inputs, outputs))
1881  return false;
1882 
1883  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1884  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1885  if (!srcType || !dstType)
1886  return false;
1887 
1888  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1889 }
1890 
1891 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1892  auto resType = getType();
1893  auto operand = adaptor.getIn();
1894  if (!operand)
1895  return {};
1896 
1897  /// Bitcast dense elements.
1898  if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1899  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1900  /// Other shaped types unhandled.
1901  if (llvm::isa<ShapedType>(resType))
1902  return {};
1903 
1904  /// Bitcast poison.
1905  if (llvm::isa<ub::PoisonAttr>(operand))
1906  return ub::PoisonAttr::get(getContext());
1907 
1908  /// Bitcast integer or float to integer or float.
1909  APInt bits = llvm::isa<FloatAttr>(operand)
1910  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1911  : llvm::cast<IntegerAttr>(operand).getValue();
1912  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1913  "trying to fold on broken IR: operands have incompatible types");
1914 
1915  if (auto resFloatType = dyn_cast<FloatType>(resType))
1916  return FloatAttr::get(resType,
1917  APFloat(resFloatType.getFloatSemantics(), bits));
1918  return IntegerAttr::get(resType, bits);
1919 }
1920 
1921 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1922  MLIRContext *context) {
1923  patterns.add<BitcastOfBitcast>(context);
1924 }
1925 
1926 //===----------------------------------------------------------------------===//
1927 // CmpIOp
1928 //===----------------------------------------------------------------------===//
1929 
1930 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1931 /// comparison predicates.
1932 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1933  const APInt &lhs, const APInt &rhs) {
1934  switch (predicate) {
1935  case arith::CmpIPredicate::eq:
1936  return lhs.eq(rhs);
1937  case arith::CmpIPredicate::ne:
1938  return lhs.ne(rhs);
1939  case arith::CmpIPredicate::slt:
1940  return lhs.slt(rhs);
1941  case arith::CmpIPredicate::sle:
1942  return lhs.sle(rhs);
1943  case arith::CmpIPredicate::sgt:
1944  return lhs.sgt(rhs);
1945  case arith::CmpIPredicate::sge:
1946  return lhs.sge(rhs);
1947  case arith::CmpIPredicate::ult:
1948  return lhs.ult(rhs);
1949  case arith::CmpIPredicate::ule:
1950  return lhs.ule(rhs);
1951  case arith::CmpIPredicate::ugt:
1952  return lhs.ugt(rhs);
1953  case arith::CmpIPredicate::uge:
1954  return lhs.uge(rhs);
1955  }
1956  llvm_unreachable("unknown cmpi predicate kind");
1957 }
1958 
1959 /// Returns true if the predicate is true for two equal operands.
1960 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1961  switch (predicate) {
1962  case arith::CmpIPredicate::eq:
1963  case arith::CmpIPredicate::sle:
1964  case arith::CmpIPredicate::sge:
1965  case arith::CmpIPredicate::ule:
1966  case arith::CmpIPredicate::uge:
1967  return true;
1968  case arith::CmpIPredicate::ne:
1969  case arith::CmpIPredicate::slt:
1970  case arith::CmpIPredicate::sgt:
1971  case arith::CmpIPredicate::ult:
1972  case arith::CmpIPredicate::ugt:
1973  return false;
1974  }
1975  llvm_unreachable("unknown cmpi predicate kind");
1976 }
1977 
1978 static std::optional<int64_t> getIntegerWidth(Type t) {
1979  if (auto intType = dyn_cast<IntegerType>(t)) {
1980  return intType.getWidth();
1981  }
1982  if (auto vectorIntType = dyn_cast<VectorType>(t)) {
1983  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1984  }
1985  return std::nullopt;
1986 }
1987 
1988 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1989  // cmpi(pred, x, x)
1990  if (getLhs() == getRhs()) {
1991  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1992  return getBoolAttribute(getType(), val);
1993  }
1994 
1995  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1996  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1997  // extsi(%x : i1 -> iN) != 0 -> %x
1998  std::optional<int64_t> integerWidth =
1999  getIntegerWidth(extOp.getOperand().getType());
2000  if (integerWidth && integerWidth.value() == 1 &&
2001  getPredicate() == arith::CmpIPredicate::ne)
2002  return extOp.getOperand();
2003  }
2004  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2005  // extui(%x : i1 -> iN) != 0 -> %x
2006  std::optional<int64_t> integerWidth =
2007  getIntegerWidth(extOp.getOperand().getType());
2008  if (integerWidth && integerWidth.value() == 1 &&
2009  getPredicate() == arith::CmpIPredicate::ne)
2010  return extOp.getOperand();
2011  }
2012 
2013  // arith.cmpi ne, %val, %zero : i1 -> %val
2014  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2015  getPredicate() == arith::CmpIPredicate::ne)
2016  return getLhs();
2017  }
2018 
2019  if (matchPattern(adaptor.getRhs(), m_One())) {
2020  // arith.cmpi eq, %val, %one : i1 -> %val
2021  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2022  getPredicate() == arith::CmpIPredicate::eq)
2023  return getLhs();
2024  }
2025 
2026  // Move constant to the right side.
2027  if (adaptor.getLhs() && !adaptor.getRhs()) {
2028  // Do not use invertPredicate, as it will change eq to ne and vice versa.
2029  using Pred = CmpIPredicate;
2030  const std::pair<Pred, Pred> invPreds[] = {
2031  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2032  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2033  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2034  {Pred::ne, Pred::ne},
2035  };
2036  Pred origPred = getPredicate();
2037  for (auto pred : invPreds) {
2038  if (origPred == pred.first) {
2039  setPredicate(pred.second);
2040  Value lhs = getLhs();
2041  Value rhs = getRhs();
2042  getLhsMutable().assign(rhs);
2043  getRhsMutable().assign(lhs);
2044  return getResult();
2045  }
2046  }
2047  llvm_unreachable("unknown cmpi predicate kind");
2048  }
2049 
2050  // We are moving constants to the right side; So if lhs is constant rhs is
2051  // guaranteed to be a constant.
2052  if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2053  return constFoldBinaryOp<IntegerAttr>(
2054  adaptor.getOperands(), getI1SameShape(lhs.getType()),
2055  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2056  return APInt(1,
2057  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2058  });
2059  }
2060 
2061  return {};
2062 }
2063 
2064 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2065  MLIRContext *context) {
2066  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2067 }
2068 
2069 //===----------------------------------------------------------------------===//
2070 // CmpFOp
2071 //===----------------------------------------------------------------------===//
2072 
2073 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2074 /// comparison predicates.
2075 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2076  const APFloat &lhs, const APFloat &rhs) {
2077  auto cmpResult = lhs.compare(rhs);
2078  switch (predicate) {
2079  case arith::CmpFPredicate::AlwaysFalse:
2080  return false;
2081  case arith::CmpFPredicate::OEQ:
2082  return cmpResult == APFloat::cmpEqual;
2083  case arith::CmpFPredicate::OGT:
2084  return cmpResult == APFloat::cmpGreaterThan;
2085  case arith::CmpFPredicate::OGE:
2086  return cmpResult == APFloat::cmpGreaterThan ||
2087  cmpResult == APFloat::cmpEqual;
2088  case arith::CmpFPredicate::OLT:
2089  return cmpResult == APFloat::cmpLessThan;
2090  case arith::CmpFPredicate::OLE:
2091  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2092  case arith::CmpFPredicate::ONE:
2093  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2094  case arith::CmpFPredicate::ORD:
2095  return cmpResult != APFloat::cmpUnordered;
2096  case arith::CmpFPredicate::UEQ:
2097  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2098  case arith::CmpFPredicate::UGT:
2099  return cmpResult == APFloat::cmpUnordered ||
2100  cmpResult == APFloat::cmpGreaterThan;
2101  case arith::CmpFPredicate::UGE:
2102  return cmpResult == APFloat::cmpUnordered ||
2103  cmpResult == APFloat::cmpGreaterThan ||
2104  cmpResult == APFloat::cmpEqual;
2105  case arith::CmpFPredicate::ULT:
2106  return cmpResult == APFloat::cmpUnordered ||
2107  cmpResult == APFloat::cmpLessThan;
2108  case arith::CmpFPredicate::ULE:
2109  return cmpResult == APFloat::cmpUnordered ||
2110  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2111  case arith::CmpFPredicate::UNE:
2112  return cmpResult != APFloat::cmpEqual;
2113  case arith::CmpFPredicate::UNO:
2114  return cmpResult == APFloat::cmpUnordered;
2115  case arith::CmpFPredicate::AlwaysTrue:
2116  return true;
2117  }
2118  llvm_unreachable("unknown cmpf predicate kind");
2119 }
2120 
2121 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2122  auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2123  auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2124 
2125  // If one operand is NaN, making them both NaN does not change the result.
2126  if (lhs && lhs.getValue().isNaN())
2127  rhs = lhs;
2128  if (rhs && rhs.getValue().isNaN())
2129  lhs = rhs;
2130 
2131  if (!lhs || !rhs)
2132  return {};
2133 
2134  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2135  return BoolAttr::get(getContext(), val);
2136 }
2137 
2138 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2139 public:
2141 
2142  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2143  bool isUnsigned) {
2144  using namespace arith;
2145  switch (pred) {
2146  case CmpFPredicate::UEQ:
2147  case CmpFPredicate::OEQ:
2148  return CmpIPredicate::eq;
2149  case CmpFPredicate::UGT:
2150  case CmpFPredicate::OGT:
2151  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2152  case CmpFPredicate::UGE:
2153  case CmpFPredicate::OGE:
2154  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2155  case CmpFPredicate::ULT:
2156  case CmpFPredicate::OLT:
2157  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2158  case CmpFPredicate::ULE:
2159  case CmpFPredicate::OLE:
2160  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2161  case CmpFPredicate::UNE:
2162  case CmpFPredicate::ONE:
2163  return CmpIPredicate::ne;
2164  default:
2165  llvm_unreachable("Unexpected predicate!");
2166  }
2167  }
2168 
2169  LogicalResult matchAndRewrite(CmpFOp op,
2170  PatternRewriter &rewriter) const override {
2171  FloatAttr flt;
2172  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2173  return failure();
2174 
2175  const APFloat &rhs = flt.getValue();
2176 
2177  // Don't attempt to fold a nan.
2178  if (rhs.isNaN())
2179  return failure();
2180 
2181  // Get the width of the mantissa. We don't want to hack on conversions that
2182  // might lose information from the integer, e.g. "i64 -> float"
2183  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2184  int mantissaWidth = floatTy.getFPMantissaWidth();
2185  if (mantissaWidth <= 0)
2186  return failure();
2187 
2188  bool isUnsigned;
2189  Value intVal;
2190 
2191  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2192  isUnsigned = false;
2193  intVal = si.getIn();
2194  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2195  isUnsigned = true;
2196  intVal = ui.getIn();
2197  } else {
2198  return failure();
2199  }
2200 
2201  // Check to see that the input is converted from an integer type that is
2202  // small enough that preserves all bits.
2203  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2204  auto intWidth = intTy.getWidth();
2205 
2206  // Number of bits representing values, as opposed to the sign
2207  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2208 
2209  // Following test does NOT adjust intWidth downwards for signed inputs,
2210  // because the most negative value still requires all the mantissa bits
2211  // to distinguish it from one less than that value.
2212  if ((int)intWidth > mantissaWidth) {
2213  // Conversion would lose accuracy. Check if loss can impact comparison.
2214  int exponent = ilogb(rhs);
2215  if (exponent == APFloat::IEK_Inf) {
2216  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2217  if (maxExponent < (int)valueBits) {
2218  // Conversion could create infinity.
2219  return failure();
2220  }
2221  } else {
2222  // Note that if rhs is zero or NaN, then Exp is negative
2223  // and first condition is trivially false.
2224  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2225  // Conversion could affect comparison.
2226  return failure();
2227  }
2228  }
2229  }
2230 
2231  // Convert to equivalent cmpi predicate
2232  CmpIPredicate pred;
2233  switch (op.getPredicate()) {
2234  case CmpFPredicate::ORD:
2235  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2236  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2237  /*width=*/1);
2238  return success();
2239  case CmpFPredicate::UNO:
2240  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2241  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2242  /*width=*/1);
2243  return success();
2244  default:
2245  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2246  break;
2247  }
2248 
2249  if (!isUnsigned) {
2250  // If the rhs value is > SignedMax, fold the comparison. This handles
2251  // +INF and large values.
2252  APFloat signedMax(rhs.getSemantics());
2253  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2254  APFloat::rmNearestTiesToEven);
2255  if (signedMax < rhs) { // smax < 13123.0
2256  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2257  pred == CmpIPredicate::sle)
2258  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2259  /*width=*/1);
2260  else
2261  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2262  /*width=*/1);
2263  return success();
2264  }
2265  } else {
2266  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2267  // +INF and large values.
2268  APFloat unsignedMax(rhs.getSemantics());
2269  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2270  APFloat::rmNearestTiesToEven);
2271  if (unsignedMax < rhs) { // umax < 13123.0
2272  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2273  pred == CmpIPredicate::ule)
2274  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2275  /*width=*/1);
2276  else
2277  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2278  /*width=*/1);
2279  return success();
2280  }
2281  }
2282 
2283  if (!isUnsigned) {
2284  // See if the rhs value is < SignedMin.
2285  APFloat signedMin(rhs.getSemantics());
2286  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2287  APFloat::rmNearestTiesToEven);
2288  if (signedMin > rhs) { // smin > 12312.0
2289  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2290  pred == CmpIPredicate::sge)
2291  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2292  /*width=*/1);
2293  else
2294  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2295  /*width=*/1);
2296  return success();
2297  }
2298  } else {
2299  // See if the rhs value is < UnsignedMin.
2300  APFloat unsignedMin(rhs.getSemantics());
2301  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2302  APFloat::rmNearestTiesToEven);
2303  if (unsignedMin > rhs) { // umin > 12312.0
2304  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2305  pred == CmpIPredicate::uge)
2306  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2307  /*width=*/1);
2308  else
2309  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2310  /*width=*/1);
2311  return success();
2312  }
2313  }
2314 
2315  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2316  // [0, UMAX], but it may still be fractional. See if it is fractional by
2317  // casting the FP value to the integer value and back, checking for
2318  // equality. Don't do this for zero, because -0.0 is not fractional.
2319  bool ignored;
2320  APSInt rhsInt(intWidth, isUnsigned);
2321  if (APFloat::opInvalidOp ==
2322  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2323  // Undefined behavior invoked - the destination type can't represent
2324  // the input constant.
2325  return failure();
2326  }
2327 
2328  if (!rhs.isZero()) {
2329  APFloat apf(floatTy.getFloatSemantics(),
2330  APInt::getZero(floatTy.getWidth()));
2331  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2332 
2333  bool equal = apf == rhs;
2334  if (!equal) {
2335  // If we had a comparison against a fractional value, we have to adjust
2336  // the compare predicate and sometimes the value. rhsInt is rounded
2337  // towards zero at this point.
2338  switch (pred) {
2339  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2340  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2341  /*width=*/1);
2342  return success();
2343  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2344  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2345  /*width=*/1);
2346  return success();
2347  case CmpIPredicate::ule:
2348  // (float)int <= 4.4 --> int <= 4
2349  // (float)int <= -4.4 --> false
2350  if (rhs.isNegative()) {
2351  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2352  /*width=*/1);
2353  return success();
2354  }
2355  break;
2356  case CmpIPredicate::sle:
2357  // (float)int <= 4.4 --> int <= 4
2358  // (float)int <= -4.4 --> int < -4
2359  if (rhs.isNegative())
2360  pred = CmpIPredicate::slt;
2361  break;
2362  case CmpIPredicate::ult:
2363  // (float)int < -4.4 --> false
2364  // (float)int < 4.4 --> int <= 4
2365  if (rhs.isNegative()) {
2366  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2367  /*width=*/1);
2368  return success();
2369  }
2370  pred = CmpIPredicate::ule;
2371  break;
2372  case CmpIPredicate::slt:
2373  // (float)int < -4.4 --> int < -4
2374  // (float)int < 4.4 --> int <= 4
2375  if (!rhs.isNegative())
2376  pred = CmpIPredicate::sle;
2377  break;
2378  case CmpIPredicate::ugt:
2379  // (float)int > 4.4 --> int > 4
2380  // (float)int > -4.4 --> true
2381  if (rhs.isNegative()) {
2382  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2383  /*width=*/1);
2384  return success();
2385  }
2386  break;
2387  case CmpIPredicate::sgt:
2388  // (float)int > 4.4 --> int > 4
2389  // (float)int > -4.4 --> int >= -4
2390  if (rhs.isNegative())
2391  pred = CmpIPredicate::sge;
2392  break;
2393  case CmpIPredicate::uge:
2394  // (float)int >= -4.4 --> true
2395  // (float)int >= 4.4 --> int > 4
2396  if (rhs.isNegative()) {
2397  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2398  /*width=*/1);
2399  return success();
2400  }
2401  pred = CmpIPredicate::ugt;
2402  break;
2403  case CmpIPredicate::sge:
2404  // (float)int >= -4.4 --> int >= -4
2405  // (float)int >= 4.4 --> int > 4
2406  if (!rhs.isNegative())
2407  pred = CmpIPredicate::sgt;
2408  break;
2409  }
2410  }
2411  }
2412 
2413  // Lower this FP comparison into an appropriate integer version of the
2414  // comparison.
2415  rewriter.replaceOpWithNewOp<CmpIOp>(
2416  op, pred, intVal,
2417  ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2418  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2419  return success();
2420  }
2421 };
2422 
2423 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2424  MLIRContext *context) {
2425  patterns.insert<CmpFIntToFPConst>(context);
2426 }
2427 
2428 //===----------------------------------------------------------------------===//
2429 // SelectOp
2430 //===----------------------------------------------------------------------===//
2431 
2432 // select %arg, %c1, %c0 => extui %arg
2433 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2435 
2436  LogicalResult matchAndRewrite(arith::SelectOp op,
2437  PatternRewriter &rewriter) const override {
2438  // Cannot extui i1 to i1, or i1 to f32
2439  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2440  return failure();
2441 
2442  // select %x, c1, %c0 => extui %arg
2443  if (matchPattern(op.getTrueValue(), m_One()) &&
2444  matchPattern(op.getFalseValue(), m_Zero())) {
2445  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2446  op.getCondition());
2447  return success();
2448  }
2449 
2450  // select %x, c0, %c1 => extui (xor %arg, true)
2451  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2452  matchPattern(op.getFalseValue(), m_One())) {
2453  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2454  op, op.getType(),
2455  arith::XOrIOp::create(
2456  rewriter, op.getLoc(), op.getCondition(),
2457  arith::ConstantIntOp::create(rewriter, op.getLoc(),
2458  op.getCondition().getType(), 1)));
2459  return success();
2460  }
2461 
2462  return failure();
2463  }
2464 };
2465 
2466 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2467  MLIRContext *context) {
2468  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2469  SelectI1ToNot, SelectToExtUI>(context);
2470 }
2471 
2472 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2473  Value trueVal = getTrueValue();
2474  Value falseVal = getFalseValue();
2475  if (trueVal == falseVal)
2476  return trueVal;
2477 
2478  Value condition = getCondition();
2479 
2480  // select true, %0, %1 => %0
2481  if (matchPattern(adaptor.getCondition(), m_One()))
2482  return trueVal;
2483 
2484  // select false, %0, %1 => %1
2485  if (matchPattern(adaptor.getCondition(), m_Zero()))
2486  return falseVal;
2487 
2488  // If either operand is fully poisoned, return the other.
2489  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2490  return falseVal;
2491 
2492  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2493  return trueVal;
2494 
2495  // select %x, true, false => %x
2496  if (getType().isSignlessInteger(1) &&
2497  matchPattern(adaptor.getTrueValue(), m_One()) &&
2498  matchPattern(adaptor.getFalseValue(), m_Zero()))
2499  return condition;
2500 
2501  if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2502  auto pred = cmp.getPredicate();
2503  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2504  auto cmpLhs = cmp.getLhs();
2505  auto cmpRhs = cmp.getRhs();
2506 
2507  // %0 = arith.cmpi eq, %arg0, %arg1
2508  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2509 
2510  // %0 = arith.cmpi ne, %arg0, %arg1
2511  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2512 
2513  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2514  (cmpRhs == trueVal && cmpLhs == falseVal))
2515  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2516  }
2517  }
2518 
2519  // Constant-fold constant operands over non-splat constant condition.
2520  // select %cst_vec, %cst0, %cst1 => %cst2
2521  if (auto cond =
2522  dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2523  if (auto lhs =
2524  dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2525  if (auto rhs =
2526  dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2527  SmallVector<Attribute> results;
2528  results.reserve(static_cast<size_t>(cond.getNumElements()));
2529  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2530  cond.value_end<BoolAttr>());
2531  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2532  lhs.value_end<Attribute>());
2533  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2534  rhs.value_end<Attribute>());
2535 
2536  for (auto [condVal, lhsVal, rhsVal] :
2537  llvm::zip_equal(condVals, lhsVals, rhsVals))
2538  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2539 
2540  return DenseElementsAttr::get(lhs.getType(), results);
2541  }
2542  }
2543  }
2544 
2545  return nullptr;
2546 }
2547 
2548 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2549  Type conditionType, resultType;
2551  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2552  parser.parseOptionalAttrDict(result.attributes) ||
2553  parser.parseColonType(resultType))
2554  return failure();
2555 
2556  // Check for the explicit condition type if this is a masked tensor or vector.
2557  if (succeeded(parser.parseOptionalComma())) {
2558  conditionType = resultType;
2559  if (parser.parseType(resultType))
2560  return failure();
2561  } else {
2562  conditionType = parser.getBuilder().getI1Type();
2563  }
2564 
2565  result.addTypes(resultType);
2566  return parser.resolveOperands(operands,
2567  {conditionType, resultType, resultType},
2568  parser.getNameLoc(), result.operands);
2569 }
2570 
2572  p << " " << getOperands();
2573  p.printOptionalAttrDict((*this)->getAttrs());
2574  p << " : ";
2575  if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2576  p << condType << ", ";
2577  p << getType();
2578 }
2579 
2580 LogicalResult arith::SelectOp::verify() {
2581  Type conditionType = getCondition().getType();
2582  if (conditionType.isSignlessInteger(1))
2583  return success();
2584 
2585  // If the result type is a vector or tensor, the type can be a mask with the
2586  // same elements.
2587  Type resultType = getType();
2588  if (!llvm::isa<TensorType, VectorType>(resultType))
2589  return emitOpError() << "expected condition to be a signless i1, but got "
2590  << conditionType;
2591  Type shapedConditionType = getI1SameShape(resultType);
2592  if (conditionType != shapedConditionType) {
2593  return emitOpError() << "expected condition type to have the same shape "
2594  "as the result type, expected "
2595  << shapedConditionType << ", but got "
2596  << conditionType;
2597  }
2598  return success();
2599 }
2600 //===----------------------------------------------------------------------===//
2601 // ShLIOp
2602 //===----------------------------------------------------------------------===//
2603 
2604 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2605  // shli(x, 0) -> x
2606  if (matchPattern(adaptor.getRhs(), m_Zero()))
2607  return getLhs();
2608  // Don't fold if shifting more or equal than the bit width.
2609  bool bounded = false;
2610  auto result = constFoldBinaryOp<IntegerAttr>(
2611  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2612  bounded = b.ult(b.getBitWidth());
2613  return a.shl(b);
2614  });
2615  return bounded ? result : Attribute();
2616 }
2617 
2618 //===----------------------------------------------------------------------===//
2619 // ShRUIOp
2620 //===----------------------------------------------------------------------===//
2621 
2622 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2623  // shrui(x, 0) -> x
2624  if (matchPattern(adaptor.getRhs(), m_Zero()))
2625  return getLhs();
2626  // Don't fold if shifting more or equal than the bit width.
2627  bool bounded = false;
2628  auto result = constFoldBinaryOp<IntegerAttr>(
2629  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2630  bounded = b.ult(b.getBitWidth());
2631  return a.lshr(b);
2632  });
2633  return bounded ? result : Attribute();
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // ShRSIOp
2638 //===----------------------------------------------------------------------===//
2639 
2640 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2641  // shrsi(x, 0) -> x
2642  if (matchPattern(adaptor.getRhs(), m_Zero()))
2643  return getLhs();
2644  // Don't fold if shifting more or equal than the bit width.
2645  bool bounded = false;
2646  auto result = constFoldBinaryOp<IntegerAttr>(
2647  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2648  bounded = b.ult(b.getBitWidth());
2649  return a.ashr(b);
2650  });
2651  return bounded ? result : Attribute();
2652 }
2653 
2654 //===----------------------------------------------------------------------===//
2655 // Atomic Enum
2656 //===----------------------------------------------------------------------===//
2657 
2658 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2659 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2660  OpBuilder &builder, Location loc,
2661  bool useOnlyFiniteValue) {
2662  switch (kind) {
2663  case AtomicRMWKind::maximumf: {
2664  const llvm::fltSemantics &semantic =
2665  llvm::cast<FloatType>(resultType).getFloatSemantics();
2666  APFloat identity = useOnlyFiniteValue
2667  ? APFloat::getLargest(semantic, /*Negative=*/true)
2668  : APFloat::getInf(semantic, /*Negative=*/true);
2669  return builder.getFloatAttr(resultType, identity);
2670  }
2671  case AtomicRMWKind::maxnumf: {
2672  const llvm::fltSemantics &semantic =
2673  llvm::cast<FloatType>(resultType).getFloatSemantics();
2674  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2675  return builder.getFloatAttr(resultType, identity);
2676  }
2677  case AtomicRMWKind::addf:
2678  case AtomicRMWKind::addi:
2679  case AtomicRMWKind::maxu:
2680  case AtomicRMWKind::ori:
2681  return builder.getZeroAttr(resultType);
2682  case AtomicRMWKind::andi:
2683  return builder.getIntegerAttr(
2684  resultType,
2685  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2686  case AtomicRMWKind::maxs:
2687  return builder.getIntegerAttr(
2688  resultType, APInt::getSignedMinValue(
2689  llvm::cast<IntegerType>(resultType).getWidth()));
2690  case AtomicRMWKind::minimumf: {
2691  const llvm::fltSemantics &semantic =
2692  llvm::cast<FloatType>(resultType).getFloatSemantics();
2693  APFloat identity = useOnlyFiniteValue
2694  ? APFloat::getLargest(semantic, /*Negative=*/false)
2695  : APFloat::getInf(semantic, /*Negative=*/false);
2696 
2697  return builder.getFloatAttr(resultType, identity);
2698  }
2699  case AtomicRMWKind::minnumf: {
2700  const llvm::fltSemantics &semantic =
2701  llvm::cast<FloatType>(resultType).getFloatSemantics();
2702  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2703  return builder.getFloatAttr(resultType, identity);
2704  }
2705  case AtomicRMWKind::mins:
2706  return builder.getIntegerAttr(
2707  resultType, APInt::getSignedMaxValue(
2708  llvm::cast<IntegerType>(resultType).getWidth()));
2709  case AtomicRMWKind::minu:
2710  return builder.getIntegerAttr(
2711  resultType,
2712  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2713  case AtomicRMWKind::muli:
2714  return builder.getIntegerAttr(resultType, 1);
2715  case AtomicRMWKind::mulf:
2716  return builder.getFloatAttr(resultType, 1);
2717  // TODO: Add remaining reduction operations.
2718  default:
2719  (void)emitOptionalError(loc, "Reduction operation type not supported");
2720  break;
2721  }
2722  return nullptr;
2723 }
2724 
2725 /// Return the identity numeric value associated to the give op.
2726 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2727  std::optional<AtomicRMWKind> maybeKind =
2729  // Floating-point operations.
2730  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2731  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2732  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2733  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2734  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2735  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2736  // Integer operations.
2737  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2738  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2739  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2740  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2741  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2742  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2743  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2744  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2745  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2746  .Default([](Operation *op) { return std::nullopt; });
2747  if (!maybeKind) {
2748  return std::nullopt;
2749  }
2750 
2751  bool useOnlyFiniteValue = false;
2752  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2753  if (fmfOpInterface) {
2754  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2755  useOnlyFiniteValue =
2756  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2757  }
2758 
2759  // Builder only used as helper for attribute creation.
2760  OpBuilder b(op->getContext());
2761  Type resultType = op->getResult(0).getType();
2762 
2763  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2764  useOnlyFiniteValue);
2765 }
2766 
2767 /// Returns the identity value associated with an AtomicRMWKind op.
2768 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2769  OpBuilder &builder, Location loc,
2770  bool useOnlyFiniteValue) {
2771  auto attr =
2772  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2773  return arith::ConstantOp::create(builder, loc, attr);
2774 }
2775 
2776 /// Return the value obtained by applying the reduction operation kind
2777 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2778 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2779  Location loc, Value lhs, Value rhs) {
2780  switch (op) {
2781  case AtomicRMWKind::addf:
2782  return arith::AddFOp::create(builder, loc, lhs, rhs);
2783  case AtomicRMWKind::addi:
2784  return arith::AddIOp::create(builder, loc, lhs, rhs);
2785  case AtomicRMWKind::mulf:
2786  return arith::MulFOp::create(builder, loc, lhs, rhs);
2787  case AtomicRMWKind::muli:
2788  return arith::MulIOp::create(builder, loc, lhs, rhs);
2789  case AtomicRMWKind::maximumf:
2790  return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2791  case AtomicRMWKind::minimumf:
2792  return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2793  case AtomicRMWKind::maxnumf:
2794  return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2795  case AtomicRMWKind::minnumf:
2796  return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2797  case AtomicRMWKind::maxs:
2798  return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2799  case AtomicRMWKind::mins:
2800  return arith::MinSIOp::create(builder, loc, lhs, rhs);
2801  case AtomicRMWKind::maxu:
2802  return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2803  case AtomicRMWKind::minu:
2804  return arith::MinUIOp::create(builder, loc, lhs, rhs);
2805  case AtomicRMWKind::ori:
2806  return arith::OrIOp::create(builder, loc, lhs, rhs);
2807  case AtomicRMWKind::andi:
2808  return arith::AndIOp::create(builder, loc, lhs, rhs);
2809  // TODO: Add remaining reduction operations.
2810  default:
2811  (void)emitOptionalError(loc, "Reduction operation type not supported");
2812  break;
2813  }
2814  return nullptr;
2815 }
2816 
2817 //===----------------------------------------------------------------------===//
2818 // TableGen'd op method definitions
2819 //===----------------------------------------------------------------------===//
2820 
2821 #define GET_OP_CLASSES
2822 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2823 
2824 //===----------------------------------------------------------------------===//
2825 // TableGen'd enum attribute definitions
2826 //===----------------------------------------------------------------------===//
2827 
2828 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:711
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1419
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1355
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1960
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:672
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1369
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1341
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:771
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:753
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:141
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:149
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1434
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1812
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1710
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1391
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:129
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:952
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1362
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:170
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1978
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1377
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1335
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:433
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1405
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2169
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2142
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:830
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:831
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition: ArithOps.cpp:330
static bool classof(Operation *op)
Definition: ArithOps.cpp:347
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:324
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:353
static bool classof(Operation *op)
Definition: ArithOps.cpp:374
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:251
static bool classof(Operation *op)
Definition: ArithOps.cpp:318
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2726
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1932
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2659
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2778
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2768
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2436
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes