MLIR  20.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // The value's type must match the return type.
211  if (getValue().getType() != type) {
212  return emitOpError() << "value type " << getValue().getType()
213  << " must match return type: " << type;
214  }
215  // Integer values must be signless.
216  if (llvm::isa<IntegerType>(type) &&
217  !llvm::cast<IntegerType>(type).isSignless())
218  return emitOpError("integer return type must be signless");
219  // Any float or elements attribute are acceptable.
220  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
221  return emitOpError(
222  "value must be an integer, float, or elements attribute");
223  }
224 
225  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
226  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
227  // However, this would most likely require updating the lowerings to LLVM.
228  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
229  return emitOpError(
230  "intializing scalable vectors with elements attribute is not supported"
231  " unless it's a vector splat");
232  return success();
233 }
234 
235 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
236  // The value's type must be the same as the provided type.
237  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
238  if (!typedAttr || typedAttr.getType() != type)
239  return false;
240  // Integer values must be signless.
241  if (llvm::isa<IntegerType>(type) &&
242  !llvm::cast<IntegerType>(type).isSignless())
243  return false;
244  // Integer, float, and element attributes are buildable.
245  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
246 }
247 
248 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
249  Type type, Location loc) {
250  if (isBuildableWith(value, type))
251  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
252  return nullptr;
253 }
254 
255 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
256 
258  int64_t value, unsigned width) {
259  auto type = builder.getIntegerType(width);
260  arith::ConstantOp::build(builder, result, type,
261  builder.getIntegerAttr(type, value));
262 }
263 
265  int64_t value, Type type) {
266  assert(type.isSignlessInteger() &&
267  "ConstantIntOp can only have signless integer type values");
268  arith::ConstantOp::build(builder, result, type,
269  builder.getIntegerAttr(type, value));
270 }
271 
273  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
274  return constOp.getType().isSignlessInteger();
275  return false;
276 }
277 
279  const APFloat &value, FloatType type) {
280  arith::ConstantOp::build(builder, result, type,
281  builder.getFloatAttr(type, value));
282 }
283 
285  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
286  return llvm::isa<FloatType>(constOp.getType());
287  return false;
288 }
289 
291  int64_t value) {
292  arith::ConstantOp::build(builder, result, builder.getIndexType(),
293  builder.getIndexAttr(value));
294 }
295 
297  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
298  return constOp.getType().isIndex();
299  return false;
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // AddIOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
307  // addi(x, 0) -> x
308  if (matchPattern(adaptor.getRhs(), m_Zero()))
309  return getLhs();
310 
311  // addi(subi(a, b), b) -> a
312  if (auto sub = getLhs().getDefiningOp<SubIOp>())
313  if (getRhs() == sub.getRhs())
314  return sub.getLhs();
315 
316  // addi(b, subi(a, b)) -> a
317  if (auto sub = getRhs().getDefiningOp<SubIOp>())
318  if (getLhs() == sub.getRhs())
319  return sub.getLhs();
320 
321  return constFoldBinaryOp<IntegerAttr>(
322  adaptor.getOperands(),
323  [](APInt a, const APInt &b) { return std::move(a) + b; });
324 }
325 
326 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
327  MLIRContext *context) {
328  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
329  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // AddUIExtendedOp
334 //===----------------------------------------------------------------------===//
335 
336 std::optional<SmallVector<int64_t, 4>>
337 arith::AddUIExtendedOp::getShapeForUnroll() {
338  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
339  return llvm::to_vector<4>(vt.getShape());
340  return std::nullopt;
341 }
342 
343 // Returns the overflow bit, assuming that `sum` is the result of unsigned
344 // addition of `operand` and another number.
345 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
346  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
347 }
348 
349 LogicalResult
350 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
352  Type overflowTy = getOverflow().getType();
353  // addui_extended(x, 0) -> x, false
354  if (matchPattern(getRhs(), m_Zero())) {
355  Builder builder(getContext());
356  auto falseValue = builder.getZeroAttr(overflowTy);
357 
358  results.push_back(getLhs());
359  results.push_back(falseValue);
360  return success();
361  }
362 
363  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
364  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
365  // operands. If that succeeds, calculate the overflow bit based on the sum
366  // and the first (constant) operand, `lhs`.
367  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
368  adaptor.getOperands(),
369  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
370  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
371  ArrayRef({sumAttr, adaptor.getLhs()}),
372  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
374  if (!overflowAttr)
375  return failure();
376 
377  results.push_back(sumAttr);
378  results.push_back(overflowAttr);
379  return success();
380  }
381 
382  return failure();
383 }
384 
385 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
386  RewritePatternSet &patterns, MLIRContext *context) {
387  patterns.add<AddUIExtendedToAddI>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // SubIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395  // subi(x,x) -> 0
396  if (getOperand(0) == getOperand(1)) {
397  auto shapedType = dyn_cast<ShapedType>(getType());
398  // We can't generate a constant with a dynamic shaped tensor.
399  if (!shapedType || shapedType.hasStaticShape())
400  return Builder(getContext()).getZeroAttr(getType());
401  }
402  // subi(x,0) -> x
403  if (matchPattern(adaptor.getRhs(), m_Zero()))
404  return getLhs();
405 
406  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
407  // subi(addi(a, b), b) -> a
408  if (getRhs() == add.getRhs())
409  return add.getLhs();
410  // subi(addi(a, b), a) -> b
411  if (getRhs() == add.getLhs())
412  return add.getRhs();
413  }
414 
415  return constFoldBinaryOp<IntegerAttr>(
416  adaptor.getOperands(),
417  [](APInt a, const APInt &b) { return std::move(a) - b; });
418 }
419 
420 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
421  MLIRContext *context) {
422  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
423  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
424  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // MulIOp
429 //===----------------------------------------------------------------------===//
430 
431 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
432  // muli(x, 0) -> 0
433  if (matchPattern(adaptor.getRhs(), m_Zero()))
434  return getRhs();
435  // muli(x, 1) -> x
436  if (matchPattern(adaptor.getRhs(), m_One()))
437  return getLhs();
438  // TODO: Handle the overflow case.
439 
440  // default folder
441  return constFoldBinaryOp<IntegerAttr>(
442  adaptor.getOperands(),
443  [](const APInt &a, const APInt &b) { return a * b; });
444 }
445 
446 void arith::MulIOp::getAsmResultNames(
447  function_ref<void(Value, StringRef)> setNameFn) {
448  if (!isa<IndexType>(getType()))
449  return;
450 
451  // Match vector.vscale by name to avoid depending on the vector dialect (which
452  // is a circular dependency).
453  auto isVscale = [](Operation *op) {
454  return op && op->getName().getStringRef() == "vector.vscale";
455  };
456 
457  IntegerAttr baseValue;
458  auto isVscaleExpr = [&](Value a, Value b) {
459  return matchPattern(a, m_Constant(&baseValue)) &&
460  isVscale(b.getDefiningOp());
461  };
462 
463  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
464  return;
465 
466  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
467  SmallString<32> specialNameBuffer;
468  llvm::raw_svector_ostream specialName(specialNameBuffer);
469  specialName << 'c' << baseValue.getInt() << "_vscale";
470  setNameFn(getResult(), specialName.str());
471 }
472 
473 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
474  MLIRContext *context) {
475  patterns.add<MulIMulIConstant>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // MulSIExtendedOp
480 //===----------------------------------------------------------------------===//
481 
482 std::optional<SmallVector<int64_t, 4>>
483 arith::MulSIExtendedOp::getShapeForUnroll() {
484  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
485  return llvm::to_vector<4>(vt.getShape());
486  return std::nullopt;
487 }
488 
489 LogicalResult
490 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
492  // mulsi_extended(x, 0) -> 0, 0
493  if (matchPattern(adaptor.getRhs(), m_Zero())) {
494  Attribute zero = adaptor.getRhs();
495  results.push_back(zero);
496  results.push_back(zero);
497  return success();
498  }
499 
500  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
501  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
502  adaptor.getOperands(),
503  [](const APInt &a, const APInt &b) { return a * b; })) {
504  // Invoke the constant fold helper again to calculate the 'high' result.
505  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
506  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
507  return llvm::APIntOps::mulhs(a, b);
508  });
509  assert(highAttr && "Unexpected constant-folding failure");
510 
511  results.push_back(lowAttr);
512  results.push_back(highAttr);
513  return success();
514  }
515 
516  return failure();
517 }
518 
519 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
520  RewritePatternSet &patterns, MLIRContext *context) {
521  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // MulUIExtendedOp
526 //===----------------------------------------------------------------------===//
527 
528 std::optional<SmallVector<int64_t, 4>>
529 arith::MulUIExtendedOp::getShapeForUnroll() {
530  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
531  return llvm::to_vector<4>(vt.getShape());
532  return std::nullopt;
533 }
534 
535 LogicalResult
536 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
538  // mului_extended(x, 0) -> 0, 0
539  if (matchPattern(adaptor.getRhs(), m_Zero())) {
540  Attribute zero = adaptor.getRhs();
541  results.push_back(zero);
542  results.push_back(zero);
543  return success();
544  }
545 
546  // mului_extended(x, 1) -> x, 0
547  if (matchPattern(adaptor.getRhs(), m_One())) {
548  Builder builder(getContext());
549  Attribute zero = builder.getZeroAttr(getLhs().getType());
550  results.push_back(getLhs());
551  results.push_back(zero);
552  return success();
553  }
554 
555  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
556  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
557  adaptor.getOperands(),
558  [](const APInt &a, const APInt &b) { return a * b; })) {
559  // Invoke the constant fold helper again to calculate the 'high' result.
560  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
561  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
562  return llvm::APIntOps::mulhu(a, b);
563  });
564  assert(highAttr && "Unexpected constant-folding failure");
565 
566  results.push_back(lowAttr);
567  results.push_back(highAttr);
568  return success();
569  }
570 
571  return failure();
572 }
573 
574 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
575  RewritePatternSet &patterns, MLIRContext *context) {
576  patterns.add<MulUIExtendedToMulI>(context);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // DivUIOp
581 //===----------------------------------------------------------------------===//
582 
583 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
584  // divui (x, 1) -> x.
585  if (matchPattern(adaptor.getRhs(), m_One()))
586  return getLhs();
587 
588  // Don't fold if it would require a division by zero.
589  bool div0 = false;
590  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
591  [&](APInt a, const APInt &b) {
592  if (div0 || !b) {
593  div0 = true;
594  return a;
595  }
596  return a.udiv(b);
597  });
598 
599  return div0 ? Attribute() : result;
600 }
601 
602 /// Returns whether an unsigned division by `divisor` is speculatable.
604  // X / 0 => UB
605  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
607 
609 }
610 
611 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
612  return getDivUISpeculatability(getRhs());
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // DivSIOp
617 //===----------------------------------------------------------------------===//
618 
619 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
620  // divsi (x, 1) -> x.
621  if (matchPattern(adaptor.getRhs(), m_One()))
622  return getLhs();
623 
624  // Don't fold if it would overflow or if it requires a division by zero.
625  bool overflowOrDiv0 = false;
626  auto result = constFoldBinaryOp<IntegerAttr>(
627  adaptor.getOperands(), [&](APInt a, const APInt &b) {
628  if (overflowOrDiv0 || !b) {
629  overflowOrDiv0 = true;
630  return a;
631  }
632  return a.sdiv_ov(b, overflowOrDiv0);
633  });
634 
635  return overflowOrDiv0 ? Attribute() : result;
636 }
637 
638 /// Returns whether a signed division by `divisor` is speculatable. This
639 /// function conservatively assumes that all signed division by -1 are not
640 /// speculatable.
642  // X / 0 => UB
643  // INT_MIN / -1 => UB
644  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
647 
649 }
650 
651 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
652  return getDivSISpeculatability(getRhs());
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // Ceil and floor division folding helpers
657 //===----------------------------------------------------------------------===//
658 
659 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
660  bool &overflow) {
661  // Returns (a-1)/b + 1
662  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
663  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
664  return val.sadd_ov(one, overflow);
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // CeilDivUIOp
669 //===----------------------------------------------------------------------===//
670 
671 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
672  // ceildivui (x, 1) -> x.
673  if (matchPattern(adaptor.getRhs(), m_One()))
674  return getLhs();
675 
676  bool overflowOrDiv0 = false;
677  auto result = constFoldBinaryOp<IntegerAttr>(
678  adaptor.getOperands(), [&](APInt a, const APInt &b) {
679  if (overflowOrDiv0 || !b) {
680  overflowOrDiv0 = true;
681  return a;
682  }
683  APInt quotient = a.udiv(b);
684  if (!a.urem(b))
685  return quotient;
686  APInt one(a.getBitWidth(), 1, true);
687  return quotient.uadd_ov(one, overflowOrDiv0);
688  });
689 
690  return overflowOrDiv0 ? Attribute() : result;
691 }
692 
693 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
694  return getDivUISpeculatability(getRhs());
695 }
696 
697 //===----------------------------------------------------------------------===//
698 // CeilDivSIOp
699 //===----------------------------------------------------------------------===//
700 
701 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
702  // ceildivsi (x, 1) -> x.
703  if (matchPattern(adaptor.getRhs(), m_One()))
704  return getLhs();
705 
706  // Don't fold if it would overflow or if it requires a division by zero.
707  // TODO: This hook won't fold operations where a = MININT, because
708  // negating MININT overflows. This can be improved.
709  bool overflowOrDiv0 = false;
710  auto result = constFoldBinaryOp<IntegerAttr>(
711  adaptor.getOperands(), [&](APInt a, const APInt &b) {
712  if (overflowOrDiv0 || !b) {
713  overflowOrDiv0 = true;
714  return a;
715  }
716  if (!a)
717  return a;
718  // After this point we know that neither a or b are zero.
719  unsigned bits = a.getBitWidth();
720  APInt zero = APInt::getZero(bits);
721  bool aGtZero = a.sgt(zero);
722  bool bGtZero = b.sgt(zero);
723  if (aGtZero && bGtZero) {
724  // Both positive, return ceil(a, b).
725  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
726  }
727 
728  // No folding happens if any of the intermediate arithmetic operations
729  // overflows.
730  bool overflowNegA = false;
731  bool overflowNegB = false;
732  bool overflowDiv = false;
733  bool overflowNegRes = false;
734  if (!aGtZero && !bGtZero) {
735  // Both negative, return ceil(-a, -b).
736  APInt posA = zero.ssub_ov(a, overflowNegA);
737  APInt posB = zero.ssub_ov(b, overflowNegB);
738  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
739  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
740  return res;
741  }
742  if (!aGtZero && bGtZero) {
743  // A is negative, b is positive, return - ( -a / b).
744  APInt posA = zero.ssub_ov(a, overflowNegA);
745  APInt div = posA.sdiv_ov(b, overflowDiv);
746  APInt res = zero.ssub_ov(div, overflowNegRes);
747  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
748  return res;
749  }
750  // A is positive, b is negative, return - (a / -b).
751  APInt posB = zero.ssub_ov(b, overflowNegB);
752  APInt div = a.sdiv_ov(posB, overflowDiv);
753  APInt res = zero.ssub_ov(div, overflowNegRes);
754 
755  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
756  return res;
757  });
758 
759  return overflowOrDiv0 ? Attribute() : result;
760 }
761 
762 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
763  return getDivSISpeculatability(getRhs());
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // FloorDivSIOp
768 //===----------------------------------------------------------------------===//
769 
770 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
771  // floordivsi (x, 1) -> x.
772  if (matchPattern(adaptor.getRhs(), m_One()))
773  return getLhs();
774 
775  // Don't fold if it would overflow or if it requires a division by zero.
776  bool overflowOrDiv = false;
777  auto result = constFoldBinaryOp<IntegerAttr>(
778  adaptor.getOperands(), [&](APInt a, const APInt &b) {
779  if (b.isZero()) {
780  overflowOrDiv = true;
781  return a;
782  }
783  return a.sfloordiv_ov(b, overflowOrDiv);
784  });
785 
786  return overflowOrDiv ? Attribute() : result;
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // RemUIOp
791 //===----------------------------------------------------------------------===//
792 
793 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
794  // remui (x, 1) -> 0.
795  if (matchPattern(adaptor.getRhs(), m_One()))
796  return Builder(getContext()).getZeroAttr(getType());
797 
798  // Don't fold if it would require a division by zero.
799  bool div0 = false;
800  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
801  [&](APInt a, const APInt &b) {
802  if (div0 || b.isZero()) {
803  div0 = true;
804  return a;
805  }
806  return a.urem(b);
807  });
808 
809  return div0 ? Attribute() : result;
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // RemSIOp
814 //===----------------------------------------------------------------------===//
815 
816 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
817  // remsi (x, 1) -> 0.
818  if (matchPattern(adaptor.getRhs(), m_One()))
819  return Builder(getContext()).getZeroAttr(getType());
820 
821  // Don't fold if it would require a division by zero.
822  bool div0 = false;
823  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
824  [&](APInt a, const APInt &b) {
825  if (div0 || b.isZero()) {
826  div0 = true;
827  return a;
828  }
829  return a.srem(b);
830  });
831 
832  return div0 ? Attribute() : result;
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // AndIOp
837 //===----------------------------------------------------------------------===//
838 
839 /// Fold `and(a, and(a, b))` to `and(a, b)`
840 static Value foldAndIofAndI(arith::AndIOp op) {
841  for (bool reversePrev : {false, true}) {
842  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
843  .getDefiningOp<arith::AndIOp>();
844  if (!prev)
845  continue;
846 
847  Value other = (reversePrev ? op.getLhs() : op.getRhs());
848  if (other != prev.getLhs() && other != prev.getRhs())
849  continue;
850 
851  return prev.getResult();
852  }
853  return {};
854 }
855 
856 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
857  /// and(x, 0) -> 0
858  if (matchPattern(adaptor.getRhs(), m_Zero()))
859  return getRhs();
860  /// and(x, allOnes) -> x
861  APInt intValue;
862  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
863  intValue.isAllOnes())
864  return getLhs();
865  /// and(x, not(x)) -> 0
866  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
867  m_ConstantInt(&intValue))) &&
868  intValue.isAllOnes())
869  return Builder(getContext()).getZeroAttr(getType());
870  /// and(not(x), x) -> 0
871  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
872  m_ConstantInt(&intValue))) &&
873  intValue.isAllOnes())
874  return Builder(getContext()).getZeroAttr(getType());
875 
876  /// and(a, and(a, b)) -> and(a, b)
877  if (Value result = foldAndIofAndI(*this))
878  return result;
879 
880  return constFoldBinaryOp<IntegerAttr>(
881  adaptor.getOperands(),
882  [](APInt a, const APInt &b) { return std::move(a) & b; });
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // OrIOp
887 //===----------------------------------------------------------------------===//
888 
889 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
890  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
891  /// or(x, 0) -> x
892  if (rhsVal.isZero())
893  return getLhs();
894  /// or(x, <all ones>) -> <all ones>
895  if (rhsVal.isAllOnes())
896  return adaptor.getRhs();
897  }
898 
899  APInt intValue;
900  /// or(x, xor(x, 1)) -> 1
901  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
902  m_ConstantInt(&intValue))) &&
903  intValue.isAllOnes())
904  return getRhs().getDefiningOp<XOrIOp>().getRhs();
905  /// or(xor(x, 1), x) -> 1
906  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
907  m_ConstantInt(&intValue))) &&
908  intValue.isAllOnes())
909  return getLhs().getDefiningOp<XOrIOp>().getRhs();
910 
911  return constFoldBinaryOp<IntegerAttr>(
912  adaptor.getOperands(),
913  [](APInt a, const APInt &b) { return std::move(a) | b; });
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // XOrIOp
918 //===----------------------------------------------------------------------===//
919 
920 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
921  /// xor(x, 0) -> x
922  if (matchPattern(adaptor.getRhs(), m_Zero()))
923  return getLhs();
924  /// xor(x, x) -> 0
925  if (getLhs() == getRhs())
926  return Builder(getContext()).getZeroAttr(getType());
927  /// xor(xor(x, a), a) -> x
928  /// xor(xor(a, x), a) -> x
929  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
930  if (prev.getRhs() == getRhs())
931  return prev.getLhs();
932  if (prev.getLhs() == getRhs())
933  return prev.getRhs();
934  }
935  /// xor(a, xor(x, a)) -> x
936  /// xor(a, xor(a, x)) -> x
937  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
938  if (prev.getRhs() == getLhs())
939  return prev.getLhs();
940  if (prev.getLhs() == getLhs())
941  return prev.getRhs();
942  }
943 
944  return constFoldBinaryOp<IntegerAttr>(
945  adaptor.getOperands(),
946  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
947 }
948 
949 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
950  MLIRContext *context) {
951  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // NegFOp
956 //===----------------------------------------------------------------------===//
957 
958 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
959  // negf(negf(x)) -> x
960  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
961  return op.getOperand();
962  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
963  [](const APFloat &a) { return -a; });
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // AddFOp
968 //===----------------------------------------------------------------------===//
969 
970 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
971  // addf(x, -0) -> x
972  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
973  return getLhs();
974 
975  return constFoldBinaryOp<FloatAttr>(
976  adaptor.getOperands(),
977  [](const APFloat &a, const APFloat &b) { return a + b; });
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // SubFOp
982 //===----------------------------------------------------------------------===//
983 
984 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
985  // subf(x, +0) -> x
986  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
987  return getLhs();
988 
989  // Simplifies subf(x, rhs) to x if the following conditions are met:
990  // 1. `rhs` is a denormal floating-point value.
991  // 2. The denormal mode for the operation is set to positive zero.
992  bool isPositiveZeroMode =
993  getDenormalModeAttr().getValue() == DenormalMode::positive_zero;
994  if (isPositiveZeroMode && matchPattern(adaptor.getRhs(), m_isDenormalFloat()))
995  return getLhs();
996 
997  return constFoldBinaryOp<FloatAttr>(
998  adaptor.getOperands(),
999  [](const APFloat &a, const APFloat &b) { return a - b; });
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // MaximumFOp
1004 //===----------------------------------------------------------------------===//
1005 
1006 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1007  // maximumf(x,x) -> x
1008  if (getLhs() == getRhs())
1009  return getRhs();
1010 
1011  // maximumf(x, -inf) -> x
1012  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1013  return getLhs();
1014 
1015  return constFoldBinaryOp<FloatAttr>(
1016  adaptor.getOperands(),
1017  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // MaxNumFOp
1022 //===----------------------------------------------------------------------===//
1023 
1024 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1025  // maxnumf(x,x) -> x
1026  if (getLhs() == getRhs())
1027  return getRhs();
1028 
1029  // maxnumf(x, NaN) -> x
1030  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1031  return getLhs();
1032 
1033  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // MaxSIOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1041  // maxsi(x,x) -> x
1042  if (getLhs() == getRhs())
1043  return getRhs();
1044 
1045  if (APInt intValue;
1046  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1047  // maxsi(x,MAX_INT) -> MAX_INT
1048  if (intValue.isMaxSignedValue())
1049  return getRhs();
1050  // maxsi(x, MIN_INT) -> x
1051  if (intValue.isMinSignedValue())
1052  return getLhs();
1053  }
1054 
1055  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1056  [](const APInt &a, const APInt &b) {
1057  return llvm::APIntOps::smax(a, b);
1058  });
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // MaxUIOp
1063 //===----------------------------------------------------------------------===//
1064 
1065 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1066  // maxui(x,x) -> x
1067  if (getLhs() == getRhs())
1068  return getRhs();
1069 
1070  if (APInt intValue;
1071  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1072  // maxui(x,MAX_INT) -> MAX_INT
1073  if (intValue.isMaxValue())
1074  return getRhs();
1075  // maxui(x, MIN_INT) -> x
1076  if (intValue.isMinValue())
1077  return getLhs();
1078  }
1079 
1080  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1081  [](const APInt &a, const APInt &b) {
1082  return llvm::APIntOps::umax(a, b);
1083  });
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // MinimumFOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1091  // minimumf(x,x) -> x
1092  if (getLhs() == getRhs())
1093  return getRhs();
1094 
1095  // minimumf(x, +inf) -> x
1096  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1097  return getLhs();
1098 
1099  return constFoldBinaryOp<FloatAttr>(
1100  adaptor.getOperands(),
1101  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1102 }
1103 
1104 //===----------------------------------------------------------------------===//
1105 // MinNumFOp
1106 //===----------------------------------------------------------------------===//
1107 
1108 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1109  // minnumf(x,x) -> x
1110  if (getLhs() == getRhs())
1111  return getRhs();
1112 
1113  // minnumf(x, NaN) -> x
1114  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1115  return getLhs();
1116 
1117  return constFoldBinaryOp<FloatAttr>(
1118  adaptor.getOperands(),
1119  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // MinSIOp
1124 //===----------------------------------------------------------------------===//
1125 
1126 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1127  // minsi(x,x) -> x
1128  if (getLhs() == getRhs())
1129  return getRhs();
1130 
1131  if (APInt intValue;
1132  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1133  // minsi(x,MIN_INT) -> MIN_INT
1134  if (intValue.isMinSignedValue())
1135  return getRhs();
1136  // minsi(x, MAX_INT) -> x
1137  if (intValue.isMaxSignedValue())
1138  return getLhs();
1139  }
1140 
1141  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1142  [](const APInt &a, const APInt &b) {
1143  return llvm::APIntOps::smin(a, b);
1144  });
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // MinUIOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1152  // minui(x,x) -> x
1153  if (getLhs() == getRhs())
1154  return getRhs();
1155 
1156  if (APInt intValue;
1157  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1158  // minui(x,MIN_INT) -> MIN_INT
1159  if (intValue.isMinValue())
1160  return getRhs();
1161  // minui(x, MAX_INT) -> x
1162  if (intValue.isMaxValue())
1163  return getLhs();
1164  }
1165 
1166  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1167  [](const APInt &a, const APInt &b) {
1168  return llvm::APIntOps::umin(a, b);
1169  });
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // MulFOp
1174 //===----------------------------------------------------------------------===//
1175 
1176 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1177  // mulf(x, 1) -> x
1178  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1179  return getLhs();
1180 
1181  return constFoldBinaryOp<FloatAttr>(
1182  adaptor.getOperands(),
1183  [](const APFloat &a, const APFloat &b) { return a * b; });
1184 }
1185 
1186 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1187  MLIRContext *context) {
1188  patterns.add<MulFOfNegF>(context);
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // DivFOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1196  // divf(x, 1) -> x
1197  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1198  return getLhs();
1199 
1200  return constFoldBinaryOp<FloatAttr>(
1201  adaptor.getOperands(),
1202  [](const APFloat &a, const APFloat &b) { return a / b; });
1203 }
1204 
1205 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1206  MLIRContext *context) {
1207  patterns.add<DivFOfNegF>(context);
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // RemFOp
1212 //===----------------------------------------------------------------------===//
1213 
1214 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1215  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1216  [](const APFloat &a, const APFloat &b) {
1217  APFloat result(a);
1218  // APFloat::mod() offers the remainder
1219  // behavior we want, i.e. the result has
1220  // the sign of LHS operand.
1221  (void)result.mod(b);
1222  return result;
1223  });
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // Utility functions for verifying cast ops
1228 //===----------------------------------------------------------------------===//
1229 
1230 template <typename... Types>
1231 using type_list = std::tuple<Types...> *;
1232 
1233 /// Returns a non-null type only if the provided type is one of the allowed
1234 /// types or one of the allowed shaped types of the allowed types. Returns the
1235 /// element type if a valid shaped type is provided.
1236 template <typename... ShapedTypes, typename... ElementTypes>
1239  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1240  return {};
1241 
1242  auto underlyingType = getElementTypeOrSelf(type);
1243  if (!llvm::isa<ElementTypes...>(underlyingType))
1244  return {};
1245 
1246  return underlyingType;
1247 }
1248 
1249 /// Get allowed underlying types for vectors and tensors.
1250 template <typename... ElementTypes>
1251 static Type getTypeIfLike(Type type) {
1254 }
1255 
1256 /// Get allowed underlying types for vectors, tensors, and memrefs.
1257 template <typename... ElementTypes>
1259  return getUnderlyingType(type,
1262 }
1263 
1264 /// Return false if both types are ranked tensor with mismatching encoding.
1265 static bool hasSameEncoding(Type typeA, Type typeB) {
1266  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1267  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1268  if (!rankedTensorA || !rankedTensorB)
1269  return true;
1270  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1271 }
1272 
1273 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1274  if (inputs.size() != 1 || outputs.size() != 1)
1275  return false;
1276  if (!hasSameEncoding(inputs.front(), outputs.front()))
1277  return false;
1278  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1279 }
1280 
1281 //===----------------------------------------------------------------------===//
1282 // Verifiers for integer and floating point extension/truncation ops
1283 //===----------------------------------------------------------------------===//
1284 
1285 // Extend ops can only extend to a wider type.
1286 template <typename ValType, typename Op>
1287 static LogicalResult verifyExtOp(Op op) {
1288  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1289  Type dstType = getElementTypeOrSelf(op.getType());
1290 
1291  if (llvm::cast<ValType>(srcType).getWidth() >=
1292  llvm::cast<ValType>(dstType).getWidth())
1293  return op.emitError("result type ")
1294  << dstType << " must be wider than operand type " << srcType;
1295 
1296  return success();
1297 }
1298 
1299 // Truncate ops can only truncate to a shorter type.
1300 template <typename ValType, typename Op>
1301 static LogicalResult verifyTruncateOp(Op op) {
1302  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1303  Type dstType = getElementTypeOrSelf(op.getType());
1304 
1305  if (llvm::cast<ValType>(srcType).getWidth() <=
1306  llvm::cast<ValType>(dstType).getWidth())
1307  return op.emitError("result type ")
1308  << dstType << " must be shorter than operand type " << srcType;
1309 
1310  return success();
1311 }
1312 
1313 /// Validate a cast that changes the width of a type.
1314 template <template <typename> class WidthComparator, typename... ElementTypes>
1315 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1316  if (!areValidCastInputsAndOutputs(inputs, outputs))
1317  return false;
1318 
1319  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1320  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1321  if (!srcType || !dstType)
1322  return false;
1323 
1324  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1325  srcType.getIntOrFloatBitWidth());
1326 }
1327 
1328 /// Attempts to convert `sourceValue` to an APFloat value with
1329 /// `targetSemantics` and `roundingMode`, without any information loss.
1330 static FailureOr<APFloat> convertFloatValue(
1331  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1332  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1333  bool losesInfo = false;
1334  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1335  if (losesInfo || status != APFloat::opOK)
1336  return failure();
1337 
1338  return sourceValue;
1339 }
1340 
1341 //===----------------------------------------------------------------------===//
1342 // ExtUIOp
1343 //===----------------------------------------------------------------------===//
1344 
1345 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1346  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1347  getInMutable().assign(lhs.getIn());
1348  return getResult();
1349  }
1350 
1351  Type resType = getElementTypeOrSelf(getType());
1352  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1353  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1354  adaptor.getOperands(), getType(),
1355  [bitWidth](const APInt &a, bool &castStatus) {
1356  return a.zext(bitWidth);
1357  });
1358 }
1359 
1360 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1361  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1362 }
1363 
1364 LogicalResult arith::ExtUIOp::verify() {
1365  return verifyExtOp<IntegerType>(*this);
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // ExtSIOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1373  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1374  getInMutable().assign(lhs.getIn());
1375  return getResult();
1376  }
1377 
1378  Type resType = getElementTypeOrSelf(getType());
1379  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1380  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1381  adaptor.getOperands(), getType(),
1382  [bitWidth](const APInt &a, bool &castStatus) {
1383  return a.sext(bitWidth);
1384  });
1385 }
1386 
1387 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1388  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1389 }
1390 
1391 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1392  MLIRContext *context) {
1393  patterns.add<ExtSIOfExtUI>(context);
1394 }
1395 
1396 LogicalResult arith::ExtSIOp::verify() {
1397  return verifyExtOp<IntegerType>(*this);
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // ExtFOp
1402 //===----------------------------------------------------------------------===//
1403 
1404 /// Fold extension of float constants when there is no information loss due the
1405 /// difference in fp semantics.
1406 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1407  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1408  if (truncFOp.getOperand().getType() == getType()) {
1409  arith::FastMathFlags truncFMF =
1410  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1411  bool isTruncContract =
1412  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1413  arith::FastMathFlags extFMF =
1414  getFastmath().value_or(arith::FastMathFlags::none);
1415  bool isExtContract =
1416  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1417  if (isTruncContract && isExtContract) {
1418  return truncFOp.getOperand();
1419  }
1420  }
1421  }
1422 
1423  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1424  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1425  return constFoldCastOp<FloatAttr, FloatAttr>(
1426  adaptor.getOperands(), getType(),
1427  [&targetSemantics](const APFloat &a, bool &castStatus) {
1428  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1429  if (failed(result)) {
1430  castStatus = false;
1431  return a;
1432  }
1433  return *result;
1434  });
1435 }
1436 
1437 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1438  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1439 }
1440 
1441 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // TruncIOp
1445 //===----------------------------------------------------------------------===//
1446 
1447 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1448  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1449  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1450  Value src = getOperand().getDefiningOp()->getOperand(0);
1451  Type srcType = getElementTypeOrSelf(src.getType());
1452  Type dstType = getElementTypeOrSelf(getType());
1453  // trunci(zexti(a)) -> trunci(a)
1454  // trunci(sexti(a)) -> trunci(a)
1455  if (llvm::cast<IntegerType>(srcType).getWidth() >
1456  llvm::cast<IntegerType>(dstType).getWidth()) {
1457  setOperand(src);
1458  return getResult();
1459  }
1460 
1461  // trunci(zexti(a)) -> a
1462  // trunci(sexti(a)) -> a
1463  if (srcType == dstType)
1464  return src;
1465  }
1466 
1467  // trunci(trunci(a)) -> trunci(a))
1468  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1469  setOperand(getOperand().getDefiningOp()->getOperand(0));
1470  return getResult();
1471  }
1472 
1473  Type resType = getElementTypeOrSelf(getType());
1474  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1475  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1476  adaptor.getOperands(), getType(),
1477  [bitWidth](const APInt &a, bool &castStatus) {
1478  return a.trunc(bitWidth);
1479  });
1480 }
1481 
1482 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1483  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1484 }
1485 
1486 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1487  MLIRContext *context) {
1488  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1489  TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1490  context);
1491 }
1492 
1493 LogicalResult arith::TruncIOp::verify() {
1494  return verifyTruncateOp<IntegerType>(*this);
1495 }
1496 
1497 //===----------------------------------------------------------------------===//
1498 // TruncFOp
1499 //===----------------------------------------------------------------------===//
1500 
1501 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1502 /// can be represented without precision loss.
1503 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1504  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1505  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1506  return constFoldCastOp<FloatAttr, FloatAttr>(
1507  adaptor.getOperands(), getType(),
1508  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1509  RoundingMode roundingMode =
1510  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1511  llvm::RoundingMode llvmRoundingMode =
1512  convertArithRoundingModeToLLVMIR(roundingMode);
1513  FailureOr<APFloat> result =
1514  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1515  if (failed(result)) {
1516  castStatus = false;
1517  return a;
1518  }
1519  return *result;
1520  });
1521 }
1522 
1523 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1524  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1525 }
1526 
1527 LogicalResult arith::TruncFOp::verify() {
1528  return verifyTruncateOp<FloatType>(*this);
1529 }
1530 
1531 //===----------------------------------------------------------------------===//
1532 // AndIOp
1533 //===----------------------------------------------------------------------===//
1534 
1535 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1536  MLIRContext *context) {
1537  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1538 }
1539 
1540 //===----------------------------------------------------------------------===//
1541 // OrIOp
1542 //===----------------------------------------------------------------------===//
1543 
1544 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1545  MLIRContext *context) {
1546  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1547 }
1548 
1549 //===----------------------------------------------------------------------===//
1550 // Verifiers for casts between integers and floats.
1551 //===----------------------------------------------------------------------===//
1552 
1553 template <typename From, typename To>
1554 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1555  if (!areValidCastInputsAndOutputs(inputs, outputs))
1556  return false;
1557 
1558  auto srcType = getTypeIfLike<From>(inputs.front());
1559  auto dstType = getTypeIfLike<To>(outputs.back());
1560 
1561  return srcType && dstType;
1562 }
1563 
1564 //===----------------------------------------------------------------------===//
1565 // UIToFPOp
1566 //===----------------------------------------------------------------------===//
1567 
1568 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1569  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1570 }
1571 
1572 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1573  Type resEleType = getElementTypeOrSelf(getType());
1574  return constFoldCastOp<IntegerAttr, FloatAttr>(
1575  adaptor.getOperands(), getType(),
1576  [&resEleType](const APInt &a, bool &castStatus) {
1577  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1578  APFloat apf(floatTy.getFloatSemantics(),
1579  APInt::getZero(floatTy.getWidth()));
1580  apf.convertFromAPInt(a, /*IsSigned=*/false,
1581  APFloat::rmNearestTiesToEven);
1582  return apf;
1583  });
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 // SIToFPOp
1588 //===----------------------------------------------------------------------===//
1589 
1590 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1591  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1592 }
1593 
1594 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1595  Type resEleType = getElementTypeOrSelf(getType());
1596  return constFoldCastOp<IntegerAttr, FloatAttr>(
1597  adaptor.getOperands(), getType(),
1598  [&resEleType](const APInt &a, bool &castStatus) {
1599  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1600  APFloat apf(floatTy.getFloatSemantics(),
1601  APInt::getZero(floatTy.getWidth()));
1602  apf.convertFromAPInt(a, /*IsSigned=*/true,
1603  APFloat::rmNearestTiesToEven);
1604  return apf;
1605  });
1606 }
1607 
1608 //===----------------------------------------------------------------------===//
1609 // FPToUIOp
1610 //===----------------------------------------------------------------------===//
1611 
1612 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1613  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1614 }
1615 
1616 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1617  Type resType = getElementTypeOrSelf(getType());
1618  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1619  return constFoldCastOp<FloatAttr, IntegerAttr>(
1620  adaptor.getOperands(), getType(),
1621  [&bitWidth](const APFloat &a, bool &castStatus) {
1622  bool ignored;
1623  APSInt api(bitWidth, /*isUnsigned=*/true);
1624  castStatus = APFloat::opInvalidOp !=
1625  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1626  return api;
1627  });
1628 }
1629 
1630 //===----------------------------------------------------------------------===//
1631 // FPToSIOp
1632 //===----------------------------------------------------------------------===//
1633 
1634 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1635  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1636 }
1637 
1638 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1639  Type resType = getElementTypeOrSelf(getType());
1640  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1641  return constFoldCastOp<FloatAttr, IntegerAttr>(
1642  adaptor.getOperands(), getType(),
1643  [&bitWidth](const APFloat &a, bool &castStatus) {
1644  bool ignored;
1645  APSInt api(bitWidth, /*isUnsigned=*/false);
1646  castStatus = APFloat::opInvalidOp !=
1647  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1648  return api;
1649  });
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // IndexCastOp
1654 //===----------------------------------------------------------------------===//
1655 
1656 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1657  if (!areValidCastInputsAndOutputs(inputs, outputs))
1658  return false;
1659 
1660  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1661  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1662  if (!srcType || !dstType)
1663  return false;
1664 
1665  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1666  (srcType.isSignlessInteger() && dstType.isIndex());
1667 }
1668 
1669 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1670  TypeRange outputs) {
1671  return areIndexCastCompatible(inputs, outputs);
1672 }
1673 
1674 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1675  // index_cast(constant) -> constant
1676  unsigned resultBitwidth = 64; // Default for index integer attributes.
1677  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1678  resultBitwidth = intTy.getWidth();
1679 
1680  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1681  adaptor.getOperands(), getType(),
1682  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1683  return a.sextOrTrunc(resultBitwidth);
1684  });
1685 }
1686 
1687 void arith::IndexCastOp::getCanonicalizationPatterns(
1688  RewritePatternSet &patterns, MLIRContext *context) {
1689  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1690 }
1691 
1692 //===----------------------------------------------------------------------===//
1693 // IndexCastUIOp
1694 //===----------------------------------------------------------------------===//
1695 
1696 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1697  TypeRange outputs) {
1698  return areIndexCastCompatible(inputs, outputs);
1699 }
1700 
1701 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1702  // index_castui(constant) -> constant
1703  unsigned resultBitwidth = 64; // Default for index integer attributes.
1704  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1705  resultBitwidth = intTy.getWidth();
1706 
1707  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1708  adaptor.getOperands(), getType(),
1709  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1710  return a.zextOrTrunc(resultBitwidth);
1711  });
1712 }
1713 
1714 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1715  RewritePatternSet &patterns, MLIRContext *context) {
1716  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1717 }
1718 
1719 //===----------------------------------------------------------------------===//
1720 // BitcastOp
1721 //===----------------------------------------------------------------------===//
1722 
1723 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1724  if (!areValidCastInputsAndOutputs(inputs, outputs))
1725  return false;
1726 
1727  auto srcType =
1728  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1729  auto dstType =
1730  getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1731  if (!srcType || !dstType)
1732  return false;
1733 
1734  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1735 }
1736 
1737 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1738  auto resType = getType();
1739  auto operand = adaptor.getIn();
1740  if (!operand)
1741  return {};
1742 
1743  /// Bitcast dense elements.
1744  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1745  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1746  /// Other shaped types unhandled.
1747  if (llvm::isa<ShapedType>(resType))
1748  return {};
1749 
1750  /// Bitcast integer or float to integer or float.
1751  APInt bits = llvm::isa<FloatAttr>(operand)
1752  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1753  : llvm::cast<IntegerAttr>(operand).getValue();
1754  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1755  "trying to fold on broken IR: operands have incompatible types");
1756 
1757  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1758  return FloatAttr::get(resType,
1759  APFloat(resFloatType.getFloatSemantics(), bits));
1760  return IntegerAttr::get(resType, bits);
1761 }
1762 
1763 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1764  MLIRContext *context) {
1765  patterns.add<BitcastOfBitcast>(context);
1766 }
1767 
1768 //===----------------------------------------------------------------------===//
1769 // CmpIOp
1770 //===----------------------------------------------------------------------===//
1771 
1772 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1773 /// comparison predicates.
1774 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1775  const APInt &lhs, const APInt &rhs) {
1776  switch (predicate) {
1777  case arith::CmpIPredicate::eq:
1778  return lhs.eq(rhs);
1779  case arith::CmpIPredicate::ne:
1780  return lhs.ne(rhs);
1781  case arith::CmpIPredicate::slt:
1782  return lhs.slt(rhs);
1783  case arith::CmpIPredicate::sle:
1784  return lhs.sle(rhs);
1785  case arith::CmpIPredicate::sgt:
1786  return lhs.sgt(rhs);
1787  case arith::CmpIPredicate::sge:
1788  return lhs.sge(rhs);
1789  case arith::CmpIPredicate::ult:
1790  return lhs.ult(rhs);
1791  case arith::CmpIPredicate::ule:
1792  return lhs.ule(rhs);
1793  case arith::CmpIPredicate::ugt:
1794  return lhs.ugt(rhs);
1795  case arith::CmpIPredicate::uge:
1796  return lhs.uge(rhs);
1797  }
1798  llvm_unreachable("unknown cmpi predicate kind");
1799 }
1800 
1801 /// Returns true if the predicate is true for two equal operands.
1802 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1803  switch (predicate) {
1804  case arith::CmpIPredicate::eq:
1805  case arith::CmpIPredicate::sle:
1806  case arith::CmpIPredicate::sge:
1807  case arith::CmpIPredicate::ule:
1808  case arith::CmpIPredicate::uge:
1809  return true;
1810  case arith::CmpIPredicate::ne:
1811  case arith::CmpIPredicate::slt:
1812  case arith::CmpIPredicate::sgt:
1813  case arith::CmpIPredicate::ult:
1814  case arith::CmpIPredicate::ugt:
1815  return false;
1816  }
1817  llvm_unreachable("unknown cmpi predicate kind");
1818 }
1819 
1820 static std::optional<int64_t> getIntegerWidth(Type t) {
1821  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1822  return intType.getWidth();
1823  }
1824  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1825  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1826  }
1827  return std::nullopt;
1828 }
1829 
1830 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1831  // cmpi(pred, x, x)
1832  if (getLhs() == getRhs()) {
1833  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1834  return getBoolAttribute(getType(), val);
1835  }
1836 
1837  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1838  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1839  // extsi(%x : i1 -> iN) != 0 -> %x
1840  std::optional<int64_t> integerWidth =
1841  getIntegerWidth(extOp.getOperand().getType());
1842  if (integerWidth && integerWidth.value() == 1 &&
1843  getPredicate() == arith::CmpIPredicate::ne)
1844  return extOp.getOperand();
1845  }
1846  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1847  // extui(%x : i1 -> iN) != 0 -> %x
1848  std::optional<int64_t> integerWidth =
1849  getIntegerWidth(extOp.getOperand().getType());
1850  if (integerWidth && integerWidth.value() == 1 &&
1851  getPredicate() == arith::CmpIPredicate::ne)
1852  return extOp.getOperand();
1853  }
1854  }
1855 
1856  // Move constant to the right side.
1857  if (adaptor.getLhs() && !adaptor.getRhs()) {
1858  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1859  using Pred = CmpIPredicate;
1860  const std::pair<Pred, Pred> invPreds[] = {
1861  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1862  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1863  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1864  {Pred::ne, Pred::ne},
1865  };
1866  Pred origPred = getPredicate();
1867  for (auto pred : invPreds) {
1868  if (origPred == pred.first) {
1869  setPredicate(pred.second);
1870  Value lhs = getLhs();
1871  Value rhs = getRhs();
1872  getLhsMutable().assign(rhs);
1873  getRhsMutable().assign(lhs);
1874  return getResult();
1875  }
1876  }
1877  llvm_unreachable("unknown cmpi predicate kind");
1878  }
1879 
1880  // We are moving constants to the right side; So if lhs is constant rhs is
1881  // guaranteed to be a constant.
1882  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1883  return constFoldBinaryOp<IntegerAttr>(
1884  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1885  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1886  return APInt(1,
1887  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1888  });
1889  }
1890 
1891  return {};
1892 }
1893 
1894 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1895  MLIRContext *context) {
1896  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // CmpFOp
1901 //===----------------------------------------------------------------------===//
1902 
1903 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1904 /// comparison predicates.
1905 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1906  const APFloat &lhs, const APFloat &rhs) {
1907  auto cmpResult = lhs.compare(rhs);
1908  switch (predicate) {
1909  case arith::CmpFPredicate::AlwaysFalse:
1910  return false;
1911  case arith::CmpFPredicate::OEQ:
1912  return cmpResult == APFloat::cmpEqual;
1913  case arith::CmpFPredicate::OGT:
1914  return cmpResult == APFloat::cmpGreaterThan;
1915  case arith::CmpFPredicate::OGE:
1916  return cmpResult == APFloat::cmpGreaterThan ||
1917  cmpResult == APFloat::cmpEqual;
1918  case arith::CmpFPredicate::OLT:
1919  return cmpResult == APFloat::cmpLessThan;
1920  case arith::CmpFPredicate::OLE:
1921  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1922  case arith::CmpFPredicate::ONE:
1923  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1924  case arith::CmpFPredicate::ORD:
1925  return cmpResult != APFloat::cmpUnordered;
1926  case arith::CmpFPredicate::UEQ:
1927  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1928  case arith::CmpFPredicate::UGT:
1929  return cmpResult == APFloat::cmpUnordered ||
1930  cmpResult == APFloat::cmpGreaterThan;
1931  case arith::CmpFPredicate::UGE:
1932  return cmpResult == APFloat::cmpUnordered ||
1933  cmpResult == APFloat::cmpGreaterThan ||
1934  cmpResult == APFloat::cmpEqual;
1935  case arith::CmpFPredicate::ULT:
1936  return cmpResult == APFloat::cmpUnordered ||
1937  cmpResult == APFloat::cmpLessThan;
1938  case arith::CmpFPredicate::ULE:
1939  return cmpResult == APFloat::cmpUnordered ||
1940  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1941  case arith::CmpFPredicate::UNE:
1942  return cmpResult != APFloat::cmpEqual;
1943  case arith::CmpFPredicate::UNO:
1944  return cmpResult == APFloat::cmpUnordered;
1945  case arith::CmpFPredicate::AlwaysTrue:
1946  return true;
1947  }
1948  llvm_unreachable("unknown cmpf predicate kind");
1949 }
1950 
1951 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1952  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1953  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1954 
1955  // If one operand is NaN, making them both NaN does not change the result.
1956  if (lhs && lhs.getValue().isNaN())
1957  rhs = lhs;
1958  if (rhs && rhs.getValue().isNaN())
1959  lhs = rhs;
1960 
1961  if (!lhs || !rhs)
1962  return {};
1963 
1964  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1965  return BoolAttr::get(getContext(), val);
1966 }
1967 
1968 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1969 public:
1971 
1972  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1973  bool isUnsigned) {
1974  using namespace arith;
1975  switch (pred) {
1976  case CmpFPredicate::UEQ:
1977  case CmpFPredicate::OEQ:
1978  return CmpIPredicate::eq;
1979  case CmpFPredicate::UGT:
1980  case CmpFPredicate::OGT:
1981  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1982  case CmpFPredicate::UGE:
1983  case CmpFPredicate::OGE:
1984  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1985  case CmpFPredicate::ULT:
1986  case CmpFPredicate::OLT:
1987  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1988  case CmpFPredicate::ULE:
1989  case CmpFPredicate::OLE:
1990  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1991  case CmpFPredicate::UNE:
1992  case CmpFPredicate::ONE:
1993  return CmpIPredicate::ne;
1994  default:
1995  llvm_unreachable("Unexpected predicate!");
1996  }
1997  }
1998 
1999  LogicalResult matchAndRewrite(CmpFOp op,
2000  PatternRewriter &rewriter) const override {
2001  FloatAttr flt;
2002  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2003  return failure();
2004 
2005  const APFloat &rhs = flt.getValue();
2006 
2007  // Don't attempt to fold a nan.
2008  if (rhs.isNaN())
2009  return failure();
2010 
2011  // Get the width of the mantissa. We don't want to hack on conversions that
2012  // might lose information from the integer, e.g. "i64 -> float"
2013  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2014  int mantissaWidth = floatTy.getFPMantissaWidth();
2015  if (mantissaWidth <= 0)
2016  return failure();
2017 
2018  bool isUnsigned;
2019  Value intVal;
2020 
2021  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2022  isUnsigned = false;
2023  intVal = si.getIn();
2024  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2025  isUnsigned = true;
2026  intVal = ui.getIn();
2027  } else {
2028  return failure();
2029  }
2030 
2031  // Check to see that the input is converted from an integer type that is
2032  // small enough that preserves all bits.
2033  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2034  auto intWidth = intTy.getWidth();
2035 
2036  // Number of bits representing values, as opposed to the sign
2037  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2038 
2039  // Following test does NOT adjust intWidth downwards for signed inputs,
2040  // because the most negative value still requires all the mantissa bits
2041  // to distinguish it from one less than that value.
2042  if ((int)intWidth > mantissaWidth) {
2043  // Conversion would lose accuracy. Check if loss can impact comparison.
2044  int exponent = ilogb(rhs);
2045  if (exponent == APFloat::IEK_Inf) {
2046  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2047  if (maxExponent < (int)valueBits) {
2048  // Conversion could create infinity.
2049  return failure();
2050  }
2051  } else {
2052  // Note that if rhs is zero or NaN, then Exp is negative
2053  // and first condition is trivially false.
2054  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2055  // Conversion could affect comparison.
2056  return failure();
2057  }
2058  }
2059  }
2060 
2061  // Convert to equivalent cmpi predicate
2062  CmpIPredicate pred;
2063  switch (op.getPredicate()) {
2064  case CmpFPredicate::ORD:
2065  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2066  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2067  /*width=*/1);
2068  return success();
2069  case CmpFPredicate::UNO:
2070  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2071  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2072  /*width=*/1);
2073  return success();
2074  default:
2075  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2076  break;
2077  }
2078 
2079  if (!isUnsigned) {
2080  // If the rhs value is > SignedMax, fold the comparison. This handles
2081  // +INF and large values.
2082  APFloat signedMax(rhs.getSemantics());
2083  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2084  APFloat::rmNearestTiesToEven);
2085  if (signedMax < rhs) { // smax < 13123.0
2086  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2087  pred == CmpIPredicate::sle)
2088  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2089  /*width=*/1);
2090  else
2091  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2092  /*width=*/1);
2093  return success();
2094  }
2095  } else {
2096  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2097  // +INF and large values.
2098  APFloat unsignedMax(rhs.getSemantics());
2099  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2100  APFloat::rmNearestTiesToEven);
2101  if (unsignedMax < rhs) { // umax < 13123.0
2102  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2103  pred == CmpIPredicate::ule)
2104  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2105  /*width=*/1);
2106  else
2107  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2108  /*width=*/1);
2109  return success();
2110  }
2111  }
2112 
2113  if (!isUnsigned) {
2114  // See if the rhs value is < SignedMin.
2115  APFloat signedMin(rhs.getSemantics());
2116  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2117  APFloat::rmNearestTiesToEven);
2118  if (signedMin > rhs) { // smin > 12312.0
2119  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2120  pred == CmpIPredicate::sge)
2121  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2122  /*width=*/1);
2123  else
2124  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2125  /*width=*/1);
2126  return success();
2127  }
2128  } else {
2129  // See if the rhs value is < UnsignedMin.
2130  APFloat unsignedMin(rhs.getSemantics());
2131  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2132  APFloat::rmNearestTiesToEven);
2133  if (unsignedMin > rhs) { // umin > 12312.0
2134  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2135  pred == CmpIPredicate::uge)
2136  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2137  /*width=*/1);
2138  else
2139  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2140  /*width=*/1);
2141  return success();
2142  }
2143  }
2144 
2145  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2146  // [0, UMAX], but it may still be fractional. See if it is fractional by
2147  // casting the FP value to the integer value and back, checking for
2148  // equality. Don't do this for zero, because -0.0 is not fractional.
2149  bool ignored;
2150  APSInt rhsInt(intWidth, isUnsigned);
2151  if (APFloat::opInvalidOp ==
2152  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2153  // Undefined behavior invoked - the destination type can't represent
2154  // the input constant.
2155  return failure();
2156  }
2157 
2158  if (!rhs.isZero()) {
2159  APFloat apf(floatTy.getFloatSemantics(),
2160  APInt::getZero(floatTy.getWidth()));
2161  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2162 
2163  bool equal = apf == rhs;
2164  if (!equal) {
2165  // If we had a comparison against a fractional value, we have to adjust
2166  // the compare predicate and sometimes the value. rhsInt is rounded
2167  // towards zero at this point.
2168  switch (pred) {
2169  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2170  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2171  /*width=*/1);
2172  return success();
2173  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2174  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2175  /*width=*/1);
2176  return success();
2177  case CmpIPredicate::ule:
2178  // (float)int <= 4.4 --> int <= 4
2179  // (float)int <= -4.4 --> false
2180  if (rhs.isNegative()) {
2181  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2182  /*width=*/1);
2183  return success();
2184  }
2185  break;
2186  case CmpIPredicate::sle:
2187  // (float)int <= 4.4 --> int <= 4
2188  // (float)int <= -4.4 --> int < -4
2189  if (rhs.isNegative())
2190  pred = CmpIPredicate::slt;
2191  break;
2192  case CmpIPredicate::ult:
2193  // (float)int < -4.4 --> false
2194  // (float)int < 4.4 --> int <= 4
2195  if (rhs.isNegative()) {
2196  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2197  /*width=*/1);
2198  return success();
2199  }
2200  pred = CmpIPredicate::ule;
2201  break;
2202  case CmpIPredicate::slt:
2203  // (float)int < -4.4 --> int < -4
2204  // (float)int < 4.4 --> int <= 4
2205  if (!rhs.isNegative())
2206  pred = CmpIPredicate::sle;
2207  break;
2208  case CmpIPredicate::ugt:
2209  // (float)int > 4.4 --> int > 4
2210  // (float)int > -4.4 --> true
2211  if (rhs.isNegative()) {
2212  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2213  /*width=*/1);
2214  return success();
2215  }
2216  break;
2217  case CmpIPredicate::sgt:
2218  // (float)int > 4.4 --> int > 4
2219  // (float)int > -4.4 --> int >= -4
2220  if (rhs.isNegative())
2221  pred = CmpIPredicate::sge;
2222  break;
2223  case CmpIPredicate::uge:
2224  // (float)int >= -4.4 --> true
2225  // (float)int >= 4.4 --> int > 4
2226  if (rhs.isNegative()) {
2227  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2228  /*width=*/1);
2229  return success();
2230  }
2231  pred = CmpIPredicate::ugt;
2232  break;
2233  case CmpIPredicate::sge:
2234  // (float)int >= -4.4 --> int >= -4
2235  // (float)int >= 4.4 --> int > 4
2236  if (!rhs.isNegative())
2237  pred = CmpIPredicate::sgt;
2238  break;
2239  }
2240  }
2241  }
2242 
2243  // Lower this FP comparison into an appropriate integer version of the
2244  // comparison.
2245  rewriter.replaceOpWithNewOp<CmpIOp>(
2246  op, pred, intVal,
2247  rewriter.create<ConstantOp>(
2248  op.getLoc(), intVal.getType(),
2249  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2250  return success();
2251  }
2252 };
2253 
2254 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2255  MLIRContext *context) {
2256  patterns.insert<CmpFIntToFPConst>(context);
2257 }
2258 
2259 //===----------------------------------------------------------------------===//
2260 // SelectOp
2261 //===----------------------------------------------------------------------===//
2262 
2263 // select %arg, %c1, %c0 => extui %arg
2264 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2266 
2267  LogicalResult matchAndRewrite(arith::SelectOp op,
2268  PatternRewriter &rewriter) const override {
2269  // Cannot extui i1 to i1, or i1 to f32
2270  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2271  return failure();
2272 
2273  // select %x, c1, %c0 => extui %arg
2274  if (matchPattern(op.getTrueValue(), m_One()) &&
2275  matchPattern(op.getFalseValue(), m_Zero())) {
2276  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2277  op.getCondition());
2278  return success();
2279  }
2280 
2281  // select %x, c0, %c1 => extui (xor %arg, true)
2282  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2283  matchPattern(op.getFalseValue(), m_One())) {
2284  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2285  op, op.getType(),
2286  rewriter.create<arith::XOrIOp>(
2287  op.getLoc(), op.getCondition(),
2288  rewriter.create<arith::ConstantIntOp>(
2289  op.getLoc(), 1, op.getCondition().getType())));
2290  return success();
2291  }
2292 
2293  return failure();
2294  }
2295 };
2296 
2297 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2298  MLIRContext *context) {
2299  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2300  SelectI1ToNot, SelectToExtUI>(context);
2301 }
2302 
2303 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2304  Value trueVal = getTrueValue();
2305  Value falseVal = getFalseValue();
2306  if (trueVal == falseVal)
2307  return trueVal;
2308 
2309  Value condition = getCondition();
2310 
2311  // select true, %0, %1 => %0
2312  if (matchPattern(adaptor.getCondition(), m_One()))
2313  return trueVal;
2314 
2315  // select false, %0, %1 => %1
2316  if (matchPattern(adaptor.getCondition(), m_Zero()))
2317  return falseVal;
2318 
2319  // If either operand is fully poisoned, return the other.
2320  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2321  return falseVal;
2322 
2323  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2324  return trueVal;
2325 
2326  // select %x, true, false => %x
2327  if (getType().isSignlessInteger(1) &&
2328  matchPattern(adaptor.getTrueValue(), m_One()) &&
2329  matchPattern(adaptor.getFalseValue(), m_Zero()))
2330  return condition;
2331 
2332  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2333  auto pred = cmp.getPredicate();
2334  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2335  auto cmpLhs = cmp.getLhs();
2336  auto cmpRhs = cmp.getRhs();
2337 
2338  // %0 = arith.cmpi eq, %arg0, %arg1
2339  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2340 
2341  // %0 = arith.cmpi ne, %arg0, %arg1
2342  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2343 
2344  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2345  (cmpRhs == trueVal && cmpLhs == falseVal))
2346  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2347  }
2348  }
2349 
2350  // Constant-fold constant operands over non-splat constant condition.
2351  // select %cst_vec, %cst0, %cst1 => %cst2
2352  if (auto cond =
2353  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2354  if (auto lhs =
2355  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2356  if (auto rhs =
2357  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2358  SmallVector<Attribute> results;
2359  results.reserve(static_cast<size_t>(cond.getNumElements()));
2360  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2361  cond.value_end<BoolAttr>());
2362  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2363  lhs.value_end<Attribute>());
2364  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2365  rhs.value_end<Attribute>());
2366 
2367  for (auto [condVal, lhsVal, rhsVal] :
2368  llvm::zip_equal(condVals, lhsVals, rhsVals))
2369  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2370 
2371  return DenseElementsAttr::get(lhs.getType(), results);
2372  }
2373  }
2374  }
2375 
2376  return nullptr;
2377 }
2378 
2379 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2380  Type conditionType, resultType;
2382  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2383  parser.parseOptionalAttrDict(result.attributes) ||
2384  parser.parseColonType(resultType))
2385  return failure();
2386 
2387  // Check for the explicit condition type if this is a masked tensor or vector.
2388  if (succeeded(parser.parseOptionalComma())) {
2389  conditionType = resultType;
2390  if (parser.parseType(resultType))
2391  return failure();
2392  } else {
2393  conditionType = parser.getBuilder().getI1Type();
2394  }
2395 
2396  result.addTypes(resultType);
2397  return parser.resolveOperands(operands,
2398  {conditionType, resultType, resultType},
2399  parser.getNameLoc(), result.operands);
2400 }
2401 
2403  p << " " << getOperands();
2404  p.printOptionalAttrDict((*this)->getAttrs());
2405  p << " : ";
2406  if (ShapedType condType =
2407  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2408  p << condType << ", ";
2409  p << getType();
2410 }
2411 
2412 LogicalResult arith::SelectOp::verify() {
2413  Type conditionType = getCondition().getType();
2414  if (conditionType.isSignlessInteger(1))
2415  return success();
2416 
2417  // If the result type is a vector or tensor, the type can be a mask with the
2418  // same elements.
2419  Type resultType = getType();
2420  if (!llvm::isa<TensorType, VectorType>(resultType))
2421  return emitOpError() << "expected condition to be a signless i1, but got "
2422  << conditionType;
2423  Type shapedConditionType = getI1SameShape(resultType);
2424  if (conditionType != shapedConditionType) {
2425  return emitOpError() << "expected condition type to have the same shape "
2426  "as the result type, expected "
2427  << shapedConditionType << ", but got "
2428  << conditionType;
2429  }
2430  return success();
2431 }
2432 //===----------------------------------------------------------------------===//
2433 // ShLIOp
2434 //===----------------------------------------------------------------------===//
2435 
2436 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2437  // shli(x, 0) -> x
2438  if (matchPattern(adaptor.getRhs(), m_Zero()))
2439  return getLhs();
2440  // Don't fold if shifting more or equal than the bit width.
2441  bool bounded = false;
2442  auto result = constFoldBinaryOp<IntegerAttr>(
2443  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2444  bounded = b.ult(b.getBitWidth());
2445  return a.shl(b);
2446  });
2447  return bounded ? result : Attribute();
2448 }
2449 
2450 //===----------------------------------------------------------------------===//
2451 // ShRUIOp
2452 //===----------------------------------------------------------------------===//
2453 
2454 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2455  // shrui(x, 0) -> x
2456  if (matchPattern(adaptor.getRhs(), m_Zero()))
2457  return getLhs();
2458  // Don't fold if shifting more or equal than the bit width.
2459  bool bounded = false;
2460  auto result = constFoldBinaryOp<IntegerAttr>(
2461  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2462  bounded = b.ult(b.getBitWidth());
2463  return a.lshr(b);
2464  });
2465  return bounded ? result : Attribute();
2466 }
2467 
2468 //===----------------------------------------------------------------------===//
2469 // ShRSIOp
2470 //===----------------------------------------------------------------------===//
2471 
2472 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2473  // shrsi(x, 0) -> x
2474  if (matchPattern(adaptor.getRhs(), m_Zero()))
2475  return getLhs();
2476  // Don't fold if shifting more or equal than the bit width.
2477  bool bounded = false;
2478  auto result = constFoldBinaryOp<IntegerAttr>(
2479  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2480  bounded = b.ult(b.getBitWidth());
2481  return a.ashr(b);
2482  });
2483  return bounded ? result : Attribute();
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // Atomic Enum
2488 //===----------------------------------------------------------------------===//
2489 
2490 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2491 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2492  OpBuilder &builder, Location loc,
2493  bool useOnlyFiniteValue) {
2494  switch (kind) {
2495  case AtomicRMWKind::maximumf: {
2496  const llvm::fltSemantics &semantic =
2497  llvm::cast<FloatType>(resultType).getFloatSemantics();
2498  APFloat identity = useOnlyFiniteValue
2499  ? APFloat::getLargest(semantic, /*Negative=*/true)
2500  : APFloat::getInf(semantic, /*Negative=*/true);
2501  return builder.getFloatAttr(resultType, identity);
2502  }
2503  case AtomicRMWKind::maxnumf: {
2504  const llvm::fltSemantics &semantic =
2505  llvm::cast<FloatType>(resultType).getFloatSemantics();
2506  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2507  return builder.getFloatAttr(resultType, identity);
2508  }
2509  case AtomicRMWKind::addf:
2510  case AtomicRMWKind::addi:
2511  case AtomicRMWKind::maxu:
2512  case AtomicRMWKind::ori:
2513  return builder.getZeroAttr(resultType);
2514  case AtomicRMWKind::andi:
2515  return builder.getIntegerAttr(
2516  resultType,
2517  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2518  case AtomicRMWKind::maxs:
2519  return builder.getIntegerAttr(
2520  resultType, APInt::getSignedMinValue(
2521  llvm::cast<IntegerType>(resultType).getWidth()));
2522  case AtomicRMWKind::minimumf: {
2523  const llvm::fltSemantics &semantic =
2524  llvm::cast<FloatType>(resultType).getFloatSemantics();
2525  APFloat identity = useOnlyFiniteValue
2526  ? APFloat::getLargest(semantic, /*Negative=*/false)
2527  : APFloat::getInf(semantic, /*Negative=*/false);
2528 
2529  return builder.getFloatAttr(resultType, identity);
2530  }
2531  case AtomicRMWKind::minnumf: {
2532  const llvm::fltSemantics &semantic =
2533  llvm::cast<FloatType>(resultType).getFloatSemantics();
2534  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2535  return builder.getFloatAttr(resultType, identity);
2536  }
2537  case AtomicRMWKind::mins:
2538  return builder.getIntegerAttr(
2539  resultType, APInt::getSignedMaxValue(
2540  llvm::cast<IntegerType>(resultType).getWidth()));
2541  case AtomicRMWKind::minu:
2542  return builder.getIntegerAttr(
2543  resultType,
2544  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2545  case AtomicRMWKind::muli:
2546  return builder.getIntegerAttr(resultType, 1);
2547  case AtomicRMWKind::mulf:
2548  return builder.getFloatAttr(resultType, 1);
2549  // TODO: Add remaining reduction operations.
2550  default:
2551  (void)emitOptionalError(loc, "Reduction operation type not supported");
2552  break;
2553  }
2554  return nullptr;
2555 }
2556 
2557 /// Return the identity numeric value associated to the give op.
2558 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2559  std::optional<AtomicRMWKind> maybeKind =
2561  // Floating-point operations.
2562  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2563  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2564  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2565  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2566  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2567  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2568  // Integer operations.
2569  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2570  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2571  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2572  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2573  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2574  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2575  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2576  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2577  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2578  .Default([](Operation *op) { return std::nullopt; });
2579  if (!maybeKind) {
2580  return std::nullopt;
2581  }
2582 
2583  bool useOnlyFiniteValue = false;
2584  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2585  if (fmfOpInterface) {
2586  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2587  useOnlyFiniteValue =
2588  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2589  }
2590 
2591  // Builder only used as helper for attribute creation.
2592  OpBuilder b(op->getContext());
2593  Type resultType = op->getResult(0).getType();
2594 
2595  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2596  useOnlyFiniteValue);
2597 }
2598 
2599 /// Returns the identity value associated with an AtomicRMWKind op.
2600 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2601  OpBuilder &builder, Location loc,
2602  bool useOnlyFiniteValue) {
2603  auto attr =
2604  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2605  return builder.create<arith::ConstantOp>(loc, attr);
2606 }
2607 
2608 /// Return the value obtained by applying the reduction operation kind
2609 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2610 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2611  Location loc, Value lhs, Value rhs) {
2612  switch (op) {
2613  case AtomicRMWKind::addf:
2614  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2615  case AtomicRMWKind::addi:
2616  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2617  case AtomicRMWKind::mulf:
2618  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2619  case AtomicRMWKind::muli:
2620  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2621  case AtomicRMWKind::maximumf:
2622  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2623  case AtomicRMWKind::minimumf:
2624  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2625  case AtomicRMWKind::maxnumf:
2626  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2627  case AtomicRMWKind::minnumf:
2628  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2629  case AtomicRMWKind::maxs:
2630  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2631  case AtomicRMWKind::mins:
2632  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2633  case AtomicRMWKind::maxu:
2634  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2635  case AtomicRMWKind::minu:
2636  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2637  case AtomicRMWKind::ori:
2638  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2639  case AtomicRMWKind::andi:
2640  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2641  // TODO: Add remaining reduction operations.
2642  default:
2643  (void)emitOptionalError(loc, "Reduction operation type not supported");
2644  break;
2645  }
2646  return nullptr;
2647 }
2648 
2649 //===----------------------------------------------------------------------===//
2650 // TableGen'd op method definitions
2651 //===----------------------------------------------------------------------===//
2652 
2653 #define GET_OP_CLASSES
2654 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2655 
2656 //===----------------------------------------------------------------------===//
2657 // TableGen'd enum attribute definitions
2658 //===----------------------------------------------------------------------===//
2659 
2660 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:603
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1315
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1251
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1802
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1265
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1237
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:659
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:641
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1330
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1656
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1554
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1287
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:840
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1258
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1820
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1273
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1231
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:345
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1301
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:1999
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:1972
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getFPMantissaWidth()
Return the width of the mantissa of this type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:937
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:75
bool isIndex() const
Definition: Types.cpp:64
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:278
static bool classof(Operation *op)
Definition: ArithOps.cpp:284
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:290
static bool classof(Operation *op)
Definition: ArithOps.cpp:296
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:257
static bool classof(Operation *op)
Definition: ArithOps.cpp:272
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2558
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1774
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2491
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2610
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2600
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
auto m_Val(Value v)
Definition: Matchers.h:545
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:496
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:533
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:477
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:448
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:484
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:468
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
detail::constant_float_predicate_matcher m_isDenormalFloat()
Matches a constant scalar / vector splat / tensor splat with denormal values.
Definition: Matchers.h:443
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:461
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2267
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes