MLIR  22.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 using namespace mlir;
35 using namespace mlir::arith;
36 
37 //===----------------------------------------------------------------------===//
38 // Pattern helpers
39 //===----------------------------------------------------------------------===//
40 
41 static IntegerAttr
43  Attribute rhs,
44  function_ref<APInt(const APInt &, const APInt &)> binFn) {
45  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47  APInt value = binFn(lhsVal, rhsVal);
48  return IntegerAttr::get(res.getType(), value);
49 }
50 
51 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
52  Attribute lhs, Attribute rhs) {
53  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54 }
55 
56 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
57  Attribute lhs, Attribute rhs) {
58  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59 }
60 
61 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
62  Attribute lhs, Attribute rhs) {
63  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64 }
65 
66 // Merge overflow flags from 2 ops, selecting the most conservative combination.
67 static IntegerOverflowFlagsAttr
68 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69  IntegerOverflowFlagsAttr val2) {
70  return IntegerOverflowFlagsAttr::get(val1.getContext(),
71  val1.getValue() & val2.getValue());
72 }
73 
74 /// Invert an integer comparison predicate.
75 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76  switch (pred) {
77  case arith::CmpIPredicate::eq:
78  return arith::CmpIPredicate::ne;
79  case arith::CmpIPredicate::ne:
80  return arith::CmpIPredicate::eq;
81  case arith::CmpIPredicate::slt:
82  return arith::CmpIPredicate::sge;
83  case arith::CmpIPredicate::sle:
84  return arith::CmpIPredicate::sgt;
85  case arith::CmpIPredicate::sgt:
86  return arith::CmpIPredicate::sle;
87  case arith::CmpIPredicate::sge:
88  return arith::CmpIPredicate::slt;
89  case arith::CmpIPredicate::ult:
90  return arith::CmpIPredicate::uge;
91  case arith::CmpIPredicate::ule:
92  return arith::CmpIPredicate::ugt;
93  case arith::CmpIPredicate::ugt:
94  return arith::CmpIPredicate::ule;
95  case arith::CmpIPredicate::uge:
96  return arith::CmpIPredicate::ult;
97  }
98  llvm_unreachable("unknown cmpi predicate kind");
99 }
100 
101 /// Equivalent to
102 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103 ///
104 /// Not possible to implement as chain of calls as this would introduce a
105 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106 /// on the LLVM dialect and on translation to LLVM.
107 static llvm::RoundingMode
108 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109  switch (roundingMode) {
110  case RoundingMode::downward:
111  return llvm::RoundingMode::TowardNegative;
112  case RoundingMode::to_nearest_away:
113  return llvm::RoundingMode::NearestTiesToAway;
114  case RoundingMode::to_nearest_even:
115  return llvm::RoundingMode::NearestTiesToEven;
116  case RoundingMode::toward_zero:
117  return llvm::RoundingMode::TowardZero;
118  case RoundingMode::upward:
119  return llvm::RoundingMode::TowardPositive;
120  }
121  llvm_unreachable("Unhandled rounding mode");
122 }
123 
124 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125  return arith::CmpIPredicateAttr::get(pred.getContext(),
126  invertPredicate(pred.getValue()));
127 }
128 
129 static int64_t getScalarOrElementWidth(Type type) {
130  Type elemTy = getElementTypeOrSelf(type);
131  if (elemTy.isIntOrFloat())
132  return elemTy.getIntOrFloatBitWidth();
133 
134  return -1;
135 }
136 
137 static int64_t getScalarOrElementWidth(Value value) {
138  return getScalarOrElementWidth(value.getType());
139 }
140 
141 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142  APInt value;
143  if (matchPattern(attr, m_ConstantInt(&value)))
144  return value;
145 
146  return failure();
147 }
148 
149 static Attribute getBoolAttribute(Type type, bool value) {
150  auto boolAttr = BoolAttr::get(type.getContext(), value);
151  ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152  if (!shapedType)
153  return boolAttr;
154  return DenseElementsAttr::get(shapedType, boolAttr);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TableGen'd canonicalization patterns
159 //===----------------------------------------------------------------------===//
160 
161 namespace {
162 #include "ArithCanonicalization.inc"
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // Common helpers
167 //===----------------------------------------------------------------------===//
168 
169 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
170 static Type getI1SameShape(Type type) {
171  auto i1Type = IntegerType::get(type.getContext(), 1);
172  if (auto shapedType = dyn_cast<ShapedType>(type))
173  return shapedType.cloneWith(std::nullopt, i1Type);
174  if (llvm::isa<UnrankedTensorType>(type))
175  return UnrankedTensorType::get(i1Type);
176  return i1Type;
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // ConstantOp
181 //===----------------------------------------------------------------------===//
182 
183 void arith::ConstantOp::getAsmResultNames(
184  function_ref<void(Value, StringRef)> setNameFn) {
185  auto type = getType();
186  if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187  auto intType = dyn_cast<IntegerType>(type);
188 
189  // Sugar i1 constants with 'true' and 'false'.
190  if (intType && intType.getWidth() == 1)
191  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192 
193  // Otherwise, build a complex name with the value and type.
194  SmallString<32> specialNameBuffer;
195  llvm::raw_svector_ostream specialName(specialNameBuffer);
196  specialName << 'c' << intCst.getValue();
197  if (intType)
198  specialName << '_' << type;
199  setNameFn(getResult(), specialName.str());
200  } else {
201  setNameFn(getResult(), "cst");
202  }
203 }
204 
205 /// TODO: disallow arith.constant to return anything other than signless integer
206 /// or float like.
207 LogicalResult arith::ConstantOp::verify() {
208  auto type = getType();
209  // Integer values must be signless.
210  if (llvm::isa<IntegerType>(type) &&
211  !llvm::cast<IntegerType>(type).isSignless())
212  return emitOpError("integer return type must be signless");
213  // Any float or elements attribute are acceptable.
214  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
215  return emitOpError(
216  "value must be an integer, float, or elements attribute");
217  }
218 
219  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
220  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
221  // However, this would most likely require updating the lowerings to LLVM.
222  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
223  return emitOpError(
224  "intializing scalable vectors with elements attribute is not supported"
225  " unless it's a vector splat");
226  return success();
227 }
228 
229 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
230  // The value's type must be the same as the provided type.
231  auto typedAttr = dyn_cast<TypedAttr>(value);
232  if (!typedAttr || typedAttr.getType() != type)
233  return false;
234  // Integer values must be signless.
235  if (llvm::isa<IntegerType>(type) &&
236  !llvm::cast<IntegerType>(type).isSignless())
237  return false;
238  // Integer, float, and element attributes are buildable.
239  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
240 }
241 
242 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
243  Type type, Location loc) {
244  if (isBuildableWith(value, type))
245  return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
246  return nullptr;
247 }
248 
249 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
250 
252  int64_t value, unsigned width) {
253  auto type = builder.getIntegerType(width);
254  arith::ConstantOp::build(builder, result, type,
255  builder.getIntegerAttr(type, value));
256 }
257 
259  Location location,
260  int64_t value,
261  unsigned width) {
262  mlir::OperationState state(location, getOperationName());
263  build(builder, state, value, width);
264  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
265  assert(result && "builder didn't return the right type");
266  return result;
267 }
268 
270  int64_t value,
271  unsigned width) {
272  return create(builder, builder.getLoc(), value, width);
273 }
274 
276  Type type, int64_t value) {
277  arith::ConstantOp::build(builder, result, type,
278  builder.getIntegerAttr(type, value));
279 }
280 
282  Location location, Type type,
283  int64_t value) {
284  mlir::OperationState state(location, getOperationName());
285  build(builder, state, type, value);
286  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
287  assert(result && "builder didn't return the right type");
288  return result;
289 }
290 
292  Type type, int64_t value) {
293  return create(builder, builder.getLoc(), type, value);
294 }
295 
297  Type type, const APInt &value) {
298  arith::ConstantOp::build(builder, result, type,
299  builder.getIntegerAttr(type, value));
300 }
301 
303  Location location, Type type,
304  const APInt &value) {
305  mlir::OperationState state(location, getOperationName());
306  build(builder, state, type, value);
307  auto result = dyn_cast<ConstantIntOp>(builder.create(state));
308  assert(result && "builder didn't return the right type");
309  return result;
310 }
311 
313  Type type,
314  const APInt &value) {
315  return create(builder, builder.getLoc(), type, value);
316 }
317 
319  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320  return constOp.getType().isSignlessInteger();
321  return false;
322 }
323 
325  FloatType type, const APFloat &value) {
326  arith::ConstantOp::build(builder, result, type,
327  builder.getFloatAttr(type, value));
328 }
329 
331  Location location,
332  FloatType type,
333  const APFloat &value) {
334  mlir::OperationState state(location, getOperationName());
335  build(builder, state, type, value);
336  auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
337  assert(result && "builder didn't return the right type");
338  return result;
339 }
340 
343  const APFloat &value) {
344  return create(builder, builder.getLoc(), type, value);
345 }
346 
348  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349  return llvm::isa<FloatType>(constOp.getType());
350  return false;
351 }
352 
354  int64_t value) {
355  arith::ConstantOp::build(builder, result, builder.getIndexType(),
356  builder.getIndexAttr(value));
357 }
358 
360  Location location,
361  int64_t value) {
362  mlir::OperationState state(location, getOperationName());
363  build(builder, state, value);
364  auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
365  assert(result && "builder didn't return the right type");
366  return result;
367 }
368 
371  return create(builder, builder.getLoc(), value);
372 }
373 
375  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376  return constOp.getType().isIndex();
377  return false;
378 }
379 
381  Type type) {
382  // TODO: Incorporate this check to `FloatAttr::get*`.
383  assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
384  "type doesn't have a zero representation");
385  TypedAttr zeroAttr = builder.getZeroAttr(type);
386  assert(zeroAttr && "unsupported type for zero attribute");
387  return arith::ConstantOp::create(builder, loc, zeroAttr);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // AddIOp
392 //===----------------------------------------------------------------------===//
393 
394 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
395  // addi(x, 0) -> x
396  if (matchPattern(adaptor.getRhs(), m_Zero()))
397  return getLhs();
398 
399  // addi(subi(a, b), b) -> a
400  if (auto sub = getLhs().getDefiningOp<SubIOp>())
401  if (getRhs() == sub.getRhs())
402  return sub.getLhs();
403 
404  // addi(b, subi(a, b)) -> a
405  if (auto sub = getRhs().getDefiningOp<SubIOp>())
406  if (getLhs() == sub.getRhs())
407  return sub.getLhs();
408 
409  return constFoldBinaryOp<IntegerAttr>(
410  adaptor.getOperands(),
411  [](APInt a, const APInt &b) { return std::move(a) + b; });
412 }
413 
414 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
415  MLIRContext *context) {
416  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // AddUIExtendedOp
422 //===----------------------------------------------------------------------===//
423 
424 std::optional<SmallVector<int64_t, 4>>
425 arith::AddUIExtendedOp::getShapeForUnroll() {
426  if (auto vt = dyn_cast<VectorType>(getType(0)))
427  return llvm::to_vector<4>(vt.getShape());
428  return std::nullopt;
429 }
430 
431 // Returns the overflow bit, assuming that `sum` is the result of unsigned
432 // addition of `operand` and another number.
433 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
434  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
435 }
436 
437 LogicalResult
438 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
440  Type overflowTy = getOverflow().getType();
441  // addui_extended(x, 0) -> x, false
442  if (matchPattern(getRhs(), m_Zero())) {
443  Builder builder(getContext());
444  auto falseValue = builder.getZeroAttr(overflowTy);
445 
446  results.push_back(getLhs());
447  results.push_back(falseValue);
448  return success();
449  }
450 
451  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
452  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
453  // operands. If that succeeds, calculate the overflow bit based on the sum
454  // and the first (constant) operand, `lhs`.
455  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
456  adaptor.getOperands(),
457  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
458  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
459  ArrayRef({sumAttr, adaptor.getLhs()}),
460  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
462  if (!overflowAttr)
463  return failure();
464 
465  results.push_back(sumAttr);
466  results.push_back(overflowAttr);
467  return success();
468  }
469 
470  return failure();
471 }
472 
473 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
475  patterns.add<AddUIExtendedToAddI>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // SubIOp
480 //===----------------------------------------------------------------------===//
481 
482 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
483  // subi(x,x) -> 0
484  if (getOperand(0) == getOperand(1)) {
485  auto shapedType = dyn_cast<ShapedType>(getType());
486  // We can't generate a constant with a dynamic shaped tensor.
487  if (!shapedType || shapedType.hasStaticShape())
488  return Builder(getContext()).getZeroAttr(getType());
489  }
490  // subi(x,0) -> x
491  if (matchPattern(adaptor.getRhs(), m_Zero()))
492  return getLhs();
493 
494  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
495  // subi(addi(a, b), b) -> a
496  if (getRhs() == add.getRhs())
497  return add.getLhs();
498  // subi(addi(a, b), a) -> b
499  if (getRhs() == add.getLhs())
500  return add.getRhs();
501  }
502 
503  return constFoldBinaryOp<IntegerAttr>(
504  adaptor.getOperands(),
505  [](APInt a, const APInt &b) { return std::move(a) - b; });
506 }
507 
508 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
509  MLIRContext *context) {
510  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // MulIOp
517 //===----------------------------------------------------------------------===//
518 
519 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
520  // muli(x, 0) -> 0
521  if (matchPattern(adaptor.getRhs(), m_Zero()))
522  return getRhs();
523  // muli(x, 1) -> x
524  if (matchPattern(adaptor.getRhs(), m_One()))
525  return getLhs();
526  // TODO: Handle the overflow case.
527 
528  // default folder
529  return constFoldBinaryOp<IntegerAttr>(
530  adaptor.getOperands(),
531  [](const APInt &a, const APInt &b) { return a * b; });
532 }
533 
534 void arith::MulIOp::getAsmResultNames(
535  function_ref<void(Value, StringRef)> setNameFn) {
536  if (!isa<IndexType>(getType()))
537  return;
538 
539  // Match vector.vscale by name to avoid depending on the vector dialect (which
540  // is a circular dependency).
541  auto isVscale = [](Operation *op) {
542  return op && op->getName().getStringRef() == "vector.vscale";
543  };
544 
545  IntegerAttr baseValue;
546  auto isVscaleExpr = [&](Value a, Value b) {
547  return matchPattern(a, m_Constant(&baseValue)) &&
548  isVscale(b.getDefiningOp());
549  };
550 
551  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
552  return;
553 
554  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
555  SmallString<32> specialNameBuffer;
556  llvm::raw_svector_ostream specialName(specialNameBuffer);
557  specialName << 'c' << baseValue.getInt() << "_vscale";
558  setNameFn(getResult(), specialName.str());
559 }
560 
561 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
562  MLIRContext *context) {
563  patterns.add<MulIMulIConstant>(context);
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // MulSIExtendedOp
568 //===----------------------------------------------------------------------===//
569 
570 std::optional<SmallVector<int64_t, 4>>
571 arith::MulSIExtendedOp::getShapeForUnroll() {
572  if (auto vt = dyn_cast<VectorType>(getType(0)))
573  return llvm::to_vector<4>(vt.getShape());
574  return std::nullopt;
575 }
576 
577 LogicalResult
578 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
580  // mulsi_extended(x, 0) -> 0, 0
581  if (matchPattern(adaptor.getRhs(), m_Zero())) {
582  Attribute zero = adaptor.getRhs();
583  results.push_back(zero);
584  results.push_back(zero);
585  return success();
586  }
587 
588  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
589  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
590  adaptor.getOperands(),
591  [](const APInt &a, const APInt &b) { return a * b; })) {
592  // Invoke the constant fold helper again to calculate the 'high' result.
593  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
594  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
595  return llvm::APIntOps::mulhs(a, b);
596  });
597  assert(highAttr && "Unexpected constant-folding failure");
598 
599  results.push_back(lowAttr);
600  results.push_back(highAttr);
601  return success();
602  }
603 
604  return failure();
605 }
606 
607 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
609  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // MulUIExtendedOp
614 //===----------------------------------------------------------------------===//
615 
616 std::optional<SmallVector<int64_t, 4>>
617 arith::MulUIExtendedOp::getShapeForUnroll() {
618  if (auto vt = dyn_cast<VectorType>(getType(0)))
619  return llvm::to_vector<4>(vt.getShape());
620  return std::nullopt;
621 }
622 
623 LogicalResult
624 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
626  // mului_extended(x, 0) -> 0, 0
627  if (matchPattern(adaptor.getRhs(), m_Zero())) {
628  Attribute zero = adaptor.getRhs();
629  results.push_back(zero);
630  results.push_back(zero);
631  return success();
632  }
633 
634  // mului_extended(x, 1) -> x, 0
635  if (matchPattern(adaptor.getRhs(), m_One())) {
636  Builder builder(getContext());
637  Attribute zero = builder.getZeroAttr(getLhs().getType());
638  results.push_back(getLhs());
639  results.push_back(zero);
640  return success();
641  }
642 
643  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
644  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
645  adaptor.getOperands(),
646  [](const APInt &a, const APInt &b) { return a * b; })) {
647  // Invoke the constant fold helper again to calculate the 'high' result.
648  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
649  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
650  return llvm::APIntOps::mulhu(a, b);
651  });
652  assert(highAttr && "Unexpected constant-folding failure");
653 
654  results.push_back(lowAttr);
655  results.push_back(highAttr);
656  return success();
657  }
658 
659  return failure();
660 }
661 
662 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
664  patterns.add<MulUIExtendedToMulI>(context);
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // DivUIOp
669 //===----------------------------------------------------------------------===//
670 
671 /// Fold `(a * b) / b -> a`
672 static Value foldDivMul(Value lhs, Value rhs,
673  arith::IntegerOverflowFlags ovfFlags) {
674  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
675  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
676  return {};
677 
678  if (mul.getLhs() == rhs)
679  return mul.getRhs();
680 
681  if (mul.getRhs() == rhs)
682  return mul.getLhs();
683 
684  return {};
685 }
686 
687 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
688  // divui (x, 1) -> x.
689  if (matchPattern(adaptor.getRhs(), m_One()))
690  return getLhs();
691 
692  // (a * b) / b -> a
693  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
694  return val;
695 
696  // Don't fold if it would require a division by zero.
697  bool div0 = false;
698  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
699  [&](APInt a, const APInt &b) {
700  if (div0 || !b) {
701  div0 = true;
702  return a;
703  }
704  return a.udiv(b);
705  });
706 
707  return div0 ? Attribute() : result;
708 }
709 
710 /// Returns whether an unsigned division by `divisor` is speculatable.
712  // X / 0 => UB
713  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
715 
717 }
718 
719 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
720  return getDivUISpeculatability(getRhs());
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // DivSIOp
725 //===----------------------------------------------------------------------===//
726 
727 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
728  // divsi (x, 1) -> x.
729  if (matchPattern(adaptor.getRhs(), m_One()))
730  return getLhs();
731 
732  // (a * b) / b -> a
733  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
734  return val;
735 
736  // Don't fold if it would overflow or if it requires a division by zero.
737  bool overflowOrDiv0 = false;
738  auto result = constFoldBinaryOp<IntegerAttr>(
739  adaptor.getOperands(), [&](APInt a, const APInt &b) {
740  if (overflowOrDiv0 || !b) {
741  overflowOrDiv0 = true;
742  return a;
743  }
744  return a.sdiv_ov(b, overflowOrDiv0);
745  });
746 
747  return overflowOrDiv0 ? Attribute() : result;
748 }
749 
750 /// Returns whether a signed division by `divisor` is speculatable. This
751 /// function conservatively assumes that all signed division by -1 are not
752 /// speculatable.
754  // X / 0 => UB
755  // INT_MIN / -1 => UB
756  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
759 
761 }
762 
763 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
764  return getDivSISpeculatability(getRhs());
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // Ceil and floor division folding helpers
769 //===----------------------------------------------------------------------===//
770 
771 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
772  bool &overflow) {
773  // Returns (a-1)/b + 1
774  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
775  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
776  return val.sadd_ov(one, overflow);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // CeilDivUIOp
781 //===----------------------------------------------------------------------===//
782 
783 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
784  // ceildivui (x, 1) -> x.
785  if (matchPattern(adaptor.getRhs(), m_One()))
786  return getLhs();
787 
788  bool overflowOrDiv0 = false;
789  auto result = constFoldBinaryOp<IntegerAttr>(
790  adaptor.getOperands(), [&](APInt a, const APInt &b) {
791  if (overflowOrDiv0 || !b) {
792  overflowOrDiv0 = true;
793  return a;
794  }
795  APInt quotient = a.udiv(b);
796  if (!a.urem(b))
797  return quotient;
798  APInt one(a.getBitWidth(), 1, true);
799  return quotient.uadd_ov(one, overflowOrDiv0);
800  });
801 
802  return overflowOrDiv0 ? Attribute() : result;
803 }
804 
805 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
806  return getDivUISpeculatability(getRhs());
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // CeilDivSIOp
811 //===----------------------------------------------------------------------===//
812 
813 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
814  // ceildivsi (x, 1) -> x.
815  if (matchPattern(adaptor.getRhs(), m_One()))
816  return getLhs();
817 
818  // Don't fold if it would overflow or if it requires a division by zero.
819  // TODO: This hook won't fold operations where a = MININT, because
820  // negating MININT overflows. This can be improved.
821  bool overflowOrDiv0 = false;
822  auto result = constFoldBinaryOp<IntegerAttr>(
823  adaptor.getOperands(), [&](APInt a, const APInt &b) {
824  if (overflowOrDiv0 || !b) {
825  overflowOrDiv0 = true;
826  return a;
827  }
828  if (!a)
829  return a;
830  // After this point we know that neither a or b are zero.
831  unsigned bits = a.getBitWidth();
832  APInt zero = APInt::getZero(bits);
833  bool aGtZero = a.sgt(zero);
834  bool bGtZero = b.sgt(zero);
835  if (aGtZero && bGtZero) {
836  // Both positive, return ceil(a, b).
837  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
838  }
839 
840  // No folding happens if any of the intermediate arithmetic operations
841  // overflows.
842  bool overflowNegA = false;
843  bool overflowNegB = false;
844  bool overflowDiv = false;
845  bool overflowNegRes = false;
846  if (!aGtZero && !bGtZero) {
847  // Both negative, return ceil(-a, -b).
848  APInt posA = zero.ssub_ov(a, overflowNegA);
849  APInt posB = zero.ssub_ov(b, overflowNegB);
850  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
851  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
852  return res;
853  }
854  if (!aGtZero && bGtZero) {
855  // A is negative, b is positive, return - ( -a / b).
856  APInt posA = zero.ssub_ov(a, overflowNegA);
857  APInt div = posA.sdiv_ov(b, overflowDiv);
858  APInt res = zero.ssub_ov(div, overflowNegRes);
859  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
860  return res;
861  }
862  // A is positive, b is negative, return - (a / -b).
863  APInt posB = zero.ssub_ov(b, overflowNegB);
864  APInt div = a.sdiv_ov(posB, overflowDiv);
865  APInt res = zero.ssub_ov(div, overflowNegRes);
866 
867  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
868  return res;
869  });
870 
871  return overflowOrDiv0 ? Attribute() : result;
872 }
873 
874 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
875  return getDivSISpeculatability(getRhs());
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // FloorDivSIOp
880 //===----------------------------------------------------------------------===//
881 
882 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
883  // floordivsi (x, 1) -> x.
884  if (matchPattern(adaptor.getRhs(), m_One()))
885  return getLhs();
886 
887  // Don't fold if it would overflow or if it requires a division by zero.
888  bool overflowOrDiv = false;
889  auto result = constFoldBinaryOp<IntegerAttr>(
890  adaptor.getOperands(), [&](APInt a, const APInt &b) {
891  if (b.isZero()) {
892  overflowOrDiv = true;
893  return a;
894  }
895  return a.sfloordiv_ov(b, overflowOrDiv);
896  });
897 
898  return overflowOrDiv ? Attribute() : result;
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // RemUIOp
903 //===----------------------------------------------------------------------===//
904 
905 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
906  // remui (x, 1) -> 0.
907  if (matchPattern(adaptor.getRhs(), m_One()))
908  return Builder(getContext()).getZeroAttr(getType());
909 
910  // Don't fold if it would require a division by zero.
911  bool div0 = false;
912  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
913  [&](APInt a, const APInt &b) {
914  if (div0 || b.isZero()) {
915  div0 = true;
916  return a;
917  }
918  return a.urem(b);
919  });
920 
921  return div0 ? Attribute() : result;
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // RemSIOp
926 //===----------------------------------------------------------------------===//
927 
928 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
929  // remsi (x, 1) -> 0.
930  if (matchPattern(adaptor.getRhs(), m_One()))
931  return Builder(getContext()).getZeroAttr(getType());
932 
933  // Don't fold if it would require a division by zero.
934  bool div0 = false;
935  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
936  [&](APInt a, const APInt &b) {
937  if (div0 || b.isZero()) {
938  div0 = true;
939  return a;
940  }
941  return a.srem(b);
942  });
943 
944  return div0 ? Attribute() : result;
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // AndIOp
949 //===----------------------------------------------------------------------===//
950 
951 /// Fold `and(a, and(a, b))` to `and(a, b)`
952 static Value foldAndIofAndI(arith::AndIOp op) {
953  for (bool reversePrev : {false, true}) {
954  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955  .getDefiningOp<arith::AndIOp>();
956  if (!prev)
957  continue;
958 
959  Value other = (reversePrev ? op.getLhs() : op.getRhs());
960  if (other != prev.getLhs() && other != prev.getRhs())
961  continue;
962 
963  return prev.getResult();
964  }
965  return {};
966 }
967 
968 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
969  /// and(x, 0) -> 0
970  if (matchPattern(adaptor.getRhs(), m_Zero()))
971  return getRhs();
972  /// and(x, allOnes) -> x
973  APInt intValue;
974  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
975  intValue.isAllOnes())
976  return getLhs();
977  /// and(x, not(x)) -> 0
978  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
979  m_ConstantInt(&intValue))) &&
980  intValue.isAllOnes())
981  return Builder(getContext()).getZeroAttr(getType());
982  /// and(not(x), x) -> 0
983  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
984  m_ConstantInt(&intValue))) &&
985  intValue.isAllOnes())
986  return Builder(getContext()).getZeroAttr(getType());
987 
988  /// and(a, and(a, b)) -> and(a, b)
989  if (Value result = foldAndIofAndI(*this))
990  return result;
991 
992  return constFoldBinaryOp<IntegerAttr>(
993  adaptor.getOperands(),
994  [](APInt a, const APInt &b) { return std::move(a) & b; });
995 }
996 
997 //===----------------------------------------------------------------------===//
998 // OrIOp
999 //===----------------------------------------------------------------------===//
1000 
1001 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1002  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1003  /// or(x, 0) -> x
1004  if (rhsVal.isZero())
1005  return getLhs();
1006  /// or(x, <all ones>) -> <all ones>
1007  if (rhsVal.isAllOnes())
1008  return adaptor.getRhs();
1009  }
1010 
1011  APInt intValue;
1012  /// or(x, xor(x, 1)) -> 1
1013  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1014  m_ConstantInt(&intValue))) &&
1015  intValue.isAllOnes())
1016  return getRhs().getDefiningOp<XOrIOp>().getRhs();
1017  /// or(xor(x, 1), x) -> 1
1018  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1019  m_ConstantInt(&intValue))) &&
1020  intValue.isAllOnes())
1021  return getLhs().getDefiningOp<XOrIOp>().getRhs();
1022 
1023  return constFoldBinaryOp<IntegerAttr>(
1024  adaptor.getOperands(),
1025  [](APInt a, const APInt &b) { return std::move(a) | b; });
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // XOrIOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1033  /// xor(x, 0) -> x
1034  if (matchPattern(adaptor.getRhs(), m_Zero()))
1035  return getLhs();
1036  /// xor(x, x) -> 0
1037  if (getLhs() == getRhs())
1038  return Builder(getContext()).getZeroAttr(getType());
1039  /// xor(xor(x, a), a) -> x
1040  /// xor(xor(a, x), a) -> x
1041  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042  if (prev.getRhs() == getRhs())
1043  return prev.getLhs();
1044  if (prev.getLhs() == getRhs())
1045  return prev.getRhs();
1046  }
1047  /// xor(a, xor(x, a)) -> x
1048  /// xor(a, xor(a, x)) -> x
1049  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050  if (prev.getRhs() == getLhs())
1051  return prev.getLhs();
1052  if (prev.getLhs() == getLhs())
1053  return prev.getRhs();
1054  }
1055 
1056  return constFoldBinaryOp<IntegerAttr>(
1057  adaptor.getOperands(),
1058  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1059 }
1060 
1061 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1062  MLIRContext *context) {
1063  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // NegFOp
1068 //===----------------------------------------------------------------------===//
1069 
1070 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1071  /// negf(negf(x)) -> x
1072  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073  return op.getOperand();
1074  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1075  [](const APFloat &a) { return -a; });
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // AddFOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1083  // addf(x, -0) -> x
1084  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1085  return getLhs();
1086 
1087  return constFoldBinaryOp<FloatAttr>(
1088  adaptor.getOperands(),
1089  [](const APFloat &a, const APFloat &b) { return a + b; });
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // SubFOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1097  // subf(x, +0) -> x
1098  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1099  return getLhs();
1100 
1101  return constFoldBinaryOp<FloatAttr>(
1102  adaptor.getOperands(),
1103  [](const APFloat &a, const APFloat &b) { return a - b; });
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // MaximumFOp
1108 //===----------------------------------------------------------------------===//
1109 
1110 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1111  // maximumf(x,x) -> x
1112  if (getLhs() == getRhs())
1113  return getRhs();
1114 
1115  // maximumf(x, -inf) -> x
1116  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1117  return getLhs();
1118 
1119  return constFoldBinaryOp<FloatAttr>(
1120  adaptor.getOperands(),
1121  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1122 }
1123 
1124 //===----------------------------------------------------------------------===//
1125 // MaxNumFOp
1126 //===----------------------------------------------------------------------===//
1127 
1128 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1129  // maxnumf(x,x) -> x
1130  if (getLhs() == getRhs())
1131  return getRhs();
1132 
1133  // maxnumf(x, NaN) -> x
1134  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1135  return getLhs();
1136 
1137  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // MaxSIOp
1142 //===----------------------------------------------------------------------===//
1143 
1144 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1145  // maxsi(x,x) -> x
1146  if (getLhs() == getRhs())
1147  return getRhs();
1148 
1149  if (APInt intValue;
1150  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1151  // maxsi(x,MAX_INT) -> MAX_INT
1152  if (intValue.isMaxSignedValue())
1153  return getRhs();
1154  // maxsi(x, MIN_INT) -> x
1155  if (intValue.isMinSignedValue())
1156  return getLhs();
1157  }
1158 
1159  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1160  [](const APInt &a, const APInt &b) {
1161  return llvm::APIntOps::smax(a, b);
1162  });
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // MaxUIOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1170  // maxui(x,x) -> x
1171  if (getLhs() == getRhs())
1172  return getRhs();
1173 
1174  if (APInt intValue;
1175  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1176  // maxui(x,MAX_INT) -> MAX_INT
1177  if (intValue.isMaxValue())
1178  return getRhs();
1179  // maxui(x, MIN_INT) -> x
1180  if (intValue.isMinValue())
1181  return getLhs();
1182  }
1183 
1184  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1185  [](const APInt &a, const APInt &b) {
1186  return llvm::APIntOps::umax(a, b);
1187  });
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // MinimumFOp
1192 //===----------------------------------------------------------------------===//
1193 
1194 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1195  // minimumf(x,x) -> x
1196  if (getLhs() == getRhs())
1197  return getRhs();
1198 
1199  // minimumf(x, +inf) -> x
1200  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1201  return getLhs();
1202 
1203  return constFoldBinaryOp<FloatAttr>(
1204  adaptor.getOperands(),
1205  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // MinNumFOp
1210 //===----------------------------------------------------------------------===//
1211 
1212 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1213  // minnumf(x,x) -> x
1214  if (getLhs() == getRhs())
1215  return getRhs();
1216 
1217  // minnumf(x, NaN) -> x
1218  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1219  return getLhs();
1220 
1221  return constFoldBinaryOp<FloatAttr>(
1222  adaptor.getOperands(),
1223  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // MinSIOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1231  // minsi(x,x) -> x
1232  if (getLhs() == getRhs())
1233  return getRhs();
1234 
1235  if (APInt intValue;
1236  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1237  // minsi(x,MIN_INT) -> MIN_INT
1238  if (intValue.isMinSignedValue())
1239  return getRhs();
1240  // minsi(x, MAX_INT) -> x
1241  if (intValue.isMaxSignedValue())
1242  return getLhs();
1243  }
1244 
1245  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246  [](const APInt &a, const APInt &b) {
1247  return llvm::APIntOps::smin(a, b);
1248  });
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // MinUIOp
1253 //===----------------------------------------------------------------------===//
1254 
1255 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1256  // minui(x,x) -> x
1257  if (getLhs() == getRhs())
1258  return getRhs();
1259 
1260  if (APInt intValue;
1261  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1262  // minui(x,MIN_INT) -> MIN_INT
1263  if (intValue.isMinValue())
1264  return getRhs();
1265  // minui(x, MAX_INT) -> x
1266  if (intValue.isMaxValue())
1267  return getLhs();
1268  }
1269 
1270  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1271  [](const APInt &a, const APInt &b) {
1272  return llvm::APIntOps::umin(a, b);
1273  });
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // MulFOp
1278 //===----------------------------------------------------------------------===//
1279 
1280 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1281  // mulf(x, 1) -> x
1282  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1283  return getLhs();
1284 
1285  if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1286  arith::FastMathFlags::nsz)) {
1287  // mulf(x, 0) -> 0
1288  if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1289  return getRhs();
1290  }
1291 
1292  return constFoldBinaryOp<FloatAttr>(
1293  adaptor.getOperands(),
1294  [](const APFloat &a, const APFloat &b) { return a * b; });
1295 }
1296 
1297 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1298  MLIRContext *context) {
1299  patterns.add<MulFOfNegF>(context);
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // DivFOp
1304 //===----------------------------------------------------------------------===//
1305 
1306 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1307  // divf(x, 1) -> x
1308  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1309  return getLhs();
1310 
1311  return constFoldBinaryOp<FloatAttr>(
1312  adaptor.getOperands(),
1313  [](const APFloat &a, const APFloat &b) { return a / b; });
1314 }
1315 
1316 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1317  MLIRContext *context) {
1318  patterns.add<DivFOfNegF>(context);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // RemFOp
1323 //===----------------------------------------------------------------------===//
1324 
1325 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1326  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1327  [](const APFloat &a, const APFloat &b) {
1328  APFloat result(a);
1329  // APFloat::mod() offers the remainder
1330  // behavior we want, i.e. the result has
1331  // the sign of LHS operand.
1332  (void)result.mod(b);
1333  return result;
1334  });
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // Utility functions for verifying cast ops
1339 //===----------------------------------------------------------------------===//
1340 
1341 template <typename... Types>
1342 using type_list = std::tuple<Types...> *;
1343 
1344 /// Returns a non-null type only if the provided type is one of the allowed
1345 /// types or one of the allowed shaped types of the allowed types. Returns the
1346 /// element type if a valid shaped type is provided.
1347 template <typename... ShapedTypes, typename... ElementTypes>
1350  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1351  return {};
1352 
1353  auto underlyingType = getElementTypeOrSelf(type);
1354  if (!llvm::isa<ElementTypes...>(underlyingType))
1355  return {};
1356 
1357  return underlyingType;
1358 }
1359 
1360 /// Get allowed underlying types for vectors and tensors.
1361 template <typename... ElementTypes>
1362 static Type getTypeIfLike(Type type) {
1365 }
1366 
1367 /// Get allowed underlying types for vectors, tensors, and memrefs.
1368 template <typename... ElementTypes>
1370  return getUnderlyingType(type,
1373 }
1374 
1375 /// Return false if both types are ranked tensor with mismatching encoding.
1376 static bool hasSameEncoding(Type typeA, Type typeB) {
1377  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1378  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1379  if (!rankedTensorA || !rankedTensorB)
1380  return true;
1381  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1382 }
1383 
1384 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1385  if (inputs.size() != 1 || outputs.size() != 1)
1386  return false;
1387  if (!hasSameEncoding(inputs.front(), outputs.front()))
1388  return false;
1389  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1390 }
1391 
1392 //===----------------------------------------------------------------------===//
1393 // Verifiers for integer and floating point extension/truncation ops
1394 //===----------------------------------------------------------------------===//
1395 
1396 // Extend ops can only extend to a wider type.
1397 template <typename ValType, typename Op>
1398 static LogicalResult verifyExtOp(Op op) {
1399  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1400  Type dstType = getElementTypeOrSelf(op.getType());
1401 
1402  if (llvm::cast<ValType>(srcType).getWidth() >=
1403  llvm::cast<ValType>(dstType).getWidth())
1404  return op.emitError("result type ")
1405  << dstType << " must be wider than operand type " << srcType;
1406 
1407  return success();
1408 }
1409 
1410 // Truncate ops can only truncate to a shorter type.
1411 template <typename ValType, typename Op>
1412 static LogicalResult verifyTruncateOp(Op op) {
1413  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1414  Type dstType = getElementTypeOrSelf(op.getType());
1415 
1416  if (llvm::cast<ValType>(srcType).getWidth() <=
1417  llvm::cast<ValType>(dstType).getWidth())
1418  return op.emitError("result type ")
1419  << dstType << " must be shorter than operand type " << srcType;
1420 
1421  return success();
1422 }
1423 
1424 /// Validate a cast that changes the width of a type.
1425 template <template <typename> class WidthComparator, typename... ElementTypes>
1426 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1427  if (!areValidCastInputsAndOutputs(inputs, outputs))
1428  return false;
1429 
1430  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1431  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1432  if (!srcType || !dstType)
1433  return false;
1434 
1435  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1436  srcType.getIntOrFloatBitWidth());
1437 }
1438 
1439 /// Attempts to convert `sourceValue` to an APFloat value with
1440 /// `targetSemantics` and `roundingMode`, without any information loss.
1441 static FailureOr<APFloat> convertFloatValue(
1442  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1443  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1444  bool losesInfo = false;
1445  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1446  if (losesInfo || status != APFloat::opOK)
1447  return failure();
1448 
1449  return sourceValue;
1450 }
1451 
1452 //===----------------------------------------------------------------------===//
1453 // ExtUIOp
1454 //===----------------------------------------------------------------------===//
1455 
1456 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1457  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1458  getInMutable().assign(lhs.getIn());
1459  return getResult();
1460  }
1461 
1462  Type resType = getElementTypeOrSelf(getType());
1463  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1464  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1465  adaptor.getOperands(), getType(),
1466  [bitWidth](const APInt &a, bool &castStatus) {
1467  return a.zext(bitWidth);
1468  });
1469 }
1470 
1471 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1472  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1473 }
1474 
1475 LogicalResult arith::ExtUIOp::verify() {
1476  return verifyExtOp<IntegerType>(*this);
1477 }
1478 
1479 //===----------------------------------------------------------------------===//
1480 // ExtSIOp
1481 //===----------------------------------------------------------------------===//
1482 
1483 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1484  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1485  getInMutable().assign(lhs.getIn());
1486  return getResult();
1487  }
1488 
1489  Type resType = getElementTypeOrSelf(getType());
1490  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1491  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1492  adaptor.getOperands(), getType(),
1493  [bitWidth](const APInt &a, bool &castStatus) {
1494  return a.sext(bitWidth);
1495  });
1496 }
1497 
1498 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1499  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1500 }
1501 
1502 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1503  MLIRContext *context) {
1504  patterns.add<ExtSIOfExtUI>(context);
1505 }
1506 
1507 LogicalResult arith::ExtSIOp::verify() {
1508  return verifyExtOp<IntegerType>(*this);
1509 }
1510 
1511 //===----------------------------------------------------------------------===//
1512 // ExtFOp
1513 //===----------------------------------------------------------------------===//
1514 
1515 /// Fold extension of float constants when there is no information loss due the
1516 /// difference in fp semantics.
1517 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1518  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1519  if (truncFOp.getOperand().getType() == getType()) {
1520  arith::FastMathFlags truncFMF =
1521  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1522  bool isTruncContract =
1523  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1524  arith::FastMathFlags extFMF =
1525  getFastmath().value_or(arith::FastMathFlags::none);
1526  bool isExtContract =
1527  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1528  if (isTruncContract && isExtContract) {
1529  return truncFOp.getOperand();
1530  }
1531  }
1532  }
1533 
1534  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1535  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1536  return constFoldCastOp<FloatAttr, FloatAttr>(
1537  adaptor.getOperands(), getType(),
1538  [&targetSemantics](const APFloat &a, bool &castStatus) {
1539  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1540  if (failed(result)) {
1541  castStatus = false;
1542  return a;
1543  }
1544  return *result;
1545  });
1546 }
1547 
1548 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1549  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1550 }
1551 
1552 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // ScalingExtFOp
1556 //===----------------------------------------------------------------------===//
1557 
1558 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1559  TypeRange outputs) {
1560  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1561 }
1562 
1563 LogicalResult arith::ScalingExtFOp::verify() {
1564  return verifyExtOp<FloatType>(*this);
1565 }
1566 
1567 //===----------------------------------------------------------------------===//
1568 // TruncIOp
1569 //===----------------------------------------------------------------------===//
1570 
1571 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1572  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1573  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1574  Value src = getOperand().getDefiningOp()->getOperand(0);
1575  Type srcType = getElementTypeOrSelf(src.getType());
1576  Type dstType = getElementTypeOrSelf(getType());
1577  // trunci(zexti(a)) -> trunci(a)
1578  // trunci(sexti(a)) -> trunci(a)
1579  if (llvm::cast<IntegerType>(srcType).getWidth() >
1580  llvm::cast<IntegerType>(dstType).getWidth()) {
1581  setOperand(src);
1582  return getResult();
1583  }
1584 
1585  // trunci(zexti(a)) -> a
1586  // trunci(sexti(a)) -> a
1587  if (srcType == dstType)
1588  return src;
1589  }
1590 
1591  // trunci(trunci(a)) -> trunci(a))
1592  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1593  setOperand(getOperand().getDefiningOp()->getOperand(0));
1594  return getResult();
1595  }
1596 
1597  Type resType = getElementTypeOrSelf(getType());
1598  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1599  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1600  adaptor.getOperands(), getType(),
1601  [bitWidth](const APInt &a, bool &castStatus) {
1602  return a.trunc(bitWidth);
1603  });
1604 }
1605 
1606 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1607  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1608 }
1609 
1610 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1611  MLIRContext *context) {
1612  patterns
1613  .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1614  context);
1615 }
1616 
1617 LogicalResult arith::TruncIOp::verify() {
1618  return verifyTruncateOp<IntegerType>(*this);
1619 }
1620 
1621 //===----------------------------------------------------------------------===//
1622 // TruncFOp
1623 //===----------------------------------------------------------------------===//
1624 
1625 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1626 /// can be represented without precision loss.
1627 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1628  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1629  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1630  Value src = extOp.getIn();
1631  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1632  auto intermediateType =
1633  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1634  // Check if the srcType is representable in the intermediateType.
1635  if (llvm::APFloatBase::isRepresentableBy(
1636  srcType.getFloatSemantics(),
1637  intermediateType.getFloatSemantics())) {
1638  // truncf(extf(a)) -> truncf(a)
1639  if (srcType.getWidth() > resElemType.getWidth()) {
1640  setOperand(src);
1641  return getResult();
1642  }
1643 
1644  // truncf(extf(a)) -> a
1645  if (srcType == resElemType)
1646  return src;
1647  }
1648  }
1649 
1650  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1651  return constFoldCastOp<FloatAttr, FloatAttr>(
1652  adaptor.getOperands(), getType(),
1653  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1654  RoundingMode roundingMode =
1655  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1656  llvm::RoundingMode llvmRoundingMode =
1657  convertArithRoundingModeToLLVMIR(roundingMode);
1658  FailureOr<APFloat> result =
1659  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1660  if (failed(result)) {
1661  castStatus = false;
1662  return a;
1663  }
1664  return *result;
1665  });
1666 }
1667 
1668 void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1669  MLIRContext *context) {
1670  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1671 }
1672 
1673 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1674  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1675 }
1676 
1677 LogicalResult arith::TruncFOp::verify() {
1678  return verifyTruncateOp<FloatType>(*this);
1679 }
1680 
1681 //===----------------------------------------------------------------------===//
1682 // ScalingTruncFOp
1683 //===----------------------------------------------------------------------===//
1684 
1685 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1686  TypeRange outputs) {
1687  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1688 }
1689 
1690 LogicalResult arith::ScalingTruncFOp::verify() {
1691  return verifyTruncateOp<FloatType>(*this);
1692 }
1693 
1694 //===----------------------------------------------------------------------===//
1695 // AndIOp
1696 //===----------------------------------------------------------------------===//
1697 
1698 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1699  MLIRContext *context) {
1700  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1701 }
1702 
1703 //===----------------------------------------------------------------------===//
1704 // OrIOp
1705 //===----------------------------------------------------------------------===//
1706 
1707 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1708  MLIRContext *context) {
1709  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // Verifiers for casts between integers and floats.
1714 //===----------------------------------------------------------------------===//
1715 
1716 template <typename From, typename To>
1717 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1718  if (!areValidCastInputsAndOutputs(inputs, outputs))
1719  return false;
1720 
1721  auto srcType = getTypeIfLike<From>(inputs.front());
1722  auto dstType = getTypeIfLike<To>(outputs.back());
1723 
1724  return srcType && dstType;
1725 }
1726 
1727 //===----------------------------------------------------------------------===//
1728 // UIToFPOp
1729 //===----------------------------------------------------------------------===//
1730 
1731 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1732  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1733 }
1734 
1735 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1736  Type resEleType = getElementTypeOrSelf(getType());
1737  return constFoldCastOp<IntegerAttr, FloatAttr>(
1738  adaptor.getOperands(), getType(),
1739  [&resEleType](const APInt &a, bool &castStatus) {
1740  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1741  APFloat apf(floatTy.getFloatSemantics(),
1742  APInt::getZero(floatTy.getWidth()));
1743  apf.convertFromAPInt(a, /*IsSigned=*/false,
1744  APFloat::rmNearestTiesToEven);
1745  return apf;
1746  });
1747 }
1748 
1749 //===----------------------------------------------------------------------===//
1750 // SIToFPOp
1751 //===----------------------------------------------------------------------===//
1752 
1753 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1754  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1755 }
1756 
1757 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1758  Type resEleType = getElementTypeOrSelf(getType());
1759  return constFoldCastOp<IntegerAttr, FloatAttr>(
1760  adaptor.getOperands(), getType(),
1761  [&resEleType](const APInt &a, bool &castStatus) {
1762  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1763  APFloat apf(floatTy.getFloatSemantics(),
1764  APInt::getZero(floatTy.getWidth()));
1765  apf.convertFromAPInt(a, /*IsSigned=*/true,
1766  APFloat::rmNearestTiesToEven);
1767  return apf;
1768  });
1769 }
1770 
1771 //===----------------------------------------------------------------------===//
1772 // FPToUIOp
1773 //===----------------------------------------------------------------------===//
1774 
1775 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1776  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1777 }
1778 
1779 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1780  Type resType = getElementTypeOrSelf(getType());
1781  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1782  return constFoldCastOp<FloatAttr, IntegerAttr>(
1783  adaptor.getOperands(), getType(),
1784  [&bitWidth](const APFloat &a, bool &castStatus) {
1785  bool ignored;
1786  APSInt api(bitWidth, /*isUnsigned=*/true);
1787  castStatus = APFloat::opInvalidOp !=
1788  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1789  return api;
1790  });
1791 }
1792 
1793 //===----------------------------------------------------------------------===//
1794 // FPToSIOp
1795 //===----------------------------------------------------------------------===//
1796 
1797 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1798  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1799 }
1800 
1801 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1802  Type resType = getElementTypeOrSelf(getType());
1803  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1804  return constFoldCastOp<FloatAttr, IntegerAttr>(
1805  adaptor.getOperands(), getType(),
1806  [&bitWidth](const APFloat &a, bool &castStatus) {
1807  bool ignored;
1808  APSInt api(bitWidth, /*isUnsigned=*/false);
1809  castStatus = APFloat::opInvalidOp !=
1810  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1811  return api;
1812  });
1813 }
1814 
1815 //===----------------------------------------------------------------------===//
1816 // IndexCastOp
1817 //===----------------------------------------------------------------------===//
1818 
1819 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1820  if (!areValidCastInputsAndOutputs(inputs, outputs))
1821  return false;
1822 
1823  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1824  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1825  if (!srcType || !dstType)
1826  return false;
1827 
1828  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1829  (srcType.isSignlessInteger() && dstType.isIndex());
1830 }
1831 
1832 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1833  TypeRange outputs) {
1834  return areIndexCastCompatible(inputs, outputs);
1835 }
1836 
1837 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1838  // index_cast(constant) -> constant
1839  unsigned resultBitwidth = 64; // Default for index integer attributes.
1840  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1841  resultBitwidth = intTy.getWidth();
1842 
1843  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1844  adaptor.getOperands(), getType(),
1845  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1846  return a.sextOrTrunc(resultBitwidth);
1847  });
1848 }
1849 
1850 void arith::IndexCastOp::getCanonicalizationPatterns(
1851  RewritePatternSet &patterns, MLIRContext *context) {
1852  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1853 }
1854 
1855 //===----------------------------------------------------------------------===//
1856 // IndexCastUIOp
1857 //===----------------------------------------------------------------------===//
1858 
1859 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1860  TypeRange outputs) {
1861  return areIndexCastCompatible(inputs, outputs);
1862 }
1863 
1864 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1865  // index_castui(constant) -> constant
1866  unsigned resultBitwidth = 64; // Default for index integer attributes.
1867  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1868  resultBitwidth = intTy.getWidth();
1869 
1870  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1871  adaptor.getOperands(), getType(),
1872  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1873  return a.zextOrTrunc(resultBitwidth);
1874  });
1875 }
1876 
1877 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1878  RewritePatternSet &patterns, MLIRContext *context) {
1879  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1880 }
1881 
1882 //===----------------------------------------------------------------------===//
1883 // BitcastOp
1884 //===----------------------------------------------------------------------===//
1885 
1886 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1887  if (!areValidCastInputsAndOutputs(inputs, outputs))
1888  return false;
1889 
1890  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1891  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1892  if (!srcType || !dstType)
1893  return false;
1894 
1895  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1896 }
1897 
1898 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1899  auto resType = getType();
1900  auto operand = adaptor.getIn();
1901  if (!operand)
1902  return {};
1903 
1904  /// Bitcast dense elements.
1905  if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1906  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1907  /// Other shaped types unhandled.
1908  if (llvm::isa<ShapedType>(resType))
1909  return {};
1910 
1911  /// Bitcast poison.
1912  if (llvm::isa<ub::PoisonAttr>(operand))
1913  return ub::PoisonAttr::get(getContext());
1914 
1915  /// Bitcast integer or float to integer or float.
1916  APInt bits = llvm::isa<FloatAttr>(operand)
1917  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1918  : llvm::cast<IntegerAttr>(operand).getValue();
1919  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1920  "trying to fold on broken IR: operands have incompatible types");
1921 
1922  if (auto resFloatType = dyn_cast<FloatType>(resType))
1923  return FloatAttr::get(resType,
1924  APFloat(resFloatType.getFloatSemantics(), bits));
1925  return IntegerAttr::get(resType, bits);
1926 }
1927 
1928 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1929  MLIRContext *context) {
1930  patterns.add<BitcastOfBitcast>(context);
1931 }
1932 
1933 //===----------------------------------------------------------------------===//
1934 // CmpIOp
1935 //===----------------------------------------------------------------------===//
1936 
1937 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1938 /// comparison predicates.
1939 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1940  const APInt &lhs, const APInt &rhs) {
1941  switch (predicate) {
1942  case arith::CmpIPredicate::eq:
1943  return lhs.eq(rhs);
1944  case arith::CmpIPredicate::ne:
1945  return lhs.ne(rhs);
1946  case arith::CmpIPredicate::slt:
1947  return lhs.slt(rhs);
1948  case arith::CmpIPredicate::sle:
1949  return lhs.sle(rhs);
1950  case arith::CmpIPredicate::sgt:
1951  return lhs.sgt(rhs);
1952  case arith::CmpIPredicate::sge:
1953  return lhs.sge(rhs);
1954  case arith::CmpIPredicate::ult:
1955  return lhs.ult(rhs);
1956  case arith::CmpIPredicate::ule:
1957  return lhs.ule(rhs);
1958  case arith::CmpIPredicate::ugt:
1959  return lhs.ugt(rhs);
1960  case arith::CmpIPredicate::uge:
1961  return lhs.uge(rhs);
1962  }
1963  llvm_unreachable("unknown cmpi predicate kind");
1964 }
1965 
1966 /// Returns true if the predicate is true for two equal operands.
1967 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1968  switch (predicate) {
1969  case arith::CmpIPredicate::eq:
1970  case arith::CmpIPredicate::sle:
1971  case arith::CmpIPredicate::sge:
1972  case arith::CmpIPredicate::ule:
1973  case arith::CmpIPredicate::uge:
1974  return true;
1975  case arith::CmpIPredicate::ne:
1976  case arith::CmpIPredicate::slt:
1977  case arith::CmpIPredicate::sgt:
1978  case arith::CmpIPredicate::ult:
1979  case arith::CmpIPredicate::ugt:
1980  return false;
1981  }
1982  llvm_unreachable("unknown cmpi predicate kind");
1983 }
1984 
1985 static std::optional<int64_t> getIntegerWidth(Type t) {
1986  if (auto intType = dyn_cast<IntegerType>(t)) {
1987  return intType.getWidth();
1988  }
1989  if (auto vectorIntType = dyn_cast<VectorType>(t)) {
1990  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1991  }
1992  return std::nullopt;
1993 }
1994 
1995 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1996  // cmpi(pred, x, x)
1997  if (getLhs() == getRhs()) {
1998  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1999  return getBoolAttribute(getType(), val);
2000  }
2001 
2002  if (matchPattern(adaptor.getRhs(), m_Zero())) {
2003  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2004  // extsi(%x : i1 -> iN) != 0 -> %x
2005  std::optional<int64_t> integerWidth =
2006  getIntegerWidth(extOp.getOperand().getType());
2007  if (integerWidth && integerWidth.value() == 1 &&
2008  getPredicate() == arith::CmpIPredicate::ne)
2009  return extOp.getOperand();
2010  }
2011  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2012  // extui(%x : i1 -> iN) != 0 -> %x
2013  std::optional<int64_t> integerWidth =
2014  getIntegerWidth(extOp.getOperand().getType());
2015  if (integerWidth && integerWidth.value() == 1 &&
2016  getPredicate() == arith::CmpIPredicate::ne)
2017  return extOp.getOperand();
2018  }
2019 
2020  // arith.cmpi ne, %val, %zero : i1 -> %val
2021  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2022  getPredicate() == arith::CmpIPredicate::ne)
2023  return getLhs();
2024  }
2025 
2026  if (matchPattern(adaptor.getRhs(), m_One())) {
2027  // arith.cmpi eq, %val, %one : i1 -> %val
2028  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2029  getPredicate() == arith::CmpIPredicate::eq)
2030  return getLhs();
2031  }
2032 
2033  // Move constant to the right side.
2034  if (adaptor.getLhs() && !adaptor.getRhs()) {
2035  // Do not use invertPredicate, as it will change eq to ne and vice versa.
2036  using Pred = CmpIPredicate;
2037  const std::pair<Pred, Pred> invPreds[] = {
2038  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2039  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2040  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2041  {Pred::ne, Pred::ne},
2042  };
2043  Pred origPred = getPredicate();
2044  for (auto pred : invPreds) {
2045  if (origPred == pred.first) {
2046  setPredicate(pred.second);
2047  Value lhs = getLhs();
2048  Value rhs = getRhs();
2049  getLhsMutable().assign(rhs);
2050  getRhsMutable().assign(lhs);
2051  return getResult();
2052  }
2053  }
2054  llvm_unreachable("unknown cmpi predicate kind");
2055  }
2056 
2057  // We are moving constants to the right side; So if lhs is constant rhs is
2058  // guaranteed to be a constant.
2059  if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2060  return constFoldBinaryOp<IntegerAttr>(
2061  adaptor.getOperands(), getI1SameShape(lhs.getType()),
2062  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2063  return APInt(1,
2064  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2065  });
2066  }
2067 
2068  return {};
2069 }
2070 
2071 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2072  MLIRContext *context) {
2073  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // CmpFOp
2078 //===----------------------------------------------------------------------===//
2079 
2080 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2081 /// comparison predicates.
2082 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2083  const APFloat &lhs, const APFloat &rhs) {
2084  auto cmpResult = lhs.compare(rhs);
2085  switch (predicate) {
2086  case arith::CmpFPredicate::AlwaysFalse:
2087  return false;
2088  case arith::CmpFPredicate::OEQ:
2089  return cmpResult == APFloat::cmpEqual;
2090  case arith::CmpFPredicate::OGT:
2091  return cmpResult == APFloat::cmpGreaterThan;
2092  case arith::CmpFPredicate::OGE:
2093  return cmpResult == APFloat::cmpGreaterThan ||
2094  cmpResult == APFloat::cmpEqual;
2095  case arith::CmpFPredicate::OLT:
2096  return cmpResult == APFloat::cmpLessThan;
2097  case arith::CmpFPredicate::OLE:
2098  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2099  case arith::CmpFPredicate::ONE:
2100  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2101  case arith::CmpFPredicate::ORD:
2102  return cmpResult != APFloat::cmpUnordered;
2103  case arith::CmpFPredicate::UEQ:
2104  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2105  case arith::CmpFPredicate::UGT:
2106  return cmpResult == APFloat::cmpUnordered ||
2107  cmpResult == APFloat::cmpGreaterThan;
2108  case arith::CmpFPredicate::UGE:
2109  return cmpResult == APFloat::cmpUnordered ||
2110  cmpResult == APFloat::cmpGreaterThan ||
2111  cmpResult == APFloat::cmpEqual;
2112  case arith::CmpFPredicate::ULT:
2113  return cmpResult == APFloat::cmpUnordered ||
2114  cmpResult == APFloat::cmpLessThan;
2115  case arith::CmpFPredicate::ULE:
2116  return cmpResult == APFloat::cmpUnordered ||
2117  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2118  case arith::CmpFPredicate::UNE:
2119  return cmpResult != APFloat::cmpEqual;
2120  case arith::CmpFPredicate::UNO:
2121  return cmpResult == APFloat::cmpUnordered;
2122  case arith::CmpFPredicate::AlwaysTrue:
2123  return true;
2124  }
2125  llvm_unreachable("unknown cmpf predicate kind");
2126 }
2127 
2128 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2129  auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2130  auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2131 
2132  // If one operand is NaN, making them both NaN does not change the result.
2133  if (lhs && lhs.getValue().isNaN())
2134  rhs = lhs;
2135  if (rhs && rhs.getValue().isNaN())
2136  lhs = rhs;
2137 
2138  if (!lhs || !rhs)
2139  return {};
2140 
2141  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2142  return BoolAttr::get(getContext(), val);
2143 }
2144 
2145 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2146 public:
2147  using Base::Base;
2148 
2149  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2150  bool isUnsigned) {
2151  using namespace arith;
2152  switch (pred) {
2153  case CmpFPredicate::UEQ:
2154  case CmpFPredicate::OEQ:
2155  return CmpIPredicate::eq;
2156  case CmpFPredicate::UGT:
2157  case CmpFPredicate::OGT:
2158  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2159  case CmpFPredicate::UGE:
2160  case CmpFPredicate::OGE:
2161  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2162  case CmpFPredicate::ULT:
2163  case CmpFPredicate::OLT:
2164  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2165  case CmpFPredicate::ULE:
2166  case CmpFPredicate::OLE:
2167  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2168  case CmpFPredicate::UNE:
2169  case CmpFPredicate::ONE:
2170  return CmpIPredicate::ne;
2171  default:
2172  llvm_unreachable("Unexpected predicate!");
2173  }
2174  }
2175 
2176  LogicalResult matchAndRewrite(CmpFOp op,
2177  PatternRewriter &rewriter) const override {
2178  FloatAttr flt;
2179  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2180  return failure();
2181 
2182  const APFloat &rhs = flt.getValue();
2183 
2184  // Don't attempt to fold a nan.
2185  if (rhs.isNaN())
2186  return failure();
2187 
2188  // Get the width of the mantissa. We don't want to hack on conversions that
2189  // might lose information from the integer, e.g. "i64 -> float"
2190  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2191  int mantissaWidth = floatTy.getFPMantissaWidth();
2192  if (mantissaWidth <= 0)
2193  return failure();
2194 
2195  bool isUnsigned;
2196  Value intVal;
2197 
2198  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2199  isUnsigned = false;
2200  intVal = si.getIn();
2201  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2202  isUnsigned = true;
2203  intVal = ui.getIn();
2204  } else {
2205  return failure();
2206  }
2207 
2208  // Check to see that the input is converted from an integer type that is
2209  // small enough that preserves all bits.
2210  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2211  auto intWidth = intTy.getWidth();
2212 
2213  // Number of bits representing values, as opposed to the sign
2214  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2215 
2216  // Following test does NOT adjust intWidth downwards for signed inputs,
2217  // because the most negative value still requires all the mantissa bits
2218  // to distinguish it from one less than that value.
2219  if ((int)intWidth > mantissaWidth) {
2220  // Conversion would lose accuracy. Check if loss can impact comparison.
2221  int exponent = ilogb(rhs);
2222  if (exponent == APFloat::IEK_Inf) {
2223  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2224  if (maxExponent < (int)valueBits) {
2225  // Conversion could create infinity.
2226  return failure();
2227  }
2228  } else {
2229  // Note that if rhs is zero or NaN, then Exp is negative
2230  // and first condition is trivially false.
2231  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2232  // Conversion could affect comparison.
2233  return failure();
2234  }
2235  }
2236  }
2237 
2238  // Convert to equivalent cmpi predicate
2239  CmpIPredicate pred;
2240  switch (op.getPredicate()) {
2241  case CmpFPredicate::ORD:
2242  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2243  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2244  /*width=*/1);
2245  return success();
2246  case CmpFPredicate::UNO:
2247  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2248  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2249  /*width=*/1);
2250  return success();
2251  default:
2252  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2253  break;
2254  }
2255 
2256  if (!isUnsigned) {
2257  // If the rhs value is > SignedMax, fold the comparison. This handles
2258  // +INF and large values.
2259  APFloat signedMax(rhs.getSemantics());
2260  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2261  APFloat::rmNearestTiesToEven);
2262  if (signedMax < rhs) { // smax < 13123.0
2263  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2264  pred == CmpIPredicate::sle)
2265  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2266  /*width=*/1);
2267  else
2268  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2269  /*width=*/1);
2270  return success();
2271  }
2272  } else {
2273  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2274  // +INF and large values.
2275  APFloat unsignedMax(rhs.getSemantics());
2276  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2277  APFloat::rmNearestTiesToEven);
2278  if (unsignedMax < rhs) { // umax < 13123.0
2279  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2280  pred == CmpIPredicate::ule)
2281  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2282  /*width=*/1);
2283  else
2284  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2285  /*width=*/1);
2286  return success();
2287  }
2288  }
2289 
2290  if (!isUnsigned) {
2291  // See if the rhs value is < SignedMin.
2292  APFloat signedMin(rhs.getSemantics());
2293  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2294  APFloat::rmNearestTiesToEven);
2295  if (signedMin > rhs) { // smin > 12312.0
2296  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2297  pred == CmpIPredicate::sge)
2298  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2299  /*width=*/1);
2300  else
2301  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2302  /*width=*/1);
2303  return success();
2304  }
2305  } else {
2306  // See if the rhs value is < UnsignedMin.
2307  APFloat unsignedMin(rhs.getSemantics());
2308  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2309  APFloat::rmNearestTiesToEven);
2310  if (unsignedMin > rhs) { // umin > 12312.0
2311  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2312  pred == CmpIPredicate::uge)
2313  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2314  /*width=*/1);
2315  else
2316  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2317  /*width=*/1);
2318  return success();
2319  }
2320  }
2321 
2322  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2323  // [0, UMAX], but it may still be fractional. See if it is fractional by
2324  // casting the FP value to the integer value and back, checking for
2325  // equality. Don't do this for zero, because -0.0 is not fractional.
2326  bool ignored;
2327  APSInt rhsInt(intWidth, isUnsigned);
2328  if (APFloat::opInvalidOp ==
2329  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2330  // Undefined behavior invoked - the destination type can't represent
2331  // the input constant.
2332  return failure();
2333  }
2334 
2335  if (!rhs.isZero()) {
2336  APFloat apf(floatTy.getFloatSemantics(),
2337  APInt::getZero(floatTy.getWidth()));
2338  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2339 
2340  bool equal = apf == rhs;
2341  if (!equal) {
2342  // If we had a comparison against a fractional value, we have to adjust
2343  // the compare predicate and sometimes the value. rhsInt is rounded
2344  // towards zero at this point.
2345  switch (pred) {
2346  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2347  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2348  /*width=*/1);
2349  return success();
2350  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2351  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2352  /*width=*/1);
2353  return success();
2354  case CmpIPredicate::ule:
2355  // (float)int <= 4.4 --> int <= 4
2356  // (float)int <= -4.4 --> false
2357  if (rhs.isNegative()) {
2358  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2359  /*width=*/1);
2360  return success();
2361  }
2362  break;
2363  case CmpIPredicate::sle:
2364  // (float)int <= 4.4 --> int <= 4
2365  // (float)int <= -4.4 --> int < -4
2366  if (rhs.isNegative())
2367  pred = CmpIPredicate::slt;
2368  break;
2369  case CmpIPredicate::ult:
2370  // (float)int < -4.4 --> false
2371  // (float)int < 4.4 --> int <= 4
2372  if (rhs.isNegative()) {
2373  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2374  /*width=*/1);
2375  return success();
2376  }
2377  pred = CmpIPredicate::ule;
2378  break;
2379  case CmpIPredicate::slt:
2380  // (float)int < -4.4 --> int < -4
2381  // (float)int < 4.4 --> int <= 4
2382  if (!rhs.isNegative())
2383  pred = CmpIPredicate::sle;
2384  break;
2385  case CmpIPredicate::ugt:
2386  // (float)int > 4.4 --> int > 4
2387  // (float)int > -4.4 --> true
2388  if (rhs.isNegative()) {
2389  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2390  /*width=*/1);
2391  return success();
2392  }
2393  break;
2394  case CmpIPredicate::sgt:
2395  // (float)int > 4.4 --> int > 4
2396  // (float)int > -4.4 --> int >= -4
2397  if (rhs.isNegative())
2398  pred = CmpIPredicate::sge;
2399  break;
2400  case CmpIPredicate::uge:
2401  // (float)int >= -4.4 --> true
2402  // (float)int >= 4.4 --> int > 4
2403  if (rhs.isNegative()) {
2404  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2405  /*width=*/1);
2406  return success();
2407  }
2408  pred = CmpIPredicate::ugt;
2409  break;
2410  case CmpIPredicate::sge:
2411  // (float)int >= -4.4 --> int >= -4
2412  // (float)int >= 4.4 --> int > 4
2413  if (!rhs.isNegative())
2414  pred = CmpIPredicate::sgt;
2415  break;
2416  }
2417  }
2418  }
2419 
2420  // Lower this FP comparison into an appropriate integer version of the
2421  // comparison.
2422  rewriter.replaceOpWithNewOp<CmpIOp>(
2423  op, pred, intVal,
2424  ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2425  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2426  return success();
2427  }
2428 };
2429 
2430 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2431  MLIRContext *context) {
2432  patterns.insert<CmpFIntToFPConst>(context);
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // SelectOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 // select %arg, %c1, %c0 => extui %arg
2440 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2441  using Base::Base;
2442 
2443  LogicalResult matchAndRewrite(arith::SelectOp op,
2444  PatternRewriter &rewriter) const override {
2445  // Cannot extui i1 to i1, or i1 to f32
2446  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2447  return failure();
2448 
2449  // select %x, c1, %c0 => extui %arg
2450  if (matchPattern(op.getTrueValue(), m_One()) &&
2451  matchPattern(op.getFalseValue(), m_Zero())) {
2452  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2453  op.getCondition());
2454  return success();
2455  }
2456 
2457  // select %x, c0, %c1 => extui (xor %arg, true)
2458  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2459  matchPattern(op.getFalseValue(), m_One())) {
2460  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2461  op, op.getType(),
2462  arith::XOrIOp::create(
2463  rewriter, op.getLoc(), op.getCondition(),
2464  arith::ConstantIntOp::create(rewriter, op.getLoc(),
2465  op.getCondition().getType(), 1)));
2466  return success();
2467  }
2468 
2469  return failure();
2470  }
2471 };
2472 
2473 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2474  MLIRContext *context) {
2475  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2476  SelectI1ToNot, SelectToExtUI>(context);
2477 }
2478 
2479 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2480  Value trueVal = getTrueValue();
2481  Value falseVal = getFalseValue();
2482  if (trueVal == falseVal)
2483  return trueVal;
2484 
2485  Value condition = getCondition();
2486 
2487  // select true, %0, %1 => %0
2488  if (matchPattern(adaptor.getCondition(), m_One()))
2489  return trueVal;
2490 
2491  // select false, %0, %1 => %1
2492  if (matchPattern(adaptor.getCondition(), m_Zero()))
2493  return falseVal;
2494 
2495  // If either operand is fully poisoned, return the other.
2496  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2497  return falseVal;
2498 
2499  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2500  return trueVal;
2501 
2502  // select %x, true, false => %x
2503  if (getType().isSignlessInteger(1) &&
2504  matchPattern(adaptor.getTrueValue(), m_One()) &&
2505  matchPattern(adaptor.getFalseValue(), m_Zero()))
2506  return condition;
2507 
2508  if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2509  auto pred = cmp.getPredicate();
2510  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2511  auto cmpLhs = cmp.getLhs();
2512  auto cmpRhs = cmp.getRhs();
2513 
2514  // %0 = arith.cmpi eq, %arg0, %arg1
2515  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2516 
2517  // %0 = arith.cmpi ne, %arg0, %arg1
2518  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2519 
2520  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2521  (cmpRhs == trueVal && cmpLhs == falseVal))
2522  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2523  }
2524  }
2525 
2526  // Constant-fold constant operands over non-splat constant condition.
2527  // select %cst_vec, %cst0, %cst1 => %cst2
2528  if (auto cond =
2529  dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2530  if (auto lhs =
2531  dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2532  if (auto rhs =
2533  dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2534  SmallVector<Attribute> results;
2535  results.reserve(static_cast<size_t>(cond.getNumElements()));
2536  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2537  cond.value_end<BoolAttr>());
2538  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2539  lhs.value_end<Attribute>());
2540  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2541  rhs.value_end<Attribute>());
2542 
2543  for (auto [condVal, lhsVal, rhsVal] :
2544  llvm::zip_equal(condVals, lhsVals, rhsVals))
2545  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2546 
2547  return DenseElementsAttr::get(lhs.getType(), results);
2548  }
2549  }
2550  }
2551 
2552  return nullptr;
2553 }
2554 
2555 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2556  Type conditionType, resultType;
2558  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2559  parser.parseOptionalAttrDict(result.attributes) ||
2560  parser.parseColonType(resultType))
2561  return failure();
2562 
2563  // Check for the explicit condition type if this is a masked tensor or vector.
2564  if (succeeded(parser.parseOptionalComma())) {
2565  conditionType = resultType;
2566  if (parser.parseType(resultType))
2567  return failure();
2568  } else {
2569  conditionType = parser.getBuilder().getI1Type();
2570  }
2571 
2572  result.addTypes(resultType);
2573  return parser.resolveOperands(operands,
2574  {conditionType, resultType, resultType},
2575  parser.getNameLoc(), result.operands);
2576 }
2577 
2579  p << " " << getOperands();
2580  p.printOptionalAttrDict((*this)->getAttrs());
2581  p << " : ";
2582  if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2583  p << condType << ", ";
2584  p << getType();
2585 }
2586 
2587 LogicalResult arith::SelectOp::verify() {
2588  Type conditionType = getCondition().getType();
2589  if (conditionType.isSignlessInteger(1))
2590  return success();
2591 
2592  // If the result type is a vector or tensor, the type can be a mask with the
2593  // same elements.
2594  Type resultType = getType();
2595  if (!llvm::isa<TensorType, VectorType>(resultType))
2596  return emitOpError() << "expected condition to be a signless i1, but got "
2597  << conditionType;
2598  Type shapedConditionType = getI1SameShape(resultType);
2599  if (conditionType != shapedConditionType) {
2600  return emitOpError() << "expected condition type to have the same shape "
2601  "as the result type, expected "
2602  << shapedConditionType << ", but got "
2603  << conditionType;
2604  }
2605  return success();
2606 }
2607 //===----------------------------------------------------------------------===//
2608 // ShLIOp
2609 //===----------------------------------------------------------------------===//
2610 
2611 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2612  // shli(x, 0) -> x
2613  if (matchPattern(adaptor.getRhs(), m_Zero()))
2614  return getLhs();
2615  // Don't fold if shifting more or equal than the bit width.
2616  bool bounded = false;
2617  auto result = constFoldBinaryOp<IntegerAttr>(
2618  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2619  bounded = b.ult(b.getBitWidth());
2620  return a.shl(b);
2621  });
2622  return bounded ? result : Attribute();
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // ShRUIOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2630  // shrui(x, 0) -> x
2631  if (matchPattern(adaptor.getRhs(), m_Zero()))
2632  return getLhs();
2633  // Don't fold if shifting more or equal than the bit width.
2634  bool bounded = false;
2635  auto result = constFoldBinaryOp<IntegerAttr>(
2636  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2637  bounded = b.ult(b.getBitWidth());
2638  return a.lshr(b);
2639  });
2640  return bounded ? result : Attribute();
2641 }
2642 
2643 //===----------------------------------------------------------------------===//
2644 // ShRSIOp
2645 //===----------------------------------------------------------------------===//
2646 
2647 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2648  // shrsi(x, 0) -> x
2649  if (matchPattern(adaptor.getRhs(), m_Zero()))
2650  return getLhs();
2651  // Don't fold if shifting more or equal than the bit width.
2652  bool bounded = false;
2653  auto result = constFoldBinaryOp<IntegerAttr>(
2654  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2655  bounded = b.ult(b.getBitWidth());
2656  return a.ashr(b);
2657  });
2658  return bounded ? result : Attribute();
2659 }
2660 
2661 //===----------------------------------------------------------------------===//
2662 // Atomic Enum
2663 //===----------------------------------------------------------------------===//
2664 
2665 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2666 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2667  OpBuilder &builder, Location loc,
2668  bool useOnlyFiniteValue) {
2669  switch (kind) {
2670  case AtomicRMWKind::maximumf: {
2671  const llvm::fltSemantics &semantic =
2672  llvm::cast<FloatType>(resultType).getFloatSemantics();
2673  APFloat identity = useOnlyFiniteValue
2674  ? APFloat::getLargest(semantic, /*Negative=*/true)
2675  : APFloat::getInf(semantic, /*Negative=*/true);
2676  return builder.getFloatAttr(resultType, identity);
2677  }
2678  case AtomicRMWKind::maxnumf: {
2679  const llvm::fltSemantics &semantic =
2680  llvm::cast<FloatType>(resultType).getFloatSemantics();
2681  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2682  return builder.getFloatAttr(resultType, identity);
2683  }
2684  case AtomicRMWKind::addf:
2685  case AtomicRMWKind::addi:
2686  case AtomicRMWKind::maxu:
2687  case AtomicRMWKind::ori:
2688  case AtomicRMWKind::xori:
2689  return builder.getZeroAttr(resultType);
2690  case AtomicRMWKind::andi:
2691  return builder.getIntegerAttr(
2692  resultType,
2693  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2694  case AtomicRMWKind::maxs:
2695  return builder.getIntegerAttr(
2696  resultType, APInt::getSignedMinValue(
2697  llvm::cast<IntegerType>(resultType).getWidth()));
2698  case AtomicRMWKind::minimumf: {
2699  const llvm::fltSemantics &semantic =
2700  llvm::cast<FloatType>(resultType).getFloatSemantics();
2701  APFloat identity = useOnlyFiniteValue
2702  ? APFloat::getLargest(semantic, /*Negative=*/false)
2703  : APFloat::getInf(semantic, /*Negative=*/false);
2704 
2705  return builder.getFloatAttr(resultType, identity);
2706  }
2707  case AtomicRMWKind::minnumf: {
2708  const llvm::fltSemantics &semantic =
2709  llvm::cast<FloatType>(resultType).getFloatSemantics();
2710  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2711  return builder.getFloatAttr(resultType, identity);
2712  }
2713  case AtomicRMWKind::mins:
2714  return builder.getIntegerAttr(
2715  resultType, APInt::getSignedMaxValue(
2716  llvm::cast<IntegerType>(resultType).getWidth()));
2717  case AtomicRMWKind::minu:
2718  return builder.getIntegerAttr(
2719  resultType,
2720  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2721  case AtomicRMWKind::muli:
2722  return builder.getIntegerAttr(resultType, 1);
2723  case AtomicRMWKind::mulf:
2724  return builder.getFloatAttr(resultType, 1);
2725  // TODO: Add remaining reduction operations.
2726  default:
2727  (void)emitOptionalError(loc, "Reduction operation type not supported");
2728  break;
2729  }
2730  return nullptr;
2731 }
2732 
2733 /// Returns the identity numeric value of the given op.
2734 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2735  std::optional<AtomicRMWKind> maybeKind =
2737  // Floating-point operations.
2738  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2739  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2740  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2741  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2742  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2743  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2744  // Integer operations.
2745  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2746  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2747  .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2748  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2749  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2750  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2751  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2752  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2753  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2754  .Default([](Operation *op) { return std::nullopt; });
2755  if (!maybeKind) {
2756  return std::nullopt;
2757  }
2758 
2759  bool useOnlyFiniteValue = false;
2760  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2761  if (fmfOpInterface) {
2762  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2763  useOnlyFiniteValue =
2764  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2765  }
2766 
2767  // Builder only used as helper for attribute creation.
2768  OpBuilder b(op->getContext());
2769  Type resultType = op->getResult(0).getType();
2770 
2771  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2772  useOnlyFiniteValue);
2773 }
2774 
2775 /// Returns the identity value associated with an AtomicRMWKind op.
2776 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2777  OpBuilder &builder, Location loc,
2778  bool useOnlyFiniteValue) {
2779  auto attr =
2780  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2781  return arith::ConstantOp::create(builder, loc, attr);
2782 }
2783 
2784 /// Return the value obtained by applying the reduction operation kind
2785 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2786 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2787  Location loc, Value lhs, Value rhs) {
2788  switch (op) {
2789  case AtomicRMWKind::addf:
2790  return arith::AddFOp::create(builder, loc, lhs, rhs);
2791  case AtomicRMWKind::addi:
2792  return arith::AddIOp::create(builder, loc, lhs, rhs);
2793  case AtomicRMWKind::mulf:
2794  return arith::MulFOp::create(builder, loc, lhs, rhs);
2795  case AtomicRMWKind::muli:
2796  return arith::MulIOp::create(builder, loc, lhs, rhs);
2797  case AtomicRMWKind::maximumf:
2798  return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2799  case AtomicRMWKind::minimumf:
2800  return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2801  case AtomicRMWKind::maxnumf:
2802  return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2803  case AtomicRMWKind::minnumf:
2804  return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2805  case AtomicRMWKind::maxs:
2806  return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2807  case AtomicRMWKind::mins:
2808  return arith::MinSIOp::create(builder, loc, lhs, rhs);
2809  case AtomicRMWKind::maxu:
2810  return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2811  case AtomicRMWKind::minu:
2812  return arith::MinUIOp::create(builder, loc, lhs, rhs);
2813  case AtomicRMWKind::ori:
2814  return arith::OrIOp::create(builder, loc, lhs, rhs);
2815  case AtomicRMWKind::andi:
2816  return arith::AndIOp::create(builder, loc, lhs, rhs);
2817  case AtomicRMWKind::xori:
2818  return arith::XOrIOp::create(builder, loc, lhs, rhs);
2819  // TODO: Add remaining reduction operations.
2820  default:
2821  (void)emitOptionalError(loc, "Reduction operation type not supported");
2822  break;
2823  }
2824  return nullptr;
2825 }
2826 
2827 //===----------------------------------------------------------------------===//
2828 // TableGen'd op method definitions
2829 //===----------------------------------------------------------------------===//
2830 
2831 #define GET_OP_CLASSES
2832 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2833 
2834 //===----------------------------------------------------------------------===//
2835 // TableGen'd enum attribute definitions
2836 //===----------------------------------------------------------------------===//
2837 
2838 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:711
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1426
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1362
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1967
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:672
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1376
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1348
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:771
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:753
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:141
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:149
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1441
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1819
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1717
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1398
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:129
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:952
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1369
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:170
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1985
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1384
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1342
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:433
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1412
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1252::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2176
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2149
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:629
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:662
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:831
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition: ArithOps.cpp:330
static bool classof(Operation *op)
Definition: ArithOps.cpp:347
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:324
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:353
static bool classof(Operation *op)
Definition: ArithOps.cpp:374
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:251
static bool classof(Operation *op)
Definition: ArithOps.cpp:318
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2734
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1939
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2666
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2786
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2776
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:533
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2443
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
Definition: PatternMatch.h:317
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes