MLIR  21.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // Integer values must be signless.
211  if (llvm::isa<IntegerType>(type) &&
212  !llvm::cast<IntegerType>(type).isSignless())
213  return emitOpError("integer return type must be signless");
214  // Any float or elements attribute are acceptable.
215  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216  return emitOpError(
217  "value must be an integer, float, or elements attribute");
218  }
219 
220  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
221  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
222  // However, this would most likely require updating the lowerings to LLVM.
223  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224  return emitOpError(
225  "intializing scalable vectors with elements attribute is not supported"
226  " unless it's a vector splat");
227  return success();
228 }
229 
230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231  // The value's type must be the same as the provided type.
232  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233  if (!typedAttr || typedAttr.getType() != type)
234  return false;
235  // Integer values must be signless.
236  if (llvm::isa<IntegerType>(type) &&
237  !llvm::cast<IntegerType>(type).isSignless())
238  return false;
239  // Integer, float, and element attributes are buildable.
240  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241 }
242 
243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
244  Type type, Location loc) {
245  if (isBuildableWith(value, type))
246  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
247  return nullptr;
248 }
249 
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251 
253  int64_t value, unsigned width) {
254  auto type = builder.getIntegerType(width);
255  arith::ConstantOp::build(builder, result, type,
256  builder.getIntegerAttr(type, value));
257 }
258 
260  int64_t value, Type type) {
261  assert(type.isSignlessInteger() &&
262  "ConstantIntOp can only have signless integer type values");
263  arith::ConstantOp::build(builder, result, type,
264  builder.getIntegerAttr(type, value));
265 }
266 
268  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269  return constOp.getType().isSignlessInteger();
270  return false;
271 }
272 
274  const APFloat &value, FloatType type) {
275  arith::ConstantOp::build(builder, result, type,
276  builder.getFloatAttr(type, value));
277 }
278 
280  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281  return llvm::isa<FloatType>(constOp.getType());
282  return false;
283 }
284 
286  int64_t value) {
287  arith::ConstantOp::build(builder, result, builder.getIndexType(),
288  builder.getIndexAttr(value));
289 }
290 
292  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293  return constOp.getType().isIndex();
294  return false;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // AddIOp
299 //===----------------------------------------------------------------------===//
300 
301 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
302  // addi(x, 0) -> x
303  if (matchPattern(adaptor.getRhs(), m_Zero()))
304  return getLhs();
305 
306  // addi(subi(a, b), b) -> a
307  if (auto sub = getLhs().getDefiningOp<SubIOp>())
308  if (getRhs() == sub.getRhs())
309  return sub.getLhs();
310 
311  // addi(b, subi(a, b)) -> a
312  if (auto sub = getRhs().getDefiningOp<SubIOp>())
313  if (getLhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  return constFoldBinaryOp<IntegerAttr>(
317  adaptor.getOperands(),
318  [](APInt a, const APInt &b) { return std::move(a) + b; });
319 }
320 
321 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
322  MLIRContext *context) {
323  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // AddUIExtendedOp
329 //===----------------------------------------------------------------------===//
330 
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
334  return llvm::to_vector<4>(vt.getShape());
335  return std::nullopt;
336 }
337 
338 // Returns the overflow bit, assuming that `sum` is the result of unsigned
339 // addition of `operand` and another number.
340 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
341  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
342 }
343 
344 LogicalResult
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347  Type overflowTy = getOverflow().getType();
348  // addui_extended(x, 0) -> x, false
349  if (matchPattern(getRhs(), m_Zero())) {
350  Builder builder(getContext());
351  auto falseValue = builder.getZeroAttr(overflowTy);
352 
353  results.push_back(getLhs());
354  results.push_back(falseValue);
355  return success();
356  }
357 
358  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
359  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
360  // operands. If that succeeds, calculate the overflow bit based on the sum
361  // and the first (constant) operand, `lhs`.
362  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363  adaptor.getOperands(),
364  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
365  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366  ArrayRef({sumAttr, adaptor.getLhs()}),
367  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
369  if (!overflowAttr)
370  return failure();
371 
372  results.push_back(sumAttr);
373  results.push_back(overflowAttr);
374  return success();
375  }
376 
377  return failure();
378 }
379 
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382  patterns.add<AddUIExtendedToAddI>(context);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // SubIOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
390  // subi(x,x) -> 0
391  if (getOperand(0) == getOperand(1)) {
392  auto shapedType = dyn_cast<ShapedType>(getType());
393  // We can't generate a constant with a dynamic shaped tensor.
394  if (!shapedType || shapedType.hasStaticShape())
395  return Builder(getContext()).getZeroAttr(getType());
396  }
397  // subi(x,0) -> x
398  if (matchPattern(adaptor.getRhs(), m_Zero()))
399  return getLhs();
400 
401  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
402  // subi(addi(a, b), b) -> a
403  if (getRhs() == add.getRhs())
404  return add.getLhs();
405  // subi(addi(a, b), a) -> b
406  if (getRhs() == add.getLhs())
407  return add.getRhs();
408  }
409 
410  return constFoldBinaryOp<IntegerAttr>(
411  adaptor.getOperands(),
412  [](APInt a, const APInt &b) { return std::move(a) - b; });
413 }
414 
415 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
416  MLIRContext *context) {
417  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // MulIOp
424 //===----------------------------------------------------------------------===//
425 
426 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
427  // muli(x, 0) -> 0
428  if (matchPattern(adaptor.getRhs(), m_Zero()))
429  return getRhs();
430  // muli(x, 1) -> x
431  if (matchPattern(adaptor.getRhs(), m_One()))
432  return getLhs();
433  // TODO: Handle the overflow case.
434 
435  // default folder
436  return constFoldBinaryOp<IntegerAttr>(
437  adaptor.getOperands(),
438  [](const APInt &a, const APInt &b) { return a * b; });
439 }
440 
441 void arith::MulIOp::getAsmResultNames(
442  function_ref<void(Value, StringRef)> setNameFn) {
443  if (!isa<IndexType>(getType()))
444  return;
445 
446  // Match vector.vscale by name to avoid depending on the vector dialect (which
447  // is a circular dependency).
448  auto isVscale = [](Operation *op) {
449  return op && op->getName().getStringRef() == "vector.vscale";
450  };
451 
452  IntegerAttr baseValue;
453  auto isVscaleExpr = [&](Value a, Value b) {
454  return matchPattern(a, m_Constant(&baseValue)) &&
455  isVscale(b.getDefiningOp());
456  };
457 
458  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
459  return;
460 
461  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
462  SmallString<32> specialNameBuffer;
463  llvm::raw_svector_ostream specialName(specialNameBuffer);
464  specialName << 'c' << baseValue.getInt() << "_vscale";
465  setNameFn(getResult(), specialName.str());
466 }
467 
468 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
469  MLIRContext *context) {
470  patterns.add<MulIMulIConstant>(context);
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // MulSIExtendedOp
475 //===----------------------------------------------------------------------===//
476 
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
480  return llvm::to_vector<4>(vt.getShape());
481  return std::nullopt;
482 }
483 
484 LogicalResult
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
487  // mulsi_extended(x, 0) -> 0, 0
488  if (matchPattern(adaptor.getRhs(), m_Zero())) {
489  Attribute zero = adaptor.getRhs();
490  results.push_back(zero);
491  results.push_back(zero);
492  return success();
493  }
494 
495  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
496  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497  adaptor.getOperands(),
498  [](const APInt &a, const APInt &b) { return a * b; })) {
499  // Invoke the constant fold helper again to calculate the 'high' result.
500  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
502  return llvm::APIntOps::mulhs(a, b);
503  });
504  assert(highAttr && "Unexpected constant-folding failure");
505 
506  results.push_back(lowAttr);
507  results.push_back(highAttr);
508  return success();
509  }
510 
511  return failure();
512 }
513 
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // MulUIExtendedOp
521 //===----------------------------------------------------------------------===//
522 
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
526  return llvm::to_vector<4>(vt.getShape());
527  return std::nullopt;
528 }
529 
530 LogicalResult
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
533  // mului_extended(x, 0) -> 0, 0
534  if (matchPattern(adaptor.getRhs(), m_Zero())) {
535  Attribute zero = adaptor.getRhs();
536  results.push_back(zero);
537  results.push_back(zero);
538  return success();
539  }
540 
541  // mului_extended(x, 1) -> x, 0
542  if (matchPattern(adaptor.getRhs(), m_One())) {
543  Builder builder(getContext());
544  Attribute zero = builder.getZeroAttr(getLhs().getType());
545  results.push_back(getLhs());
546  results.push_back(zero);
547  return success();
548  }
549 
550  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
551  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552  adaptor.getOperands(),
553  [](const APInt &a, const APInt &b) { return a * b; })) {
554  // Invoke the constant fold helper again to calculate the 'high' result.
555  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
557  return llvm::APIntOps::mulhu(a, b);
558  });
559  assert(highAttr && "Unexpected constant-folding failure");
560 
561  results.push_back(lowAttr);
562  results.push_back(highAttr);
563  return success();
564  }
565 
566  return failure();
567 }
568 
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571  patterns.add<MulUIExtendedToMulI>(context);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // DivUIOp
576 //===----------------------------------------------------------------------===//
577 
578 /// Fold `(a * b) / b -> a`
579 static Value foldDivMul(Value lhs, Value rhs,
580  arith::IntegerOverflowFlags ovfFlags) {
581  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
582  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
583  return {};
584 
585  if (mul.getLhs() == rhs)
586  return mul.getRhs();
587 
588  if (mul.getRhs() == rhs)
589  return mul.getLhs();
590 
591  return {};
592 }
593 
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
595  // divui (x, 1) -> x.
596  if (matchPattern(adaptor.getRhs(), m_One()))
597  return getLhs();
598 
599  // (a * b) / b -> a
600  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
601  return val;
602 
603  // Don't fold if it would require a division by zero.
604  bool div0 = false;
605  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606  [&](APInt a, const APInt &b) {
607  if (div0 || !b) {
608  div0 = true;
609  return a;
610  }
611  return a.udiv(b);
612  });
613 
614  return div0 ? Attribute() : result;
615 }
616 
617 /// Returns whether an unsigned division by `divisor` is speculatable.
619  // X / 0 => UB
620  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
622 
624 }
625 
626 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
627  return getDivUISpeculatability(getRhs());
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // DivSIOp
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
635  // divsi (x, 1) -> x.
636  if (matchPattern(adaptor.getRhs(), m_One()))
637  return getLhs();
638 
639  // (a * b) / b -> a
640  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
641  return val;
642 
643  // Don't fold if it would overflow or if it requires a division by zero.
644  bool overflowOrDiv0 = false;
645  auto result = constFoldBinaryOp<IntegerAttr>(
646  adaptor.getOperands(), [&](APInt a, const APInt &b) {
647  if (overflowOrDiv0 || !b) {
648  overflowOrDiv0 = true;
649  return a;
650  }
651  return a.sdiv_ov(b, overflowOrDiv0);
652  });
653 
654  return overflowOrDiv0 ? Attribute() : result;
655 }
656 
657 /// Returns whether a signed division by `divisor` is speculatable. This
658 /// function conservatively assumes that all signed division by -1 are not
659 /// speculatable.
661  // X / 0 => UB
662  // INT_MIN / -1 => UB
663  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
666 
668 }
669 
670 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
671  return getDivSISpeculatability(getRhs());
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // Ceil and floor division folding helpers
676 //===----------------------------------------------------------------------===//
677 
678 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
679  bool &overflow) {
680  // Returns (a-1)/b + 1
681  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
682  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683  return val.sadd_ov(one, overflow);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // CeilDivUIOp
688 //===----------------------------------------------------------------------===//
689 
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
691  // ceildivui (x, 1) -> x.
692  if (matchPattern(adaptor.getRhs(), m_One()))
693  return getLhs();
694 
695  bool overflowOrDiv0 = false;
696  auto result = constFoldBinaryOp<IntegerAttr>(
697  adaptor.getOperands(), [&](APInt a, const APInt &b) {
698  if (overflowOrDiv0 || !b) {
699  overflowOrDiv0 = true;
700  return a;
701  }
702  APInt quotient = a.udiv(b);
703  if (!a.urem(b))
704  return quotient;
705  APInt one(a.getBitWidth(), 1, true);
706  return quotient.uadd_ov(one, overflowOrDiv0);
707  });
708 
709  return overflowOrDiv0 ? Attribute() : result;
710 }
711 
712 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
713  return getDivUISpeculatability(getRhs());
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // CeilDivSIOp
718 //===----------------------------------------------------------------------===//
719 
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
721  // ceildivsi (x, 1) -> x.
722  if (matchPattern(adaptor.getRhs(), m_One()))
723  return getLhs();
724 
725  // Don't fold if it would overflow or if it requires a division by zero.
726  // TODO: This hook won't fold operations where a = MININT, because
727  // negating MININT overflows. This can be improved.
728  bool overflowOrDiv0 = false;
729  auto result = constFoldBinaryOp<IntegerAttr>(
730  adaptor.getOperands(), [&](APInt a, const APInt &b) {
731  if (overflowOrDiv0 || !b) {
732  overflowOrDiv0 = true;
733  return a;
734  }
735  if (!a)
736  return a;
737  // After this point we know that neither a or b are zero.
738  unsigned bits = a.getBitWidth();
739  APInt zero = APInt::getZero(bits);
740  bool aGtZero = a.sgt(zero);
741  bool bGtZero = b.sgt(zero);
742  if (aGtZero && bGtZero) {
743  // Both positive, return ceil(a, b).
744  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
745  }
746 
747  // No folding happens if any of the intermediate arithmetic operations
748  // overflows.
749  bool overflowNegA = false;
750  bool overflowNegB = false;
751  bool overflowDiv = false;
752  bool overflowNegRes = false;
753  if (!aGtZero && !bGtZero) {
754  // Both negative, return ceil(-a, -b).
755  APInt posA = zero.ssub_ov(a, overflowNegA);
756  APInt posB = zero.ssub_ov(b, overflowNegB);
757  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
758  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
759  return res;
760  }
761  if (!aGtZero && bGtZero) {
762  // A is negative, b is positive, return - ( -a / b).
763  APInt posA = zero.ssub_ov(a, overflowNegA);
764  APInt div = posA.sdiv_ov(b, overflowDiv);
765  APInt res = zero.ssub_ov(div, overflowNegRes);
766  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
767  return res;
768  }
769  // A is positive, b is negative, return - (a / -b).
770  APInt posB = zero.ssub_ov(b, overflowNegB);
771  APInt div = a.sdiv_ov(posB, overflowDiv);
772  APInt res = zero.ssub_ov(div, overflowNegRes);
773 
774  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
775  return res;
776  });
777 
778  return overflowOrDiv0 ? Attribute() : result;
779 }
780 
781 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
782  return getDivSISpeculatability(getRhs());
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // FloorDivSIOp
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
790  // floordivsi (x, 1) -> x.
791  if (matchPattern(adaptor.getRhs(), m_One()))
792  return getLhs();
793 
794  // Don't fold if it would overflow or if it requires a division by zero.
795  bool overflowOrDiv = false;
796  auto result = constFoldBinaryOp<IntegerAttr>(
797  adaptor.getOperands(), [&](APInt a, const APInt &b) {
798  if (b.isZero()) {
799  overflowOrDiv = true;
800  return a;
801  }
802  return a.sfloordiv_ov(b, overflowOrDiv);
803  });
804 
805  return overflowOrDiv ? Attribute() : result;
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // RemUIOp
810 //===----------------------------------------------------------------------===//
811 
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
813  // remui (x, 1) -> 0.
814  if (matchPattern(adaptor.getRhs(), m_One()))
815  return Builder(getContext()).getZeroAttr(getType());
816 
817  // Don't fold if it would require a division by zero.
818  bool div0 = false;
819  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820  [&](APInt a, const APInt &b) {
821  if (div0 || b.isZero()) {
822  div0 = true;
823  return a;
824  }
825  return a.urem(b);
826  });
827 
828  return div0 ? Attribute() : result;
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // RemSIOp
833 //===----------------------------------------------------------------------===//
834 
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
836  // remsi (x, 1) -> 0.
837  if (matchPattern(adaptor.getRhs(), m_One()))
838  return Builder(getContext()).getZeroAttr(getType());
839 
840  // Don't fold if it would require a division by zero.
841  bool div0 = false;
842  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843  [&](APInt a, const APInt &b) {
844  if (div0 || b.isZero()) {
845  div0 = true;
846  return a;
847  }
848  return a.srem(b);
849  });
850 
851  return div0 ? Attribute() : result;
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // AndIOp
856 //===----------------------------------------------------------------------===//
857 
858 /// Fold `and(a, and(a, b))` to `and(a, b)`
859 static Value foldAndIofAndI(arith::AndIOp op) {
860  for (bool reversePrev : {false, true}) {
861  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862  .getDefiningOp<arith::AndIOp>();
863  if (!prev)
864  continue;
865 
866  Value other = (reversePrev ? op.getLhs() : op.getRhs());
867  if (other != prev.getLhs() && other != prev.getRhs())
868  continue;
869 
870  return prev.getResult();
871  }
872  return {};
873 }
874 
875 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
876  /// and(x, 0) -> 0
877  if (matchPattern(adaptor.getRhs(), m_Zero()))
878  return getRhs();
879  /// and(x, allOnes) -> x
880  APInt intValue;
881  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
882  intValue.isAllOnes())
883  return getLhs();
884  /// and(x, not(x)) -> 0
885  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
886  m_ConstantInt(&intValue))) &&
887  intValue.isAllOnes())
888  return Builder(getContext()).getZeroAttr(getType());
889  /// and(not(x), x) -> 0
890  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
891  m_ConstantInt(&intValue))) &&
892  intValue.isAllOnes())
893  return Builder(getContext()).getZeroAttr(getType());
894 
895  /// and(a, and(a, b)) -> and(a, b)
896  if (Value result = foldAndIofAndI(*this))
897  return result;
898 
899  return constFoldBinaryOp<IntegerAttr>(
900  adaptor.getOperands(),
901  [](APInt a, const APInt &b) { return std::move(a) & b; });
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // OrIOp
906 //===----------------------------------------------------------------------===//
907 
908 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
909  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
910  /// or(x, 0) -> x
911  if (rhsVal.isZero())
912  return getLhs();
913  /// or(x, <all ones>) -> <all ones>
914  if (rhsVal.isAllOnes())
915  return adaptor.getRhs();
916  }
917 
918  APInt intValue;
919  /// or(x, xor(x, 1)) -> 1
920  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
921  m_ConstantInt(&intValue))) &&
922  intValue.isAllOnes())
923  return getRhs().getDefiningOp<XOrIOp>().getRhs();
924  /// or(xor(x, 1), x) -> 1
925  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
926  m_ConstantInt(&intValue))) &&
927  intValue.isAllOnes())
928  return getLhs().getDefiningOp<XOrIOp>().getRhs();
929 
930  return constFoldBinaryOp<IntegerAttr>(
931  adaptor.getOperands(),
932  [](APInt a, const APInt &b) { return std::move(a) | b; });
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // XOrIOp
937 //===----------------------------------------------------------------------===//
938 
939 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
940  /// xor(x, 0) -> x
941  if (matchPattern(adaptor.getRhs(), m_Zero()))
942  return getLhs();
943  /// xor(x, x) -> 0
944  if (getLhs() == getRhs())
945  return Builder(getContext()).getZeroAttr(getType());
946  /// xor(xor(x, a), a) -> x
947  /// xor(xor(a, x), a) -> x
948  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949  if (prev.getRhs() == getRhs())
950  return prev.getLhs();
951  if (prev.getLhs() == getRhs())
952  return prev.getRhs();
953  }
954  /// xor(a, xor(x, a)) -> x
955  /// xor(a, xor(a, x)) -> x
956  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957  if (prev.getRhs() == getLhs())
958  return prev.getLhs();
959  if (prev.getLhs() == getLhs())
960  return prev.getRhs();
961  }
962 
963  return constFoldBinaryOp<IntegerAttr>(
964  adaptor.getOperands(),
965  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
966 }
967 
968 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
969  MLIRContext *context) {
970  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // NegFOp
975 //===----------------------------------------------------------------------===//
976 
977 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
978  /// negf(negf(x)) -> x
979  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980  return op.getOperand();
981  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982  [](const APFloat &a) { return -a; });
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // AddFOp
987 //===----------------------------------------------------------------------===//
988 
989 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
990  // addf(x, -0) -> x
991  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
992  return getLhs();
993 
994  return constFoldBinaryOp<FloatAttr>(
995  adaptor.getOperands(),
996  [](const APFloat &a, const APFloat &b) { return a + b; });
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // SubFOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1004  // subf(x, +0) -> x
1005  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1006  return getLhs();
1007 
1008  return constFoldBinaryOp<FloatAttr>(
1009  adaptor.getOperands(),
1010  [](const APFloat &a, const APFloat &b) { return a - b; });
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // MaximumFOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1018  // maximumf(x,x) -> x
1019  if (getLhs() == getRhs())
1020  return getRhs();
1021 
1022  // maximumf(x, -inf) -> x
1023  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1024  return getLhs();
1025 
1026  return constFoldBinaryOp<FloatAttr>(
1027  adaptor.getOperands(),
1028  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // MaxNumFOp
1033 //===----------------------------------------------------------------------===//
1034 
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1036  // maxnumf(x,x) -> x
1037  if (getLhs() == getRhs())
1038  return getRhs();
1039 
1040  // maxnumf(x, NaN) -> x
1041  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1042  return getLhs();
1043 
1044  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // MaxSIOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1052  // maxsi(x,x) -> x
1053  if (getLhs() == getRhs())
1054  return getRhs();
1055 
1056  if (APInt intValue;
1057  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1058  // maxsi(x,MAX_INT) -> MAX_INT
1059  if (intValue.isMaxSignedValue())
1060  return getRhs();
1061  // maxsi(x, MIN_INT) -> x
1062  if (intValue.isMinSignedValue())
1063  return getLhs();
1064  }
1065 
1066  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067  [](const APInt &a, const APInt &b) {
1068  return llvm::APIntOps::smax(a, b);
1069  });
1070 }
1071 
1072 //===----------------------------------------------------------------------===//
1073 // MaxUIOp
1074 //===----------------------------------------------------------------------===//
1075 
1076 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1077  // maxui(x,x) -> x
1078  if (getLhs() == getRhs())
1079  return getRhs();
1080 
1081  if (APInt intValue;
1082  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1083  // maxui(x,MAX_INT) -> MAX_INT
1084  if (intValue.isMaxValue())
1085  return getRhs();
1086  // maxui(x, MIN_INT) -> x
1087  if (intValue.isMinValue())
1088  return getLhs();
1089  }
1090 
1091  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092  [](const APInt &a, const APInt &b) {
1093  return llvm::APIntOps::umax(a, b);
1094  });
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // MinimumFOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1102  // minimumf(x,x) -> x
1103  if (getLhs() == getRhs())
1104  return getRhs();
1105 
1106  // minimumf(x, +inf) -> x
1107  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1108  return getLhs();
1109 
1110  return constFoldBinaryOp<FloatAttr>(
1111  adaptor.getOperands(),
1112  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // MinNumFOp
1117 //===----------------------------------------------------------------------===//
1118 
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1120  // minnumf(x,x) -> x
1121  if (getLhs() == getRhs())
1122  return getRhs();
1123 
1124  // minnumf(x, NaN) -> x
1125  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1126  return getLhs();
1127 
1128  return constFoldBinaryOp<FloatAttr>(
1129  adaptor.getOperands(),
1130  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // MinSIOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1138  // minsi(x,x) -> x
1139  if (getLhs() == getRhs())
1140  return getRhs();
1141 
1142  if (APInt intValue;
1143  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1144  // minsi(x,MIN_INT) -> MIN_INT
1145  if (intValue.isMinSignedValue())
1146  return getRhs();
1147  // minsi(x, MAX_INT) -> x
1148  if (intValue.isMaxSignedValue())
1149  return getLhs();
1150  }
1151 
1152  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153  [](const APInt &a, const APInt &b) {
1154  return llvm::APIntOps::smin(a, b);
1155  });
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // MinUIOp
1160 //===----------------------------------------------------------------------===//
1161 
1162 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1163  // minui(x,x) -> x
1164  if (getLhs() == getRhs())
1165  return getRhs();
1166 
1167  if (APInt intValue;
1168  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1169  // minui(x,MIN_INT) -> MIN_INT
1170  if (intValue.isMinValue())
1171  return getRhs();
1172  // minui(x, MAX_INT) -> x
1173  if (intValue.isMaxValue())
1174  return getLhs();
1175  }
1176 
1177  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178  [](const APInt &a, const APInt &b) {
1179  return llvm::APIntOps::umin(a, b);
1180  });
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // MulFOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1188  // mulf(x, 1) -> x
1189  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1190  return getLhs();
1191 
1192  return constFoldBinaryOp<FloatAttr>(
1193  adaptor.getOperands(),
1194  [](const APFloat &a, const APFloat &b) { return a * b; });
1195 }
1196 
1197 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1198  MLIRContext *context) {
1199  patterns.add<MulFOfNegF>(context);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // DivFOp
1204 //===----------------------------------------------------------------------===//
1205 
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1207  // divf(x, 1) -> x
1208  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1209  return getLhs();
1210 
1211  return constFoldBinaryOp<FloatAttr>(
1212  adaptor.getOperands(),
1213  [](const APFloat &a, const APFloat &b) { return a / b; });
1214 }
1215 
1216 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1217  MLIRContext *context) {
1218  patterns.add<DivFOfNegF>(context);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // RemFOp
1223 //===----------------------------------------------------------------------===//
1224 
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227  [](const APFloat &a, const APFloat &b) {
1228  APFloat result(a);
1229  // APFloat::mod() offers the remainder
1230  // behavior we want, i.e. the result has
1231  // the sign of LHS operand.
1232  (void)result.mod(b);
1233  return result;
1234  });
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Utility functions for verifying cast ops
1239 //===----------------------------------------------------------------------===//
1240 
1241 template <typename... Types>
1242 using type_list = std::tuple<Types...> *;
1243 
1244 /// Returns a non-null type only if the provided type is one of the allowed
1245 /// types or one of the allowed shaped types of the allowed types. Returns the
1246 /// element type if a valid shaped type is provided.
1247 template <typename... ShapedTypes, typename... ElementTypes>
1250  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1251  return {};
1252 
1253  auto underlyingType = getElementTypeOrSelf(type);
1254  if (!llvm::isa<ElementTypes...>(underlyingType))
1255  return {};
1256 
1257  return underlyingType;
1258 }
1259 
1260 /// Get allowed underlying types for vectors and tensors.
1261 template <typename... ElementTypes>
1262 static Type getTypeIfLike(Type type) {
1265 }
1266 
1267 /// Get allowed underlying types for vectors, tensors, and memrefs.
1268 template <typename... ElementTypes>
1270  return getUnderlyingType(type,
1273 }
1274 
1275 /// Return false if both types are ranked tensor with mismatching encoding.
1276 static bool hasSameEncoding(Type typeA, Type typeB) {
1277  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279  if (!rankedTensorA || !rankedTensorB)
1280  return true;
1281  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1282 }
1283 
1284 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1285  if (inputs.size() != 1 || outputs.size() != 1)
1286  return false;
1287  if (!hasSameEncoding(inputs.front(), outputs.front()))
1288  return false;
1289  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // Verifiers for integer and floating point extension/truncation ops
1294 //===----------------------------------------------------------------------===//
1295 
1296 // Extend ops can only extend to a wider type.
1297 template <typename ValType, typename Op>
1298 static LogicalResult verifyExtOp(Op op) {
1299  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1300  Type dstType = getElementTypeOrSelf(op.getType());
1301 
1302  if (llvm::cast<ValType>(srcType).getWidth() >=
1303  llvm::cast<ValType>(dstType).getWidth())
1304  return op.emitError("result type ")
1305  << dstType << " must be wider than operand type " << srcType;
1306 
1307  return success();
1308 }
1309 
1310 // Truncate ops can only truncate to a shorter type.
1311 template <typename ValType, typename Op>
1312 static LogicalResult verifyTruncateOp(Op op) {
1313  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1314  Type dstType = getElementTypeOrSelf(op.getType());
1315 
1316  if (llvm::cast<ValType>(srcType).getWidth() <=
1317  llvm::cast<ValType>(dstType).getWidth())
1318  return op.emitError("result type ")
1319  << dstType << " must be shorter than operand type " << srcType;
1320 
1321  return success();
1322 }
1323 
1324 /// Validate a cast that changes the width of a type.
1325 template <template <typename> class WidthComparator, typename... ElementTypes>
1326 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1327  if (!areValidCastInputsAndOutputs(inputs, outputs))
1328  return false;
1329 
1330  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1331  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1332  if (!srcType || !dstType)
1333  return false;
1334 
1335  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336  srcType.getIntOrFloatBitWidth());
1337 }
1338 
1339 /// Attempts to convert `sourceValue` to an APFloat value with
1340 /// `targetSemantics` and `roundingMode`, without any information loss.
1341 static FailureOr<APFloat> convertFloatValue(
1342  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1343  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344  bool losesInfo = false;
1345  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346  if (losesInfo || status != APFloat::opOK)
1347  return failure();
1348 
1349  return sourceValue;
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // ExtUIOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358  getInMutable().assign(lhs.getIn());
1359  return getResult();
1360  }
1361 
1362  Type resType = getElementTypeOrSelf(getType());
1363  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365  adaptor.getOperands(), getType(),
1366  [bitWidth](const APInt &a, bool &castStatus) {
1367  return a.zext(bitWidth);
1368  });
1369 }
1370 
1371 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1372  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1373 }
1374 
1375 LogicalResult arith::ExtUIOp::verify() {
1376  return verifyExtOp<IntegerType>(*this);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ExtSIOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385  getInMutable().assign(lhs.getIn());
1386  return getResult();
1387  }
1388 
1389  Type resType = getElementTypeOrSelf(getType());
1390  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392  adaptor.getOperands(), getType(),
1393  [bitWidth](const APInt &a, bool &castStatus) {
1394  return a.sext(bitWidth);
1395  });
1396 }
1397 
1398 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1400 }
1401 
1402 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403  MLIRContext *context) {
1404  patterns.add<ExtSIOfExtUI>(context);
1405 }
1406 
1407 LogicalResult arith::ExtSIOp::verify() {
1408  return verifyExtOp<IntegerType>(*this);
1409 }
1410 
1411 //===----------------------------------------------------------------------===//
1412 // ExtFOp
1413 //===----------------------------------------------------------------------===//
1414 
1415 /// Fold extension of float constants when there is no information loss due the
1416 /// difference in fp semantics.
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419  if (truncFOp.getOperand().getType() == getType()) {
1420  arith::FastMathFlags truncFMF =
1421  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422  bool isTruncContract =
1423  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1424  arith::FastMathFlags extFMF =
1425  getFastmath().value_or(arith::FastMathFlags::none);
1426  bool isExtContract =
1427  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1428  if (isTruncContract && isExtContract) {
1429  return truncFOp.getOperand();
1430  }
1431  }
1432  }
1433 
1434  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1435  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436  return constFoldCastOp<FloatAttr, FloatAttr>(
1437  adaptor.getOperands(), getType(),
1438  [&targetSemantics](const APFloat &a, bool &castStatus) {
1439  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1440  if (failed(result)) {
1441  castStatus = false;
1442  return a;
1443  }
1444  return *result;
1445  });
1446 }
1447 
1448 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1449  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1450 }
1451 
1452 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // ScalingExtFOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1459  TypeRange outputs) {
1460  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1461 }
1462 
1463 LogicalResult arith::ScalingExtFOp::verify() {
1464  return verifyExtOp<FloatType>(*this);
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // TruncIOp
1469 //===----------------------------------------------------------------------===//
1470 
1471 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1472  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1473  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1474  Value src = getOperand().getDefiningOp()->getOperand(0);
1475  Type srcType = getElementTypeOrSelf(src.getType());
1476  Type dstType = getElementTypeOrSelf(getType());
1477  // trunci(zexti(a)) -> trunci(a)
1478  // trunci(sexti(a)) -> trunci(a)
1479  if (llvm::cast<IntegerType>(srcType).getWidth() >
1480  llvm::cast<IntegerType>(dstType).getWidth()) {
1481  setOperand(src);
1482  return getResult();
1483  }
1484 
1485  // trunci(zexti(a)) -> a
1486  // trunci(sexti(a)) -> a
1487  if (srcType == dstType)
1488  return src;
1489  }
1490 
1491  // trunci(trunci(a)) -> trunci(a))
1492  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1493  setOperand(getOperand().getDefiningOp()->getOperand(0));
1494  return getResult();
1495  }
1496 
1497  Type resType = getElementTypeOrSelf(getType());
1498  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1499  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1500  adaptor.getOperands(), getType(),
1501  [bitWidth](const APInt &a, bool &castStatus) {
1502  return a.trunc(bitWidth);
1503  });
1504 }
1505 
1506 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1507  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1508 }
1509 
1510 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1511  MLIRContext *context) {
1512  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1513  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1514  context);
1515 }
1516 
1517 LogicalResult arith::TruncIOp::verify() {
1518  return verifyTruncateOp<IntegerType>(*this);
1519 }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // TruncFOp
1523 //===----------------------------------------------------------------------===//
1524 
1525 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1526 /// can be represented without precision loss.
1527 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1528  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1529  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1530  Value src = extOp.getIn();
1531  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1532  auto intermediateType =
1533  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1534  // Check if the srcType is representable in the intermediateType.
1535  if (llvm::APFloatBase::isRepresentableBy(
1536  srcType.getFloatSemantics(),
1537  intermediateType.getFloatSemantics())) {
1538  // truncf(extf(a)) -> truncf(a)
1539  if (srcType.getWidth() > resElemType.getWidth()) {
1540  setOperand(src);
1541  return getResult();
1542  }
1543 
1544  // truncf(extf(a)) -> a
1545  if (srcType == resElemType)
1546  return src;
1547  }
1548  }
1549 
1550  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1551  return constFoldCastOp<FloatAttr, FloatAttr>(
1552  adaptor.getOperands(), getType(),
1553  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1554  RoundingMode roundingMode =
1555  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1556  llvm::RoundingMode llvmRoundingMode =
1557  convertArithRoundingModeToLLVMIR(roundingMode);
1558  FailureOr<APFloat> result =
1559  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1560  if (failed(result)) {
1561  castStatus = false;
1562  return a;
1563  }
1564  return *result;
1565  });
1566 }
1567 
1568 void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1569  MLIRContext *context) {
1570  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1571 }
1572 
1573 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1574  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1575 }
1576 
1577 LogicalResult arith::TruncFOp::verify() {
1578  return verifyTruncateOp<FloatType>(*this);
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // ScalingTruncFOp
1583 //===----------------------------------------------------------------------===//
1584 
1585 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1586  TypeRange outputs) {
1587  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1588 }
1589 
1590 LogicalResult arith::ScalingTruncFOp::verify() {
1591  return verifyTruncateOp<FloatType>(*this);
1592 }
1593 
1594 //===----------------------------------------------------------------------===//
1595 // AndIOp
1596 //===----------------------------------------------------------------------===//
1597 
1598 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1599  MLIRContext *context) {
1600  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // OrIOp
1605 //===----------------------------------------------------------------------===//
1606 
1607 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1608  MLIRContext *context) {
1609  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // Verifiers for casts between integers and floats.
1614 //===----------------------------------------------------------------------===//
1615 
1616 template <typename From, typename To>
1617 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1618  if (!areValidCastInputsAndOutputs(inputs, outputs))
1619  return false;
1620 
1621  auto srcType = getTypeIfLike<From>(inputs.front());
1622  auto dstType = getTypeIfLike<To>(outputs.back());
1623 
1624  return srcType && dstType;
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // UIToFPOp
1629 //===----------------------------------------------------------------------===//
1630 
1631 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1632  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1633 }
1634 
1635 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1636  Type resEleType = getElementTypeOrSelf(getType());
1637  return constFoldCastOp<IntegerAttr, FloatAttr>(
1638  adaptor.getOperands(), getType(),
1639  [&resEleType](const APInt &a, bool &castStatus) {
1640  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1641  APFloat apf(floatTy.getFloatSemantics(),
1642  APInt::getZero(floatTy.getWidth()));
1643  apf.convertFromAPInt(a, /*IsSigned=*/false,
1644  APFloat::rmNearestTiesToEven);
1645  return apf;
1646  });
1647 }
1648 
1649 //===----------------------------------------------------------------------===//
1650 // SIToFPOp
1651 //===----------------------------------------------------------------------===//
1652 
1653 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1654  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1655 }
1656 
1657 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1658  Type resEleType = getElementTypeOrSelf(getType());
1659  return constFoldCastOp<IntegerAttr, FloatAttr>(
1660  adaptor.getOperands(), getType(),
1661  [&resEleType](const APInt &a, bool &castStatus) {
1662  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1663  APFloat apf(floatTy.getFloatSemantics(),
1664  APInt::getZero(floatTy.getWidth()));
1665  apf.convertFromAPInt(a, /*IsSigned=*/true,
1666  APFloat::rmNearestTiesToEven);
1667  return apf;
1668  });
1669 }
1670 
1671 //===----------------------------------------------------------------------===//
1672 // FPToUIOp
1673 //===----------------------------------------------------------------------===//
1674 
1675 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1676  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1677 }
1678 
1679 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1680  Type resType = getElementTypeOrSelf(getType());
1681  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1682  return constFoldCastOp<FloatAttr, IntegerAttr>(
1683  adaptor.getOperands(), getType(),
1684  [&bitWidth](const APFloat &a, bool &castStatus) {
1685  bool ignored;
1686  APSInt api(bitWidth, /*isUnsigned=*/true);
1687  castStatus = APFloat::opInvalidOp !=
1688  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1689  return api;
1690  });
1691 }
1692 
1693 //===----------------------------------------------------------------------===//
1694 // FPToSIOp
1695 //===----------------------------------------------------------------------===//
1696 
1697 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1698  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1699 }
1700 
1701 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1702  Type resType = getElementTypeOrSelf(getType());
1703  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1704  return constFoldCastOp<FloatAttr, IntegerAttr>(
1705  adaptor.getOperands(), getType(),
1706  [&bitWidth](const APFloat &a, bool &castStatus) {
1707  bool ignored;
1708  APSInt api(bitWidth, /*isUnsigned=*/false);
1709  castStatus = APFloat::opInvalidOp !=
1710  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1711  return api;
1712  });
1713 }
1714 
1715 //===----------------------------------------------------------------------===//
1716 // IndexCastOp
1717 //===----------------------------------------------------------------------===//
1718 
1719 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1720  if (!areValidCastInputsAndOutputs(inputs, outputs))
1721  return false;
1722 
1723  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1724  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1725  if (!srcType || !dstType)
1726  return false;
1727 
1728  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1729  (srcType.isSignlessInteger() && dstType.isIndex());
1730 }
1731 
1732 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1733  TypeRange outputs) {
1734  return areIndexCastCompatible(inputs, outputs);
1735 }
1736 
1737 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1738  // index_cast(constant) -> constant
1739  unsigned resultBitwidth = 64; // Default for index integer attributes.
1740  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1741  resultBitwidth = intTy.getWidth();
1742 
1743  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1744  adaptor.getOperands(), getType(),
1745  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1746  return a.sextOrTrunc(resultBitwidth);
1747  });
1748 }
1749 
1750 void arith::IndexCastOp::getCanonicalizationPatterns(
1751  RewritePatternSet &patterns, MLIRContext *context) {
1752  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1753 }
1754 
1755 //===----------------------------------------------------------------------===//
1756 // IndexCastUIOp
1757 //===----------------------------------------------------------------------===//
1758 
1759 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1760  TypeRange outputs) {
1761  return areIndexCastCompatible(inputs, outputs);
1762 }
1763 
1764 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1765  // index_castui(constant) -> constant
1766  unsigned resultBitwidth = 64; // Default for index integer attributes.
1767  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1768  resultBitwidth = intTy.getWidth();
1769 
1770  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1771  adaptor.getOperands(), getType(),
1772  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1773  return a.zextOrTrunc(resultBitwidth);
1774  });
1775 }
1776 
1777 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1778  RewritePatternSet &patterns, MLIRContext *context) {
1779  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1780 }
1781 
1782 //===----------------------------------------------------------------------===//
1783 // BitcastOp
1784 //===----------------------------------------------------------------------===//
1785 
1786 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1787  if (!areValidCastInputsAndOutputs(inputs, outputs))
1788  return false;
1789 
1790  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1791  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1792  if (!srcType || !dstType)
1793  return false;
1794 
1795  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1796 }
1797 
1798 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1799  auto resType = getType();
1800  auto operand = adaptor.getIn();
1801  if (!operand)
1802  return {};
1803 
1804  /// Bitcast dense elements.
1805  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1806  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1807  /// Other shaped types unhandled.
1808  if (llvm::isa<ShapedType>(resType))
1809  return {};
1810 
1811  /// Bitcast poison.
1812  if (llvm::isa<ub::PoisonAttr>(operand))
1813  return ub::PoisonAttr::get(getContext());
1814 
1815  /// Bitcast integer or float to integer or float.
1816  APInt bits = llvm::isa<FloatAttr>(operand)
1817  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1818  : llvm::cast<IntegerAttr>(operand).getValue();
1819  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1820  "trying to fold on broken IR: operands have incompatible types");
1821 
1822  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1823  return FloatAttr::get(resType,
1824  APFloat(resFloatType.getFloatSemantics(), bits));
1825  return IntegerAttr::get(resType, bits);
1826 }
1827 
1828 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1829  MLIRContext *context) {
1830  patterns.add<BitcastOfBitcast>(context);
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // CmpIOp
1835 //===----------------------------------------------------------------------===//
1836 
1837 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1838 /// comparison predicates.
1839 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1840  const APInt &lhs, const APInt &rhs) {
1841  switch (predicate) {
1842  case arith::CmpIPredicate::eq:
1843  return lhs.eq(rhs);
1844  case arith::CmpIPredicate::ne:
1845  return lhs.ne(rhs);
1846  case arith::CmpIPredicate::slt:
1847  return lhs.slt(rhs);
1848  case arith::CmpIPredicate::sle:
1849  return lhs.sle(rhs);
1850  case arith::CmpIPredicate::sgt:
1851  return lhs.sgt(rhs);
1852  case arith::CmpIPredicate::sge:
1853  return lhs.sge(rhs);
1854  case arith::CmpIPredicate::ult:
1855  return lhs.ult(rhs);
1856  case arith::CmpIPredicate::ule:
1857  return lhs.ule(rhs);
1858  case arith::CmpIPredicate::ugt:
1859  return lhs.ugt(rhs);
1860  case arith::CmpIPredicate::uge:
1861  return lhs.uge(rhs);
1862  }
1863  llvm_unreachable("unknown cmpi predicate kind");
1864 }
1865 
1866 /// Returns true if the predicate is true for two equal operands.
1867 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1868  switch (predicate) {
1869  case arith::CmpIPredicate::eq:
1870  case arith::CmpIPredicate::sle:
1871  case arith::CmpIPredicate::sge:
1872  case arith::CmpIPredicate::ule:
1873  case arith::CmpIPredicate::uge:
1874  return true;
1875  case arith::CmpIPredicate::ne:
1876  case arith::CmpIPredicate::slt:
1877  case arith::CmpIPredicate::sgt:
1878  case arith::CmpIPredicate::ult:
1879  case arith::CmpIPredicate::ugt:
1880  return false;
1881  }
1882  llvm_unreachable("unknown cmpi predicate kind");
1883 }
1884 
1885 static std::optional<int64_t> getIntegerWidth(Type t) {
1886  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1887  return intType.getWidth();
1888  }
1889  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1890  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1891  }
1892  return std::nullopt;
1893 }
1894 
1895 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1896  // cmpi(pred, x, x)
1897  if (getLhs() == getRhs()) {
1898  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1899  return getBoolAttribute(getType(), val);
1900  }
1901 
1902  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1903  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1904  // extsi(%x : i1 -> iN) != 0 -> %x
1905  std::optional<int64_t> integerWidth =
1906  getIntegerWidth(extOp.getOperand().getType());
1907  if (integerWidth && integerWidth.value() == 1 &&
1908  getPredicate() == arith::CmpIPredicate::ne)
1909  return extOp.getOperand();
1910  }
1911  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1912  // extui(%x : i1 -> iN) != 0 -> %x
1913  std::optional<int64_t> integerWidth =
1914  getIntegerWidth(extOp.getOperand().getType());
1915  if (integerWidth && integerWidth.value() == 1 &&
1916  getPredicate() == arith::CmpIPredicate::ne)
1917  return extOp.getOperand();
1918  }
1919 
1920  // arith.cmpi ne, %val, %zero : i1 -> %val
1921  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1922  getPredicate() == arith::CmpIPredicate::ne)
1923  return getLhs();
1924  }
1925 
1926  if (matchPattern(adaptor.getRhs(), m_One())) {
1927  // arith.cmpi eq, %val, %one : i1 -> %val
1928  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1929  getPredicate() == arith::CmpIPredicate::eq)
1930  return getLhs();
1931  }
1932 
1933  // Move constant to the right side.
1934  if (adaptor.getLhs() && !adaptor.getRhs()) {
1935  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1936  using Pred = CmpIPredicate;
1937  const std::pair<Pred, Pred> invPreds[] = {
1938  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1939  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1940  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1941  {Pred::ne, Pred::ne},
1942  };
1943  Pred origPred = getPredicate();
1944  for (auto pred : invPreds) {
1945  if (origPred == pred.first) {
1946  setPredicate(pred.second);
1947  Value lhs = getLhs();
1948  Value rhs = getRhs();
1949  getLhsMutable().assign(rhs);
1950  getRhsMutable().assign(lhs);
1951  return getResult();
1952  }
1953  }
1954  llvm_unreachable("unknown cmpi predicate kind");
1955  }
1956 
1957  // We are moving constants to the right side; So if lhs is constant rhs is
1958  // guaranteed to be a constant.
1959  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1960  return constFoldBinaryOp<IntegerAttr>(
1961  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1962  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1963  return APInt(1,
1964  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1965  });
1966  }
1967 
1968  return {};
1969 }
1970 
1971 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1972  MLIRContext *context) {
1973  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1974 }
1975 
1976 //===----------------------------------------------------------------------===//
1977 // CmpFOp
1978 //===----------------------------------------------------------------------===//
1979 
1980 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1981 /// comparison predicates.
1982 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1983  const APFloat &lhs, const APFloat &rhs) {
1984  auto cmpResult = lhs.compare(rhs);
1985  switch (predicate) {
1986  case arith::CmpFPredicate::AlwaysFalse:
1987  return false;
1988  case arith::CmpFPredicate::OEQ:
1989  return cmpResult == APFloat::cmpEqual;
1990  case arith::CmpFPredicate::OGT:
1991  return cmpResult == APFloat::cmpGreaterThan;
1992  case arith::CmpFPredicate::OGE:
1993  return cmpResult == APFloat::cmpGreaterThan ||
1994  cmpResult == APFloat::cmpEqual;
1995  case arith::CmpFPredicate::OLT:
1996  return cmpResult == APFloat::cmpLessThan;
1997  case arith::CmpFPredicate::OLE:
1998  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1999  case arith::CmpFPredicate::ONE:
2000  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2001  case arith::CmpFPredicate::ORD:
2002  return cmpResult != APFloat::cmpUnordered;
2003  case arith::CmpFPredicate::UEQ:
2004  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2005  case arith::CmpFPredicate::UGT:
2006  return cmpResult == APFloat::cmpUnordered ||
2007  cmpResult == APFloat::cmpGreaterThan;
2008  case arith::CmpFPredicate::UGE:
2009  return cmpResult == APFloat::cmpUnordered ||
2010  cmpResult == APFloat::cmpGreaterThan ||
2011  cmpResult == APFloat::cmpEqual;
2012  case arith::CmpFPredicate::ULT:
2013  return cmpResult == APFloat::cmpUnordered ||
2014  cmpResult == APFloat::cmpLessThan;
2015  case arith::CmpFPredicate::ULE:
2016  return cmpResult == APFloat::cmpUnordered ||
2017  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2018  case arith::CmpFPredicate::UNE:
2019  return cmpResult != APFloat::cmpEqual;
2020  case arith::CmpFPredicate::UNO:
2021  return cmpResult == APFloat::cmpUnordered;
2022  case arith::CmpFPredicate::AlwaysTrue:
2023  return true;
2024  }
2025  llvm_unreachable("unknown cmpf predicate kind");
2026 }
2027 
2028 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2029  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2030  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2031 
2032  // If one operand is NaN, making them both NaN does not change the result.
2033  if (lhs && lhs.getValue().isNaN())
2034  rhs = lhs;
2035  if (rhs && rhs.getValue().isNaN())
2036  lhs = rhs;
2037 
2038  if (!lhs || !rhs)
2039  return {};
2040 
2041  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2042  return BoolAttr::get(getContext(), val);
2043 }
2044 
2045 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2046 public:
2048 
2049  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2050  bool isUnsigned) {
2051  using namespace arith;
2052  switch (pred) {
2053  case CmpFPredicate::UEQ:
2054  case CmpFPredicate::OEQ:
2055  return CmpIPredicate::eq;
2056  case CmpFPredicate::UGT:
2057  case CmpFPredicate::OGT:
2058  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2059  case CmpFPredicate::UGE:
2060  case CmpFPredicate::OGE:
2061  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2062  case CmpFPredicate::ULT:
2063  case CmpFPredicate::OLT:
2064  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2065  case CmpFPredicate::ULE:
2066  case CmpFPredicate::OLE:
2067  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2068  case CmpFPredicate::UNE:
2069  case CmpFPredicate::ONE:
2070  return CmpIPredicate::ne;
2071  default:
2072  llvm_unreachable("Unexpected predicate!");
2073  }
2074  }
2075 
2076  LogicalResult matchAndRewrite(CmpFOp op,
2077  PatternRewriter &rewriter) const override {
2078  FloatAttr flt;
2079  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2080  return failure();
2081 
2082  const APFloat &rhs = flt.getValue();
2083 
2084  // Don't attempt to fold a nan.
2085  if (rhs.isNaN())
2086  return failure();
2087 
2088  // Get the width of the mantissa. We don't want to hack on conversions that
2089  // might lose information from the integer, e.g. "i64 -> float"
2090  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2091  int mantissaWidth = floatTy.getFPMantissaWidth();
2092  if (mantissaWidth <= 0)
2093  return failure();
2094 
2095  bool isUnsigned;
2096  Value intVal;
2097 
2098  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2099  isUnsigned = false;
2100  intVal = si.getIn();
2101  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2102  isUnsigned = true;
2103  intVal = ui.getIn();
2104  } else {
2105  return failure();
2106  }
2107 
2108  // Check to see that the input is converted from an integer type that is
2109  // small enough that preserves all bits.
2110  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2111  auto intWidth = intTy.getWidth();
2112 
2113  // Number of bits representing values, as opposed to the sign
2114  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2115 
2116  // Following test does NOT adjust intWidth downwards for signed inputs,
2117  // because the most negative value still requires all the mantissa bits
2118  // to distinguish it from one less than that value.
2119  if ((int)intWidth > mantissaWidth) {
2120  // Conversion would lose accuracy. Check if loss can impact comparison.
2121  int exponent = ilogb(rhs);
2122  if (exponent == APFloat::IEK_Inf) {
2123  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2124  if (maxExponent < (int)valueBits) {
2125  // Conversion could create infinity.
2126  return failure();
2127  }
2128  } else {
2129  // Note that if rhs is zero or NaN, then Exp is negative
2130  // and first condition is trivially false.
2131  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2132  // Conversion could affect comparison.
2133  return failure();
2134  }
2135  }
2136  }
2137 
2138  // Convert to equivalent cmpi predicate
2139  CmpIPredicate pred;
2140  switch (op.getPredicate()) {
2141  case CmpFPredicate::ORD:
2142  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2143  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2144  /*width=*/1);
2145  return success();
2146  case CmpFPredicate::UNO:
2147  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2148  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2149  /*width=*/1);
2150  return success();
2151  default:
2152  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2153  break;
2154  }
2155 
2156  if (!isUnsigned) {
2157  // If the rhs value is > SignedMax, fold the comparison. This handles
2158  // +INF and large values.
2159  APFloat signedMax(rhs.getSemantics());
2160  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2161  APFloat::rmNearestTiesToEven);
2162  if (signedMax < rhs) { // smax < 13123.0
2163  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2164  pred == CmpIPredicate::sle)
2165  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2166  /*width=*/1);
2167  else
2168  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2169  /*width=*/1);
2170  return success();
2171  }
2172  } else {
2173  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2174  // +INF and large values.
2175  APFloat unsignedMax(rhs.getSemantics());
2176  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2177  APFloat::rmNearestTiesToEven);
2178  if (unsignedMax < rhs) { // umax < 13123.0
2179  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2180  pred == CmpIPredicate::ule)
2181  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2182  /*width=*/1);
2183  else
2184  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2185  /*width=*/1);
2186  return success();
2187  }
2188  }
2189 
2190  if (!isUnsigned) {
2191  // See if the rhs value is < SignedMin.
2192  APFloat signedMin(rhs.getSemantics());
2193  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2194  APFloat::rmNearestTiesToEven);
2195  if (signedMin > rhs) { // smin > 12312.0
2196  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2197  pred == CmpIPredicate::sge)
2198  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2199  /*width=*/1);
2200  else
2201  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2202  /*width=*/1);
2203  return success();
2204  }
2205  } else {
2206  // See if the rhs value is < UnsignedMin.
2207  APFloat unsignedMin(rhs.getSemantics());
2208  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2209  APFloat::rmNearestTiesToEven);
2210  if (unsignedMin > rhs) { // umin > 12312.0
2211  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2212  pred == CmpIPredicate::uge)
2213  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2214  /*width=*/1);
2215  else
2216  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2217  /*width=*/1);
2218  return success();
2219  }
2220  }
2221 
2222  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2223  // [0, UMAX], but it may still be fractional. See if it is fractional by
2224  // casting the FP value to the integer value and back, checking for
2225  // equality. Don't do this for zero, because -0.0 is not fractional.
2226  bool ignored;
2227  APSInt rhsInt(intWidth, isUnsigned);
2228  if (APFloat::opInvalidOp ==
2229  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2230  // Undefined behavior invoked - the destination type can't represent
2231  // the input constant.
2232  return failure();
2233  }
2234 
2235  if (!rhs.isZero()) {
2236  APFloat apf(floatTy.getFloatSemantics(),
2237  APInt::getZero(floatTy.getWidth()));
2238  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2239 
2240  bool equal = apf == rhs;
2241  if (!equal) {
2242  // If we had a comparison against a fractional value, we have to adjust
2243  // the compare predicate and sometimes the value. rhsInt is rounded
2244  // towards zero at this point.
2245  switch (pred) {
2246  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2247  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2248  /*width=*/1);
2249  return success();
2250  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2251  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2252  /*width=*/1);
2253  return success();
2254  case CmpIPredicate::ule:
2255  // (float)int <= 4.4 --> int <= 4
2256  // (float)int <= -4.4 --> false
2257  if (rhs.isNegative()) {
2258  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2259  /*width=*/1);
2260  return success();
2261  }
2262  break;
2263  case CmpIPredicate::sle:
2264  // (float)int <= 4.4 --> int <= 4
2265  // (float)int <= -4.4 --> int < -4
2266  if (rhs.isNegative())
2267  pred = CmpIPredicate::slt;
2268  break;
2269  case CmpIPredicate::ult:
2270  // (float)int < -4.4 --> false
2271  // (float)int < 4.4 --> int <= 4
2272  if (rhs.isNegative()) {
2273  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2274  /*width=*/1);
2275  return success();
2276  }
2277  pred = CmpIPredicate::ule;
2278  break;
2279  case CmpIPredicate::slt:
2280  // (float)int < -4.4 --> int < -4
2281  // (float)int < 4.4 --> int <= 4
2282  if (!rhs.isNegative())
2283  pred = CmpIPredicate::sle;
2284  break;
2285  case CmpIPredicate::ugt:
2286  // (float)int > 4.4 --> int > 4
2287  // (float)int > -4.4 --> true
2288  if (rhs.isNegative()) {
2289  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2290  /*width=*/1);
2291  return success();
2292  }
2293  break;
2294  case CmpIPredicate::sgt:
2295  // (float)int > 4.4 --> int > 4
2296  // (float)int > -4.4 --> int >= -4
2297  if (rhs.isNegative())
2298  pred = CmpIPredicate::sge;
2299  break;
2300  case CmpIPredicate::uge:
2301  // (float)int >= -4.4 --> true
2302  // (float)int >= 4.4 --> int > 4
2303  if (rhs.isNegative()) {
2304  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2305  /*width=*/1);
2306  return success();
2307  }
2308  pred = CmpIPredicate::ugt;
2309  break;
2310  case CmpIPredicate::sge:
2311  // (float)int >= -4.4 --> int >= -4
2312  // (float)int >= 4.4 --> int > 4
2313  if (!rhs.isNegative())
2314  pred = CmpIPredicate::sgt;
2315  break;
2316  }
2317  }
2318  }
2319 
2320  // Lower this FP comparison into an appropriate integer version of the
2321  // comparison.
2322  rewriter.replaceOpWithNewOp<CmpIOp>(
2323  op, pred, intVal,
2324  rewriter.create<ConstantOp>(
2325  op.getLoc(), intVal.getType(),
2326  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2327  return success();
2328  }
2329 };
2330 
2331 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2332  MLIRContext *context) {
2333  patterns.insert<CmpFIntToFPConst>(context);
2334 }
2335 
2336 //===----------------------------------------------------------------------===//
2337 // SelectOp
2338 //===----------------------------------------------------------------------===//
2339 
2340 // select %arg, %c1, %c0 => extui %arg
2341 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2343 
2344  LogicalResult matchAndRewrite(arith::SelectOp op,
2345  PatternRewriter &rewriter) const override {
2346  // Cannot extui i1 to i1, or i1 to f32
2347  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2348  return failure();
2349 
2350  // select %x, c1, %c0 => extui %arg
2351  if (matchPattern(op.getTrueValue(), m_One()) &&
2352  matchPattern(op.getFalseValue(), m_Zero())) {
2353  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2354  op.getCondition());
2355  return success();
2356  }
2357 
2358  // select %x, c0, %c1 => extui (xor %arg, true)
2359  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2360  matchPattern(op.getFalseValue(), m_One())) {
2361  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2362  op, op.getType(),
2363  rewriter.create<arith::XOrIOp>(
2364  op.getLoc(), op.getCondition(),
2365  rewriter.create<arith::ConstantIntOp>(
2366  op.getLoc(), 1, op.getCondition().getType())));
2367  return success();
2368  }
2369 
2370  return failure();
2371  }
2372 };
2373 
2374 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2375  MLIRContext *context) {
2376  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2377  SelectI1ToNot, SelectToExtUI>(context);
2378 }
2379 
2380 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2381  Value trueVal = getTrueValue();
2382  Value falseVal = getFalseValue();
2383  if (trueVal == falseVal)
2384  return trueVal;
2385 
2386  Value condition = getCondition();
2387 
2388  // select true, %0, %1 => %0
2389  if (matchPattern(adaptor.getCondition(), m_One()))
2390  return trueVal;
2391 
2392  // select false, %0, %1 => %1
2393  if (matchPattern(adaptor.getCondition(), m_Zero()))
2394  return falseVal;
2395 
2396  // If either operand is fully poisoned, return the other.
2397  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2398  return falseVal;
2399 
2400  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2401  return trueVal;
2402 
2403  // select %x, true, false => %x
2404  if (getType().isSignlessInteger(1) &&
2405  matchPattern(adaptor.getTrueValue(), m_One()) &&
2406  matchPattern(adaptor.getFalseValue(), m_Zero()))
2407  return condition;
2408 
2409  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2410  auto pred = cmp.getPredicate();
2411  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2412  auto cmpLhs = cmp.getLhs();
2413  auto cmpRhs = cmp.getRhs();
2414 
2415  // %0 = arith.cmpi eq, %arg0, %arg1
2416  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2417 
2418  // %0 = arith.cmpi ne, %arg0, %arg1
2419  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2420 
2421  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2422  (cmpRhs == trueVal && cmpLhs == falseVal))
2423  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2424  }
2425  }
2426 
2427  // Constant-fold constant operands over non-splat constant condition.
2428  // select %cst_vec, %cst0, %cst1 => %cst2
2429  if (auto cond =
2430  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2431  if (auto lhs =
2432  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2433  if (auto rhs =
2434  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2435  SmallVector<Attribute> results;
2436  results.reserve(static_cast<size_t>(cond.getNumElements()));
2437  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2438  cond.value_end<BoolAttr>());
2439  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2440  lhs.value_end<Attribute>());
2441  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2442  rhs.value_end<Attribute>());
2443 
2444  for (auto [condVal, lhsVal, rhsVal] :
2445  llvm::zip_equal(condVals, lhsVals, rhsVals))
2446  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2447 
2448  return DenseElementsAttr::get(lhs.getType(), results);
2449  }
2450  }
2451  }
2452 
2453  return nullptr;
2454 }
2455 
2456 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2457  Type conditionType, resultType;
2459  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2460  parser.parseOptionalAttrDict(result.attributes) ||
2461  parser.parseColonType(resultType))
2462  return failure();
2463 
2464  // Check for the explicit condition type if this is a masked tensor or vector.
2465  if (succeeded(parser.parseOptionalComma())) {
2466  conditionType = resultType;
2467  if (parser.parseType(resultType))
2468  return failure();
2469  } else {
2470  conditionType = parser.getBuilder().getI1Type();
2471  }
2472 
2473  result.addTypes(resultType);
2474  return parser.resolveOperands(operands,
2475  {conditionType, resultType, resultType},
2476  parser.getNameLoc(), result.operands);
2477 }
2478 
2480  p << " " << getOperands();
2481  p.printOptionalAttrDict((*this)->getAttrs());
2482  p << " : ";
2483  if (ShapedType condType =
2484  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2485  p << condType << ", ";
2486  p << getType();
2487 }
2488 
2489 LogicalResult arith::SelectOp::verify() {
2490  Type conditionType = getCondition().getType();
2491  if (conditionType.isSignlessInteger(1))
2492  return success();
2493 
2494  // If the result type is a vector or tensor, the type can be a mask with the
2495  // same elements.
2496  Type resultType = getType();
2497  if (!llvm::isa<TensorType, VectorType>(resultType))
2498  return emitOpError() << "expected condition to be a signless i1, but got "
2499  << conditionType;
2500  Type shapedConditionType = getI1SameShape(resultType);
2501  if (conditionType != shapedConditionType) {
2502  return emitOpError() << "expected condition type to have the same shape "
2503  "as the result type, expected "
2504  << shapedConditionType << ", but got "
2505  << conditionType;
2506  }
2507  return success();
2508 }
2509 //===----------------------------------------------------------------------===//
2510 // ShLIOp
2511 //===----------------------------------------------------------------------===//
2512 
2513 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2514  // shli(x, 0) -> x
2515  if (matchPattern(adaptor.getRhs(), m_Zero()))
2516  return getLhs();
2517  // Don't fold if shifting more or equal than the bit width.
2518  bool bounded = false;
2519  auto result = constFoldBinaryOp<IntegerAttr>(
2520  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2521  bounded = b.ult(b.getBitWidth());
2522  return a.shl(b);
2523  });
2524  return bounded ? result : Attribute();
2525 }
2526 
2527 //===----------------------------------------------------------------------===//
2528 // ShRUIOp
2529 //===----------------------------------------------------------------------===//
2530 
2531 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2532  // shrui(x, 0) -> x
2533  if (matchPattern(adaptor.getRhs(), m_Zero()))
2534  return getLhs();
2535  // Don't fold if shifting more or equal than the bit width.
2536  bool bounded = false;
2537  auto result = constFoldBinaryOp<IntegerAttr>(
2538  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2539  bounded = b.ult(b.getBitWidth());
2540  return a.lshr(b);
2541  });
2542  return bounded ? result : Attribute();
2543 }
2544 
2545 //===----------------------------------------------------------------------===//
2546 // ShRSIOp
2547 //===----------------------------------------------------------------------===//
2548 
2549 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2550  // shrsi(x, 0) -> x
2551  if (matchPattern(adaptor.getRhs(), m_Zero()))
2552  return getLhs();
2553  // Don't fold if shifting more or equal than the bit width.
2554  bool bounded = false;
2555  auto result = constFoldBinaryOp<IntegerAttr>(
2556  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2557  bounded = b.ult(b.getBitWidth());
2558  return a.ashr(b);
2559  });
2560  return bounded ? result : Attribute();
2561 }
2562 
2563 //===----------------------------------------------------------------------===//
2564 // Atomic Enum
2565 //===----------------------------------------------------------------------===//
2566 
2567 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2568 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2569  OpBuilder &builder, Location loc,
2570  bool useOnlyFiniteValue) {
2571  switch (kind) {
2572  case AtomicRMWKind::maximumf: {
2573  const llvm::fltSemantics &semantic =
2574  llvm::cast<FloatType>(resultType).getFloatSemantics();
2575  APFloat identity = useOnlyFiniteValue
2576  ? APFloat::getLargest(semantic, /*Negative=*/true)
2577  : APFloat::getInf(semantic, /*Negative=*/true);
2578  return builder.getFloatAttr(resultType, identity);
2579  }
2580  case AtomicRMWKind::maxnumf: {
2581  const llvm::fltSemantics &semantic =
2582  llvm::cast<FloatType>(resultType).getFloatSemantics();
2583  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2584  return builder.getFloatAttr(resultType, identity);
2585  }
2586  case AtomicRMWKind::addf:
2587  case AtomicRMWKind::addi:
2588  case AtomicRMWKind::maxu:
2589  case AtomicRMWKind::ori:
2590  return builder.getZeroAttr(resultType);
2591  case AtomicRMWKind::andi:
2592  return builder.getIntegerAttr(
2593  resultType,
2594  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2595  case AtomicRMWKind::maxs:
2596  return builder.getIntegerAttr(
2597  resultType, APInt::getSignedMinValue(
2598  llvm::cast<IntegerType>(resultType).getWidth()));
2599  case AtomicRMWKind::minimumf: {
2600  const llvm::fltSemantics &semantic =
2601  llvm::cast<FloatType>(resultType).getFloatSemantics();
2602  APFloat identity = useOnlyFiniteValue
2603  ? APFloat::getLargest(semantic, /*Negative=*/false)
2604  : APFloat::getInf(semantic, /*Negative=*/false);
2605 
2606  return builder.getFloatAttr(resultType, identity);
2607  }
2608  case AtomicRMWKind::minnumf: {
2609  const llvm::fltSemantics &semantic =
2610  llvm::cast<FloatType>(resultType).getFloatSemantics();
2611  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2612  return builder.getFloatAttr(resultType, identity);
2613  }
2614  case AtomicRMWKind::mins:
2615  return builder.getIntegerAttr(
2616  resultType, APInt::getSignedMaxValue(
2617  llvm::cast<IntegerType>(resultType).getWidth()));
2618  case AtomicRMWKind::minu:
2619  return builder.getIntegerAttr(
2620  resultType,
2621  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2622  case AtomicRMWKind::muli:
2623  return builder.getIntegerAttr(resultType, 1);
2624  case AtomicRMWKind::mulf:
2625  return builder.getFloatAttr(resultType, 1);
2626  // TODO: Add remaining reduction operations.
2627  default:
2628  (void)emitOptionalError(loc, "Reduction operation type not supported");
2629  break;
2630  }
2631  return nullptr;
2632 }
2633 
2634 /// Return the identity numeric value associated to the give op.
2635 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2636  std::optional<AtomicRMWKind> maybeKind =
2638  // Floating-point operations.
2639  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2640  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2641  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2642  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2643  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2644  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2645  // Integer operations.
2646  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2647  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2648  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2649  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2650  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2651  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2652  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2653  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2654  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2655  .Default([](Operation *op) { return std::nullopt; });
2656  if (!maybeKind) {
2657  return std::nullopt;
2658  }
2659 
2660  bool useOnlyFiniteValue = false;
2661  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2662  if (fmfOpInterface) {
2663  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2664  useOnlyFiniteValue =
2665  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2666  }
2667 
2668  // Builder only used as helper for attribute creation.
2669  OpBuilder b(op->getContext());
2670  Type resultType = op->getResult(0).getType();
2671 
2672  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2673  useOnlyFiniteValue);
2674 }
2675 
2676 /// Returns the identity value associated with an AtomicRMWKind op.
2677 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2678  OpBuilder &builder, Location loc,
2679  bool useOnlyFiniteValue) {
2680  auto attr =
2681  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2682  return builder.create<arith::ConstantOp>(loc, attr);
2683 }
2684 
2685 /// Return the value obtained by applying the reduction operation kind
2686 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2687 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2688  Location loc, Value lhs, Value rhs) {
2689  switch (op) {
2690  case AtomicRMWKind::addf:
2691  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2692  case AtomicRMWKind::addi:
2693  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2694  case AtomicRMWKind::mulf:
2695  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2696  case AtomicRMWKind::muli:
2697  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2698  case AtomicRMWKind::maximumf:
2699  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2700  case AtomicRMWKind::minimumf:
2701  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2702  case AtomicRMWKind::maxnumf:
2703  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2704  case AtomicRMWKind::minnumf:
2705  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2706  case AtomicRMWKind::maxs:
2707  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2708  case AtomicRMWKind::mins:
2709  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2710  case AtomicRMWKind::maxu:
2711  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2712  case AtomicRMWKind::minu:
2713  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2714  case AtomicRMWKind::ori:
2715  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2716  case AtomicRMWKind::andi:
2717  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2718  // TODO: Add remaining reduction operations.
2719  default:
2720  (void)emitOptionalError(loc, "Reduction operation type not supported");
2721  break;
2722  }
2723  return nullptr;
2724 }
2725 
2726 //===----------------------------------------------------------------------===//
2727 // TableGen'd op method definitions
2728 //===----------------------------------------------------------------------===//
2729 
2730 #define GET_OP_CLASSES
2731 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2732 
2733 //===----------------------------------------------------------------------===//
2734 // TableGen'd enum attribute definitions
2735 //===----------------------------------------------------------------------===//
2736 
2737 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:618
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1326
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1262
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1867
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:579
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1276
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1248
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:678
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:660
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1341
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1719
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1617
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1298
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:859
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1269
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1885
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1284
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1242
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:340
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1312
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2076
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2049
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:252
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
IntegerType getI1Type()
Definition: Builders.cpp:55
IndexType getIndexType()
Definition: Builders.cpp:53
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:273
static bool classof(Operation *op)
Definition: ArithOps.cpp:279
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:285
static bool classof(Operation *op)
Definition: ArithOps.cpp:291
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:252
static bool classof(Operation *op)
Definition: ArithOps.cpp:267
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2635
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1839
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2568
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2687
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2677
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2344
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes