MLIR  21.0.0git
ArithOps.cpp
Go to the documentation of this file.
1 //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cassert>
10 #include <cstdint>
11 #include <functional>
12 #include <utility>
13 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
25 
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 using namespace mlir;
36 using namespace mlir::arith;
37 
38 //===----------------------------------------------------------------------===//
39 // Pattern helpers
40 //===----------------------------------------------------------------------===//
41 
42 static IntegerAttr
44  Attribute rhs,
45  function_ref<APInt(const APInt &, const APInt &)> binFn) {
46  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48  APInt value = binFn(lhsVal, rhsVal);
49  return IntegerAttr::get(res.getType(), value);
50 }
51 
52 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53  Attribute lhs, Attribute rhs) {
54  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55 }
56 
57 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58  Attribute lhs, Attribute rhs) {
59  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60 }
61 
62 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63  Attribute lhs, Attribute rhs) {
64  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65 }
66 
67 // Merge overflow flags from 2 ops, selecting the most conservative combination.
68 static IntegerOverflowFlagsAttr
69 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70  IntegerOverflowFlagsAttr val2) {
71  return IntegerOverflowFlagsAttr::get(val1.getContext(),
72  val1.getValue() & val2.getValue());
73 }
74 
75 /// Invert an integer comparison predicate.
76 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77  switch (pred) {
78  case arith::CmpIPredicate::eq:
79  return arith::CmpIPredicate::ne;
80  case arith::CmpIPredicate::ne:
81  return arith::CmpIPredicate::eq;
82  case arith::CmpIPredicate::slt:
83  return arith::CmpIPredicate::sge;
84  case arith::CmpIPredicate::sle:
85  return arith::CmpIPredicate::sgt;
86  case arith::CmpIPredicate::sgt:
87  return arith::CmpIPredicate::sle;
88  case arith::CmpIPredicate::sge:
89  return arith::CmpIPredicate::slt;
90  case arith::CmpIPredicate::ult:
91  return arith::CmpIPredicate::uge;
92  case arith::CmpIPredicate::ule:
93  return arith::CmpIPredicate::ugt;
94  case arith::CmpIPredicate::ugt:
95  return arith::CmpIPredicate::ule;
96  case arith::CmpIPredicate::uge:
97  return arith::CmpIPredicate::ult;
98  }
99  llvm_unreachable("unknown cmpi predicate kind");
100 }
101 
102 /// Equivalent to
103 /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104 ///
105 /// Not possible to implement as chain of calls as this would introduce a
106 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107 /// on the LLVM dialect and on translation to LLVM.
108 static llvm::RoundingMode
109 convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110  switch (roundingMode) {
111  case RoundingMode::downward:
112  return llvm::RoundingMode::TowardNegative;
113  case RoundingMode::to_nearest_away:
114  return llvm::RoundingMode::NearestTiesToAway;
115  case RoundingMode::to_nearest_even:
116  return llvm::RoundingMode::NearestTiesToEven;
117  case RoundingMode::toward_zero:
118  return llvm::RoundingMode::TowardZero;
119  case RoundingMode::upward:
120  return llvm::RoundingMode::TowardPositive;
121  }
122  llvm_unreachable("Unhandled rounding mode");
123 }
124 
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126  return arith::CmpIPredicateAttr::get(pred.getContext(),
127  invertPredicate(pred.getValue()));
128 }
129 
130 static int64_t getScalarOrElementWidth(Type type) {
131  Type elemTy = getElementTypeOrSelf(type);
132  if (elemTy.isIntOrFloat())
133  return elemTy.getIntOrFloatBitWidth();
134 
135  return -1;
136 }
137 
138 static int64_t getScalarOrElementWidth(Value value) {
139  return getScalarOrElementWidth(value.getType());
140 }
141 
142 static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143  APInt value;
144  if (matchPattern(attr, m_ConstantInt(&value)))
145  return value;
146 
147  return failure();
148 }
149 
150 static Attribute getBoolAttribute(Type type, bool value) {
151  auto boolAttr = BoolAttr::get(type.getContext(), value);
152  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153  if (!shapedType)
154  return boolAttr;
155  return DenseElementsAttr::get(shapedType, boolAttr);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // TableGen'd canonicalization patterns
160 //===----------------------------------------------------------------------===//
161 
162 namespace {
163 #include "ArithCanonicalization.inc"
164 } // namespace
165 
166 //===----------------------------------------------------------------------===//
167 // Common helpers
168 //===----------------------------------------------------------------------===//
169 
170 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 static Type getI1SameShape(Type type) {
172  auto i1Type = IntegerType::get(type.getContext(), 1);
173  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174  return shapedType.cloneWith(std::nullopt, i1Type);
175  if (llvm::isa<UnrankedTensorType>(type))
176  return UnrankedTensorType::get(i1Type);
177  return i1Type;
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // ConstantOp
182 //===----------------------------------------------------------------------===//
183 
184 void arith::ConstantOp::getAsmResultNames(
185  function_ref<void(Value, StringRef)> setNameFn) {
186  auto type = getType();
187  if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188  auto intType = llvm::dyn_cast<IntegerType>(type);
189 
190  // Sugar i1 constants with 'true' and 'false'.
191  if (intType && intType.getWidth() == 1)
192  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193 
194  // Otherwise, build a complex name with the value and type.
195  SmallString<32> specialNameBuffer;
196  llvm::raw_svector_ostream specialName(specialNameBuffer);
197  specialName << 'c' << intCst.getValue();
198  if (intType)
199  specialName << '_' << type;
200  setNameFn(getResult(), specialName.str());
201  } else {
202  setNameFn(getResult(), "cst");
203  }
204 }
205 
206 /// TODO: disallow arith.constant to return anything other than signless integer
207 /// or float like.
208 LogicalResult arith::ConstantOp::verify() {
209  auto type = getType();
210  // Integer values must be signless.
211  if (llvm::isa<IntegerType>(type) &&
212  !llvm::cast<IntegerType>(type).isSignless())
213  return emitOpError("integer return type must be signless");
214  // Any float or elements attribute are acceptable.
215  if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216  return emitOpError(
217  "value must be an integer, float, or elements attribute");
218  }
219 
220  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
221  // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
222  // However, this would most likely require updating the lowerings to LLVM.
223  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224  return emitOpError(
225  "intializing scalable vectors with elements attribute is not supported"
226  " unless it's a vector splat");
227  return success();
228 }
229 
230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231  // The value's type must be the same as the provided type.
232  auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233  if (!typedAttr || typedAttr.getType() != type)
234  return false;
235  // Integer values must be signless.
236  if (llvm::isa<IntegerType>(type) &&
237  !llvm::cast<IntegerType>(type).isSignless())
238  return false;
239  // Integer, float, and element attributes are buildable.
240  return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241 }
242 
243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
244  Type type, Location loc) {
245  if (isBuildableWith(value, type))
246  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
247  return nullptr;
248 }
249 
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251 
253  int64_t value, unsigned width) {
254  auto type = builder.getIntegerType(width);
255  arith::ConstantOp::build(builder, result, type,
256  builder.getIntegerAttr(type, value));
257 }
258 
260  Type type, int64_t value) {
261  arith::ConstantOp::build(builder, result, type,
262  builder.getIntegerAttr(type, value));
263 }
264 
266  Type type, const APInt &value) {
267  arith::ConstantOp::build(builder, result, type,
268  builder.getIntegerAttr(type, value));
269 }
270 
272  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
273  return constOp.getType().isSignlessInteger();
274  return false;
275 }
276 
278  FloatType type, const APFloat &value) {
279  arith::ConstantOp::build(builder, result, type,
280  builder.getFloatAttr(type, value));
281 }
282 
284  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
285  return llvm::isa<FloatType>(constOp.getType());
286  return false;
287 }
288 
290  int64_t value) {
291  arith::ConstantOp::build(builder, result, builder.getIndexType(),
292  builder.getIndexAttr(value));
293 }
294 
296  if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
297  return constOp.getType().isIndex();
298  return false;
299 }
300 
302  Type type) {
303  // TODO: Incorporate this check to `FloatAttr::get*`.
304  assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
305  "type doesn't have a zero representation");
306  TypedAttr zeroAttr = builder.getZeroAttr(type);
307  assert(zeroAttr && "unsupported type for zero attribute");
308  return builder.create<arith::ConstantOp>(loc, zeroAttr);
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // AddIOp
313 //===----------------------------------------------------------------------===//
314 
315 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
316  // addi(x, 0) -> x
317  if (matchPattern(adaptor.getRhs(), m_Zero()))
318  return getLhs();
319 
320  // addi(subi(a, b), b) -> a
321  if (auto sub = getLhs().getDefiningOp<SubIOp>())
322  if (getRhs() == sub.getRhs())
323  return sub.getLhs();
324 
325  // addi(b, subi(a, b)) -> a
326  if (auto sub = getRhs().getDefiningOp<SubIOp>())
327  if (getLhs() == sub.getRhs())
328  return sub.getLhs();
329 
330  return constFoldBinaryOp<IntegerAttr>(
331  adaptor.getOperands(),
332  [](APInt a, const APInt &b) { return std::move(a) + b; });
333 }
334 
335 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
336  MLIRContext *context) {
337  patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
338  AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // AddUIExtendedOp
343 //===----------------------------------------------------------------------===//
344 
345 std::optional<SmallVector<int64_t, 4>>
346 arith::AddUIExtendedOp::getShapeForUnroll() {
347  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
348  return llvm::to_vector<4>(vt.getShape());
349  return std::nullopt;
350 }
351 
352 // Returns the overflow bit, assuming that `sum` is the result of unsigned
353 // addition of `operand` and another number.
354 static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
355  return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
356 }
357 
358 LogicalResult
359 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
361  Type overflowTy = getOverflow().getType();
362  // addui_extended(x, 0) -> x, false
363  if (matchPattern(getRhs(), m_Zero())) {
364  Builder builder(getContext());
365  auto falseValue = builder.getZeroAttr(overflowTy);
366 
367  results.push_back(getLhs());
368  results.push_back(falseValue);
369  return success();
370  }
371 
372  // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
373  // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
374  // operands. If that succeeds, calculate the overflow bit based on the sum
375  // and the first (constant) operand, `lhs`.
376  if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
377  adaptor.getOperands(),
378  [](APInt a, const APInt &b) { return std::move(a) + b; })) {
379  Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
380  ArrayRef({sumAttr, adaptor.getLhs()}),
381  getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
383  if (!overflowAttr)
384  return failure();
385 
386  results.push_back(sumAttr);
387  results.push_back(overflowAttr);
388  return success();
389  }
390 
391  return failure();
392 }
393 
394 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
396  patterns.add<AddUIExtendedToAddI>(context);
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // SubIOp
401 //===----------------------------------------------------------------------===//
402 
403 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
404  // subi(x,x) -> 0
405  if (getOperand(0) == getOperand(1)) {
406  auto shapedType = dyn_cast<ShapedType>(getType());
407  // We can't generate a constant with a dynamic shaped tensor.
408  if (!shapedType || shapedType.hasStaticShape())
409  return Builder(getContext()).getZeroAttr(getType());
410  }
411  // subi(x,0) -> x
412  if (matchPattern(adaptor.getRhs(), m_Zero()))
413  return getLhs();
414 
415  if (auto add = getLhs().getDefiningOp<AddIOp>()) {
416  // subi(addi(a, b), b) -> a
417  if (getRhs() == add.getRhs())
418  return add.getLhs();
419  // subi(addi(a, b), a) -> b
420  if (getRhs() == add.getLhs())
421  return add.getRhs();
422  }
423 
424  return constFoldBinaryOp<IntegerAttr>(
425  adaptor.getOperands(),
426  [](APInt a, const APInt &b) { return std::move(a) - b; });
427 }
428 
429 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
430  MLIRContext *context) {
431  patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
432  SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
433  SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // MulIOp
438 //===----------------------------------------------------------------------===//
439 
440 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
441  // muli(x, 0) -> 0
442  if (matchPattern(adaptor.getRhs(), m_Zero()))
443  return getRhs();
444  // muli(x, 1) -> x
445  if (matchPattern(adaptor.getRhs(), m_One()))
446  return getLhs();
447  // TODO: Handle the overflow case.
448 
449  // default folder
450  return constFoldBinaryOp<IntegerAttr>(
451  adaptor.getOperands(),
452  [](const APInt &a, const APInt &b) { return a * b; });
453 }
454 
455 void arith::MulIOp::getAsmResultNames(
456  function_ref<void(Value, StringRef)> setNameFn) {
457  if (!isa<IndexType>(getType()))
458  return;
459 
460  // Match vector.vscale by name to avoid depending on the vector dialect (which
461  // is a circular dependency).
462  auto isVscale = [](Operation *op) {
463  return op && op->getName().getStringRef() == "vector.vscale";
464  };
465 
466  IntegerAttr baseValue;
467  auto isVscaleExpr = [&](Value a, Value b) {
468  return matchPattern(a, m_Constant(&baseValue)) &&
469  isVscale(b.getDefiningOp());
470  };
471 
472  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
473  return;
474 
475  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
476  SmallString<32> specialNameBuffer;
477  llvm::raw_svector_ostream specialName(specialNameBuffer);
478  specialName << 'c' << baseValue.getInt() << "_vscale";
479  setNameFn(getResult(), specialName.str());
480 }
481 
482 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
483  MLIRContext *context) {
484  patterns.add<MulIMulIConstant>(context);
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // MulSIExtendedOp
489 //===----------------------------------------------------------------------===//
490 
491 std::optional<SmallVector<int64_t, 4>>
492 arith::MulSIExtendedOp::getShapeForUnroll() {
493  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
494  return llvm::to_vector<4>(vt.getShape());
495  return std::nullopt;
496 }
497 
498 LogicalResult
499 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
501  // mulsi_extended(x, 0) -> 0, 0
502  if (matchPattern(adaptor.getRhs(), m_Zero())) {
503  Attribute zero = adaptor.getRhs();
504  results.push_back(zero);
505  results.push_back(zero);
506  return success();
507  }
508 
509  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
510  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
511  adaptor.getOperands(),
512  [](const APInt &a, const APInt &b) { return a * b; })) {
513  // Invoke the constant fold helper again to calculate the 'high' result.
514  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
515  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
516  return llvm::APIntOps::mulhs(a, b);
517  });
518  assert(highAttr && "Unexpected constant-folding failure");
519 
520  results.push_back(lowAttr);
521  results.push_back(highAttr);
522  return success();
523  }
524 
525  return failure();
526 }
527 
528 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
530  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // MulUIExtendedOp
535 //===----------------------------------------------------------------------===//
536 
537 std::optional<SmallVector<int64_t, 4>>
538 arith::MulUIExtendedOp::getShapeForUnroll() {
539  if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
540  return llvm::to_vector<4>(vt.getShape());
541  return std::nullopt;
542 }
543 
544 LogicalResult
545 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
547  // mului_extended(x, 0) -> 0, 0
548  if (matchPattern(adaptor.getRhs(), m_Zero())) {
549  Attribute zero = adaptor.getRhs();
550  results.push_back(zero);
551  results.push_back(zero);
552  return success();
553  }
554 
555  // mului_extended(x, 1) -> x, 0
556  if (matchPattern(adaptor.getRhs(), m_One())) {
557  Builder builder(getContext());
558  Attribute zero = builder.getZeroAttr(getLhs().getType());
559  results.push_back(getLhs());
560  results.push_back(zero);
561  return success();
562  }
563 
564  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
565  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
566  adaptor.getOperands(),
567  [](const APInt &a, const APInt &b) { return a * b; })) {
568  // Invoke the constant fold helper again to calculate the 'high' result.
569  Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
570  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
571  return llvm::APIntOps::mulhu(a, b);
572  });
573  assert(highAttr && "Unexpected constant-folding failure");
574 
575  results.push_back(lowAttr);
576  results.push_back(highAttr);
577  return success();
578  }
579 
580  return failure();
581 }
582 
583 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
585  patterns.add<MulUIExtendedToMulI>(context);
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // DivUIOp
590 //===----------------------------------------------------------------------===//
591 
592 /// Fold `(a * b) / b -> a`
593 static Value foldDivMul(Value lhs, Value rhs,
594  arith::IntegerOverflowFlags ovfFlags) {
595  auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
596  if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
597  return {};
598 
599  if (mul.getLhs() == rhs)
600  return mul.getRhs();
601 
602  if (mul.getRhs() == rhs)
603  return mul.getLhs();
604 
605  return {};
606 }
607 
608 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
609  // divui (x, 1) -> x.
610  if (matchPattern(adaptor.getRhs(), m_One()))
611  return getLhs();
612 
613  // (a * b) / b -> a
614  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
615  return val;
616 
617  // Don't fold if it would require a division by zero.
618  bool div0 = false;
619  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
620  [&](APInt a, const APInt &b) {
621  if (div0 || !b) {
622  div0 = true;
623  return a;
624  }
625  return a.udiv(b);
626  });
627 
628  return div0 ? Attribute() : result;
629 }
630 
631 /// Returns whether an unsigned division by `divisor` is speculatable.
633  // X / 0 => UB
634  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
636 
638 }
639 
640 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
641  return getDivUISpeculatability(getRhs());
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // DivSIOp
646 //===----------------------------------------------------------------------===//
647 
648 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
649  // divsi (x, 1) -> x.
650  if (matchPattern(adaptor.getRhs(), m_One()))
651  return getLhs();
652 
653  // (a * b) / b -> a
654  if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
655  return val;
656 
657  // Don't fold if it would overflow or if it requires a division by zero.
658  bool overflowOrDiv0 = false;
659  auto result = constFoldBinaryOp<IntegerAttr>(
660  adaptor.getOperands(), [&](APInt a, const APInt &b) {
661  if (overflowOrDiv0 || !b) {
662  overflowOrDiv0 = true;
663  return a;
664  }
665  return a.sdiv_ov(b, overflowOrDiv0);
666  });
667 
668  return overflowOrDiv0 ? Attribute() : result;
669 }
670 
671 /// Returns whether a signed division by `divisor` is speculatable. This
672 /// function conservatively assumes that all signed division by -1 are not
673 /// speculatable.
675  // X / 0 => UB
676  // INT_MIN / -1 => UB
677  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
680 
682 }
683 
684 Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
685  return getDivSISpeculatability(getRhs());
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Ceil and floor division folding helpers
690 //===----------------------------------------------------------------------===//
691 
692 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
693  bool &overflow) {
694  // Returns (a-1)/b + 1
695  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
696  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
697  return val.sadd_ov(one, overflow);
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // CeilDivUIOp
702 //===----------------------------------------------------------------------===//
703 
704 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
705  // ceildivui (x, 1) -> x.
706  if (matchPattern(adaptor.getRhs(), m_One()))
707  return getLhs();
708 
709  bool overflowOrDiv0 = false;
710  auto result = constFoldBinaryOp<IntegerAttr>(
711  adaptor.getOperands(), [&](APInt a, const APInt &b) {
712  if (overflowOrDiv0 || !b) {
713  overflowOrDiv0 = true;
714  return a;
715  }
716  APInt quotient = a.udiv(b);
717  if (!a.urem(b))
718  return quotient;
719  APInt one(a.getBitWidth(), 1, true);
720  return quotient.uadd_ov(one, overflowOrDiv0);
721  });
722 
723  return overflowOrDiv0 ? Attribute() : result;
724 }
725 
726 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
727  return getDivUISpeculatability(getRhs());
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // CeilDivSIOp
732 //===----------------------------------------------------------------------===//
733 
734 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
735  // ceildivsi (x, 1) -> x.
736  if (matchPattern(adaptor.getRhs(), m_One()))
737  return getLhs();
738 
739  // Don't fold if it would overflow or if it requires a division by zero.
740  // TODO: This hook won't fold operations where a = MININT, because
741  // negating MININT overflows. This can be improved.
742  bool overflowOrDiv0 = false;
743  auto result = constFoldBinaryOp<IntegerAttr>(
744  adaptor.getOperands(), [&](APInt a, const APInt &b) {
745  if (overflowOrDiv0 || !b) {
746  overflowOrDiv0 = true;
747  return a;
748  }
749  if (!a)
750  return a;
751  // After this point we know that neither a or b are zero.
752  unsigned bits = a.getBitWidth();
753  APInt zero = APInt::getZero(bits);
754  bool aGtZero = a.sgt(zero);
755  bool bGtZero = b.sgt(zero);
756  if (aGtZero && bGtZero) {
757  // Both positive, return ceil(a, b).
758  return signedCeilNonnegInputs(a, b, overflowOrDiv0);
759  }
760 
761  // No folding happens if any of the intermediate arithmetic operations
762  // overflows.
763  bool overflowNegA = false;
764  bool overflowNegB = false;
765  bool overflowDiv = false;
766  bool overflowNegRes = false;
767  if (!aGtZero && !bGtZero) {
768  // Both negative, return ceil(-a, -b).
769  APInt posA = zero.ssub_ov(a, overflowNegA);
770  APInt posB = zero.ssub_ov(b, overflowNegB);
771  APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
772  overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
773  return res;
774  }
775  if (!aGtZero && bGtZero) {
776  // A is negative, b is positive, return - ( -a / b).
777  APInt posA = zero.ssub_ov(a, overflowNegA);
778  APInt div = posA.sdiv_ov(b, overflowDiv);
779  APInt res = zero.ssub_ov(div, overflowNegRes);
780  overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
781  return res;
782  }
783  // A is positive, b is negative, return - (a / -b).
784  APInt posB = zero.ssub_ov(b, overflowNegB);
785  APInt div = a.sdiv_ov(posB, overflowDiv);
786  APInt res = zero.ssub_ov(div, overflowNegRes);
787 
788  overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
789  return res;
790  });
791 
792  return overflowOrDiv0 ? Attribute() : result;
793 }
794 
795 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
796  return getDivSISpeculatability(getRhs());
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // FloorDivSIOp
801 //===----------------------------------------------------------------------===//
802 
803 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
804  // floordivsi (x, 1) -> x.
805  if (matchPattern(adaptor.getRhs(), m_One()))
806  return getLhs();
807 
808  // Don't fold if it would overflow or if it requires a division by zero.
809  bool overflowOrDiv = false;
810  auto result = constFoldBinaryOp<IntegerAttr>(
811  adaptor.getOperands(), [&](APInt a, const APInt &b) {
812  if (b.isZero()) {
813  overflowOrDiv = true;
814  return a;
815  }
816  return a.sfloordiv_ov(b, overflowOrDiv);
817  });
818 
819  return overflowOrDiv ? Attribute() : result;
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // RemUIOp
824 //===----------------------------------------------------------------------===//
825 
826 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
827  // remui (x, 1) -> 0.
828  if (matchPattern(adaptor.getRhs(), m_One()))
829  return Builder(getContext()).getZeroAttr(getType());
830 
831  // Don't fold if it would require a division by zero.
832  bool div0 = false;
833  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
834  [&](APInt a, const APInt &b) {
835  if (div0 || b.isZero()) {
836  div0 = true;
837  return a;
838  }
839  return a.urem(b);
840  });
841 
842  return div0 ? Attribute() : result;
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // RemSIOp
847 //===----------------------------------------------------------------------===//
848 
849 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
850  // remsi (x, 1) -> 0.
851  if (matchPattern(adaptor.getRhs(), m_One()))
852  return Builder(getContext()).getZeroAttr(getType());
853 
854  // Don't fold if it would require a division by zero.
855  bool div0 = false;
856  auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
857  [&](APInt a, const APInt &b) {
858  if (div0 || b.isZero()) {
859  div0 = true;
860  return a;
861  }
862  return a.srem(b);
863  });
864 
865  return div0 ? Attribute() : result;
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // AndIOp
870 //===----------------------------------------------------------------------===//
871 
872 /// Fold `and(a, and(a, b))` to `and(a, b)`
873 static Value foldAndIofAndI(arith::AndIOp op) {
874  for (bool reversePrev : {false, true}) {
875  auto prev = (reversePrev ? op.getRhs() : op.getLhs())
876  .getDefiningOp<arith::AndIOp>();
877  if (!prev)
878  continue;
879 
880  Value other = (reversePrev ? op.getLhs() : op.getRhs());
881  if (other != prev.getLhs() && other != prev.getRhs())
882  continue;
883 
884  return prev.getResult();
885  }
886  return {};
887 }
888 
889 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
890  /// and(x, 0) -> 0
891  if (matchPattern(adaptor.getRhs(), m_Zero()))
892  return getRhs();
893  /// and(x, allOnes) -> x
894  APInt intValue;
895  if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
896  intValue.isAllOnes())
897  return getLhs();
898  /// and(x, not(x)) -> 0
899  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
900  m_ConstantInt(&intValue))) &&
901  intValue.isAllOnes())
902  return Builder(getContext()).getZeroAttr(getType());
903  /// and(not(x), x) -> 0
904  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
905  m_ConstantInt(&intValue))) &&
906  intValue.isAllOnes())
907  return Builder(getContext()).getZeroAttr(getType());
908 
909  /// and(a, and(a, b)) -> and(a, b)
910  if (Value result = foldAndIofAndI(*this))
911  return result;
912 
913  return constFoldBinaryOp<IntegerAttr>(
914  adaptor.getOperands(),
915  [](APInt a, const APInt &b) { return std::move(a) & b; });
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // OrIOp
920 //===----------------------------------------------------------------------===//
921 
922 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
923  if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
924  /// or(x, 0) -> x
925  if (rhsVal.isZero())
926  return getLhs();
927  /// or(x, <all ones>) -> <all ones>
928  if (rhsVal.isAllOnes())
929  return adaptor.getRhs();
930  }
931 
932  APInt intValue;
933  /// or(x, xor(x, 1)) -> 1
934  if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
935  m_ConstantInt(&intValue))) &&
936  intValue.isAllOnes())
937  return getRhs().getDefiningOp<XOrIOp>().getRhs();
938  /// or(xor(x, 1), x) -> 1
939  if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
940  m_ConstantInt(&intValue))) &&
941  intValue.isAllOnes())
942  return getLhs().getDefiningOp<XOrIOp>().getRhs();
943 
944  return constFoldBinaryOp<IntegerAttr>(
945  adaptor.getOperands(),
946  [](APInt a, const APInt &b) { return std::move(a) | b; });
947 }
948 
949 //===----------------------------------------------------------------------===//
950 // XOrIOp
951 //===----------------------------------------------------------------------===//
952 
953 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
954  /// xor(x, 0) -> x
955  if (matchPattern(adaptor.getRhs(), m_Zero()))
956  return getLhs();
957  /// xor(x, x) -> 0
958  if (getLhs() == getRhs())
959  return Builder(getContext()).getZeroAttr(getType());
960  /// xor(xor(x, a), a) -> x
961  /// xor(xor(a, x), a) -> x
962  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
963  if (prev.getRhs() == getRhs())
964  return prev.getLhs();
965  if (prev.getLhs() == getRhs())
966  return prev.getRhs();
967  }
968  /// xor(a, xor(x, a)) -> x
969  /// xor(a, xor(a, x)) -> x
970  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
971  if (prev.getRhs() == getLhs())
972  return prev.getLhs();
973  if (prev.getLhs() == getLhs())
974  return prev.getRhs();
975  }
976 
977  return constFoldBinaryOp<IntegerAttr>(
978  adaptor.getOperands(),
979  [](APInt a, const APInt &b) { return std::move(a) ^ b; });
980 }
981 
982 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
983  MLIRContext *context) {
984  patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
985 }
986 
987 //===----------------------------------------------------------------------===//
988 // NegFOp
989 //===----------------------------------------------------------------------===//
990 
991 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
992  /// negf(negf(x)) -> x
993  if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
994  return op.getOperand();
995  return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
996  [](const APFloat &a) { return -a; });
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // AddFOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1004  // addf(x, -0) -> x
1005  if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1006  return getLhs();
1007 
1008  return constFoldBinaryOp<FloatAttr>(
1009  adaptor.getOperands(),
1010  [](const APFloat &a, const APFloat &b) { return a + b; });
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // SubFOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1018  // subf(x, +0) -> x
1019  if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1020  return getLhs();
1021 
1022  return constFoldBinaryOp<FloatAttr>(
1023  adaptor.getOperands(),
1024  [](const APFloat &a, const APFloat &b) { return a - b; });
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // MaximumFOp
1029 //===----------------------------------------------------------------------===//
1030 
1031 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1032  // maximumf(x,x) -> x
1033  if (getLhs() == getRhs())
1034  return getRhs();
1035 
1036  // maximumf(x, -inf) -> x
1037  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1038  return getLhs();
1039 
1040  return constFoldBinaryOp<FloatAttr>(
1041  adaptor.getOperands(),
1042  [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1043 }
1044 
1045 //===----------------------------------------------------------------------===//
1046 // MaxNumFOp
1047 //===----------------------------------------------------------------------===//
1048 
1049 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1050  // maxnumf(x,x) -> x
1051  if (getLhs() == getRhs())
1052  return getRhs();
1053 
1054  // maxnumf(x, NaN) -> x
1055  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1056  return getLhs();
1057 
1058  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // MaxSIOp
1063 //===----------------------------------------------------------------------===//
1064 
1065 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1066  // maxsi(x,x) -> x
1067  if (getLhs() == getRhs())
1068  return getRhs();
1069 
1070  if (APInt intValue;
1071  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1072  // maxsi(x,MAX_INT) -> MAX_INT
1073  if (intValue.isMaxSignedValue())
1074  return getRhs();
1075  // maxsi(x, MIN_INT) -> x
1076  if (intValue.isMinSignedValue())
1077  return getLhs();
1078  }
1079 
1080  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1081  [](const APInt &a, const APInt &b) {
1082  return llvm::APIntOps::smax(a, b);
1083  });
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // MaxUIOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1091  // maxui(x,x) -> x
1092  if (getLhs() == getRhs())
1093  return getRhs();
1094 
1095  if (APInt intValue;
1096  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1097  // maxui(x,MAX_INT) -> MAX_INT
1098  if (intValue.isMaxValue())
1099  return getRhs();
1100  // maxui(x, MIN_INT) -> x
1101  if (intValue.isMinValue())
1102  return getLhs();
1103  }
1104 
1105  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1106  [](const APInt &a, const APInt &b) {
1107  return llvm::APIntOps::umax(a, b);
1108  });
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // MinimumFOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1116  // minimumf(x,x) -> x
1117  if (getLhs() == getRhs())
1118  return getRhs();
1119 
1120  // minimumf(x, +inf) -> x
1121  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1122  return getLhs();
1123 
1124  return constFoldBinaryOp<FloatAttr>(
1125  adaptor.getOperands(),
1126  [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // MinNumFOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1134  // minnumf(x,x) -> x
1135  if (getLhs() == getRhs())
1136  return getRhs();
1137 
1138  // minnumf(x, NaN) -> x
1139  if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1140  return getLhs();
1141 
1142  return constFoldBinaryOp<FloatAttr>(
1143  adaptor.getOperands(),
1144  [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // MinSIOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1152  // minsi(x,x) -> x
1153  if (getLhs() == getRhs())
1154  return getRhs();
1155 
1156  if (APInt intValue;
1157  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1158  // minsi(x,MIN_INT) -> MIN_INT
1159  if (intValue.isMinSignedValue())
1160  return getRhs();
1161  // minsi(x, MAX_INT) -> x
1162  if (intValue.isMaxSignedValue())
1163  return getLhs();
1164  }
1165 
1166  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1167  [](const APInt &a, const APInt &b) {
1168  return llvm::APIntOps::smin(a, b);
1169  });
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // MinUIOp
1174 //===----------------------------------------------------------------------===//
1175 
1176 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1177  // minui(x,x) -> x
1178  if (getLhs() == getRhs())
1179  return getRhs();
1180 
1181  if (APInt intValue;
1182  matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1183  // minui(x,MIN_INT) -> MIN_INT
1184  if (intValue.isMinValue())
1185  return getRhs();
1186  // minui(x, MAX_INT) -> x
1187  if (intValue.isMaxValue())
1188  return getLhs();
1189  }
1190 
1191  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1192  [](const APInt &a, const APInt &b) {
1193  return llvm::APIntOps::umin(a, b);
1194  });
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // MulFOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1202  // mulf(x, 1) -> x
1203  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1204  return getLhs();
1205 
1206  return constFoldBinaryOp<FloatAttr>(
1207  adaptor.getOperands(),
1208  [](const APFloat &a, const APFloat &b) { return a * b; });
1209 }
1210 
1211 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1212  MLIRContext *context) {
1213  patterns.add<MulFOfNegF>(context);
1214 }
1215 
1216 //===----------------------------------------------------------------------===//
1217 // DivFOp
1218 //===----------------------------------------------------------------------===//
1219 
1220 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1221  // divf(x, 1) -> x
1222  if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1223  return getLhs();
1224 
1225  return constFoldBinaryOp<FloatAttr>(
1226  adaptor.getOperands(),
1227  [](const APFloat &a, const APFloat &b) { return a / b; });
1228 }
1229 
1230 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1231  MLIRContext *context) {
1232  patterns.add<DivFOfNegF>(context);
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // RemFOp
1237 //===----------------------------------------------------------------------===//
1238 
1239 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1240  return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1241  [](const APFloat &a, const APFloat &b) {
1242  APFloat result(a);
1243  // APFloat::mod() offers the remainder
1244  // behavior we want, i.e. the result has
1245  // the sign of LHS operand.
1246  (void)result.mod(b);
1247  return result;
1248  });
1249 }
1250 
1251 //===----------------------------------------------------------------------===//
1252 // Utility functions for verifying cast ops
1253 //===----------------------------------------------------------------------===//
1254 
1255 template <typename... Types>
1256 using type_list = std::tuple<Types...> *;
1257 
1258 /// Returns a non-null type only if the provided type is one of the allowed
1259 /// types or one of the allowed shaped types of the allowed types. Returns the
1260 /// element type if a valid shaped type is provided.
1261 template <typename... ShapedTypes, typename... ElementTypes>
1264  if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1265  return {};
1266 
1267  auto underlyingType = getElementTypeOrSelf(type);
1268  if (!llvm::isa<ElementTypes...>(underlyingType))
1269  return {};
1270 
1271  return underlyingType;
1272 }
1273 
1274 /// Get allowed underlying types for vectors and tensors.
1275 template <typename... ElementTypes>
1276 static Type getTypeIfLike(Type type) {
1279 }
1280 
1281 /// Get allowed underlying types for vectors, tensors, and memrefs.
1282 template <typename... ElementTypes>
1284  return getUnderlyingType(type,
1287 }
1288 
1289 /// Return false if both types are ranked tensor with mismatching encoding.
1290 static bool hasSameEncoding(Type typeA, Type typeB) {
1291  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1292  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1293  if (!rankedTensorA || !rankedTensorB)
1294  return true;
1295  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1296 }
1297 
1298 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1299  if (inputs.size() != 1 || outputs.size() != 1)
1300  return false;
1301  if (!hasSameEncoding(inputs.front(), outputs.front()))
1302  return false;
1303  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // Verifiers for integer and floating point extension/truncation ops
1308 //===----------------------------------------------------------------------===//
1309 
1310 // Extend ops can only extend to a wider type.
1311 template <typename ValType, typename Op>
1312 static LogicalResult verifyExtOp(Op op) {
1313  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1314  Type dstType = getElementTypeOrSelf(op.getType());
1315 
1316  if (llvm::cast<ValType>(srcType).getWidth() >=
1317  llvm::cast<ValType>(dstType).getWidth())
1318  return op.emitError("result type ")
1319  << dstType << " must be wider than operand type " << srcType;
1320 
1321  return success();
1322 }
1323 
1324 // Truncate ops can only truncate to a shorter type.
1325 template <typename ValType, typename Op>
1326 static LogicalResult verifyTruncateOp(Op op) {
1327  Type srcType = getElementTypeOrSelf(op.getIn().getType());
1328  Type dstType = getElementTypeOrSelf(op.getType());
1329 
1330  if (llvm::cast<ValType>(srcType).getWidth() <=
1331  llvm::cast<ValType>(dstType).getWidth())
1332  return op.emitError("result type ")
1333  << dstType << " must be shorter than operand type " << srcType;
1334 
1335  return success();
1336 }
1337 
1338 /// Validate a cast that changes the width of a type.
1339 template <template <typename> class WidthComparator, typename... ElementTypes>
1340 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1341  if (!areValidCastInputsAndOutputs(inputs, outputs))
1342  return false;
1343 
1344  auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1345  auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1346  if (!srcType || !dstType)
1347  return false;
1348 
1349  return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1350  srcType.getIntOrFloatBitWidth());
1351 }
1352 
1353 /// Attempts to convert `sourceValue` to an APFloat value with
1354 /// `targetSemantics` and `roundingMode`, without any information loss.
1355 static FailureOr<APFloat> convertFloatValue(
1356  APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1357  llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1358  bool losesInfo = false;
1359  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1360  if (losesInfo || status != APFloat::opOK)
1361  return failure();
1362 
1363  return sourceValue;
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // ExtUIOp
1368 //===----------------------------------------------------------------------===//
1369 
1370 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1371  if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1372  getInMutable().assign(lhs.getIn());
1373  return getResult();
1374  }
1375 
1376  Type resType = getElementTypeOrSelf(getType());
1377  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1378  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1379  adaptor.getOperands(), getType(),
1380  [bitWidth](const APInt &a, bool &castStatus) {
1381  return a.zext(bitWidth);
1382  });
1383 }
1384 
1385 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1386  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1387 }
1388 
1389 LogicalResult arith::ExtUIOp::verify() {
1390  return verifyExtOp<IntegerType>(*this);
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // ExtSIOp
1395 //===----------------------------------------------------------------------===//
1396 
1397 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1398  if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1399  getInMutable().assign(lhs.getIn());
1400  return getResult();
1401  }
1402 
1403  Type resType = getElementTypeOrSelf(getType());
1404  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1405  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1406  adaptor.getOperands(), getType(),
1407  [bitWidth](const APInt &a, bool &castStatus) {
1408  return a.sext(bitWidth);
1409  });
1410 }
1411 
1412 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1413  return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1414 }
1415 
1416 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1417  MLIRContext *context) {
1418  patterns.add<ExtSIOfExtUI>(context);
1419 }
1420 
1421 LogicalResult arith::ExtSIOp::verify() {
1422  return verifyExtOp<IntegerType>(*this);
1423 }
1424 
1425 //===----------------------------------------------------------------------===//
1426 // ExtFOp
1427 //===----------------------------------------------------------------------===//
1428 
1429 /// Fold extension of float constants when there is no information loss due the
1430 /// difference in fp semantics.
1431 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1432  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1433  if (truncFOp.getOperand().getType() == getType()) {
1434  arith::FastMathFlags truncFMF =
1435  truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1436  bool isTruncContract =
1437  bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1438  arith::FastMathFlags extFMF =
1439  getFastmath().value_or(arith::FastMathFlags::none);
1440  bool isExtContract =
1441  bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1442  if (isTruncContract && isExtContract) {
1443  return truncFOp.getOperand();
1444  }
1445  }
1446  }
1447 
1448  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1449  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1450  return constFoldCastOp<FloatAttr, FloatAttr>(
1451  adaptor.getOperands(), getType(),
1452  [&targetSemantics](const APFloat &a, bool &castStatus) {
1453  FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1454  if (failed(result)) {
1455  castStatus = false;
1456  return a;
1457  }
1458  return *result;
1459  });
1460 }
1461 
1462 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1463  return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1464 }
1465 
1466 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1467 
1468 //===----------------------------------------------------------------------===//
1469 // ScalingExtFOp
1470 //===----------------------------------------------------------------------===//
1471 
1472 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1473  TypeRange outputs) {
1474  return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1475 }
1476 
1477 LogicalResult arith::ScalingExtFOp::verify() {
1478  return verifyExtOp<FloatType>(*this);
1479 }
1480 
1481 //===----------------------------------------------------------------------===//
1482 // TruncIOp
1483 //===----------------------------------------------------------------------===//
1484 
1485 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1486  if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1487  matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1488  Value src = getOperand().getDefiningOp()->getOperand(0);
1489  Type srcType = getElementTypeOrSelf(src.getType());
1490  Type dstType = getElementTypeOrSelf(getType());
1491  // trunci(zexti(a)) -> trunci(a)
1492  // trunci(sexti(a)) -> trunci(a)
1493  if (llvm::cast<IntegerType>(srcType).getWidth() >
1494  llvm::cast<IntegerType>(dstType).getWidth()) {
1495  setOperand(src);
1496  return getResult();
1497  }
1498 
1499  // trunci(zexti(a)) -> a
1500  // trunci(sexti(a)) -> a
1501  if (srcType == dstType)
1502  return src;
1503  }
1504 
1505  // trunci(trunci(a)) -> trunci(a))
1506  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1507  setOperand(getOperand().getDefiningOp()->getOperand(0));
1508  return getResult();
1509  }
1510 
1511  Type resType = getElementTypeOrSelf(getType());
1512  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1513  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1514  adaptor.getOperands(), getType(),
1515  [bitWidth](const APInt &a, bool &castStatus) {
1516  return a.trunc(bitWidth);
1517  });
1518 }
1519 
1520 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1521  return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1522 }
1523 
1524 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1525  MLIRContext *context) {
1526  patterns
1527  .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1528  context);
1529 }
1530 
1531 LogicalResult arith::TruncIOp::verify() {
1532  return verifyTruncateOp<IntegerType>(*this);
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // TruncFOp
1537 //===----------------------------------------------------------------------===//
1538 
1539 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
1540 /// can be represented without precision loss.
1541 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1542  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1543  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1544  Value src = extOp.getIn();
1545  auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1546  auto intermediateType =
1547  cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1548  // Check if the srcType is representable in the intermediateType.
1549  if (llvm::APFloatBase::isRepresentableBy(
1550  srcType.getFloatSemantics(),
1551  intermediateType.getFloatSemantics())) {
1552  // truncf(extf(a)) -> truncf(a)
1553  if (srcType.getWidth() > resElemType.getWidth()) {
1554  setOperand(src);
1555  return getResult();
1556  }
1557 
1558  // truncf(extf(a)) -> a
1559  if (srcType == resElemType)
1560  return src;
1561  }
1562  }
1563 
1564  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1565  return constFoldCastOp<FloatAttr, FloatAttr>(
1566  adaptor.getOperands(), getType(),
1567  [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1568  RoundingMode roundingMode =
1569  getRoundingmode().value_or(RoundingMode::to_nearest_even);
1570  llvm::RoundingMode llvmRoundingMode =
1571  convertArithRoundingModeToLLVMIR(roundingMode);
1572  FailureOr<APFloat> result =
1573  convertFloatValue(a, targetSemantics, llvmRoundingMode);
1574  if (failed(result)) {
1575  castStatus = false;
1576  return a;
1577  }
1578  return *result;
1579  });
1580 }
1581 
1582 void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1583  MLIRContext *context) {
1584  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1585 }
1586 
1587 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1588  return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1589 }
1590 
1591 LogicalResult arith::TruncFOp::verify() {
1592  return verifyTruncateOp<FloatType>(*this);
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // ScalingTruncFOp
1597 //===----------------------------------------------------------------------===//
1598 
1599 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1600  TypeRange outputs) {
1601  return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1602 }
1603 
1604 LogicalResult arith::ScalingTruncFOp::verify() {
1605  return verifyTruncateOp<FloatType>(*this);
1606 }
1607 
1608 //===----------------------------------------------------------------------===//
1609 // AndIOp
1610 //===----------------------------------------------------------------------===//
1611 
1612 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1613  MLIRContext *context) {
1614  patterns.add<AndOfExtUI, AndOfExtSI>(context);
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // OrIOp
1619 //===----------------------------------------------------------------------===//
1620 
1621 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1622  MLIRContext *context) {
1623  patterns.add<OrOfExtUI, OrOfExtSI>(context);
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // Verifiers for casts between integers and floats.
1628 //===----------------------------------------------------------------------===//
1629 
1630 template <typename From, typename To>
1631 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1632  if (!areValidCastInputsAndOutputs(inputs, outputs))
1633  return false;
1634 
1635  auto srcType = getTypeIfLike<From>(inputs.front());
1636  auto dstType = getTypeIfLike<To>(outputs.back());
1637 
1638  return srcType && dstType;
1639 }
1640 
1641 //===----------------------------------------------------------------------===//
1642 // UIToFPOp
1643 //===----------------------------------------------------------------------===//
1644 
1645 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1646  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1647 }
1648 
1649 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1650  Type resEleType = getElementTypeOrSelf(getType());
1651  return constFoldCastOp<IntegerAttr, FloatAttr>(
1652  adaptor.getOperands(), getType(),
1653  [&resEleType](const APInt &a, bool &castStatus) {
1654  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1655  APFloat apf(floatTy.getFloatSemantics(),
1656  APInt::getZero(floatTy.getWidth()));
1657  apf.convertFromAPInt(a, /*IsSigned=*/false,
1658  APFloat::rmNearestTiesToEven);
1659  return apf;
1660  });
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // SIToFPOp
1665 //===----------------------------------------------------------------------===//
1666 
1667 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1668  return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1669 }
1670 
1671 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1672  Type resEleType = getElementTypeOrSelf(getType());
1673  return constFoldCastOp<IntegerAttr, FloatAttr>(
1674  adaptor.getOperands(), getType(),
1675  [&resEleType](const APInt &a, bool &castStatus) {
1676  FloatType floatTy = llvm::cast<FloatType>(resEleType);
1677  APFloat apf(floatTy.getFloatSemantics(),
1678  APInt::getZero(floatTy.getWidth()));
1679  apf.convertFromAPInt(a, /*IsSigned=*/true,
1680  APFloat::rmNearestTiesToEven);
1681  return apf;
1682  });
1683 }
1684 
1685 //===----------------------------------------------------------------------===//
1686 // FPToUIOp
1687 //===----------------------------------------------------------------------===//
1688 
1689 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1690  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1691 }
1692 
1693 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1694  Type resType = getElementTypeOrSelf(getType());
1695  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1696  return constFoldCastOp<FloatAttr, IntegerAttr>(
1697  adaptor.getOperands(), getType(),
1698  [&bitWidth](const APFloat &a, bool &castStatus) {
1699  bool ignored;
1700  APSInt api(bitWidth, /*isUnsigned=*/true);
1701  castStatus = APFloat::opInvalidOp !=
1702  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1703  return api;
1704  });
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // FPToSIOp
1709 //===----------------------------------------------------------------------===//
1710 
1711 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1712  return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1713 }
1714 
1715 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1716  Type resType = getElementTypeOrSelf(getType());
1717  unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1718  return constFoldCastOp<FloatAttr, IntegerAttr>(
1719  adaptor.getOperands(), getType(),
1720  [&bitWidth](const APFloat &a, bool &castStatus) {
1721  bool ignored;
1722  APSInt api(bitWidth, /*isUnsigned=*/false);
1723  castStatus = APFloat::opInvalidOp !=
1724  a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1725  return api;
1726  });
1727 }
1728 
1729 //===----------------------------------------------------------------------===//
1730 // IndexCastOp
1731 //===----------------------------------------------------------------------===//
1732 
1733 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1734  if (!areValidCastInputsAndOutputs(inputs, outputs))
1735  return false;
1736 
1737  auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1738  auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1739  if (!srcType || !dstType)
1740  return false;
1741 
1742  return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1743  (srcType.isSignlessInteger() && dstType.isIndex());
1744 }
1745 
1746 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1747  TypeRange outputs) {
1748  return areIndexCastCompatible(inputs, outputs);
1749 }
1750 
1751 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1752  // index_cast(constant) -> constant
1753  unsigned resultBitwidth = 64; // Default for index integer attributes.
1754  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1755  resultBitwidth = intTy.getWidth();
1756 
1757  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1758  adaptor.getOperands(), getType(),
1759  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1760  return a.sextOrTrunc(resultBitwidth);
1761  });
1762 }
1763 
1764 void arith::IndexCastOp::getCanonicalizationPatterns(
1765  RewritePatternSet &patterns, MLIRContext *context) {
1766  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1767 }
1768 
1769 //===----------------------------------------------------------------------===//
1770 // IndexCastUIOp
1771 //===----------------------------------------------------------------------===//
1772 
1773 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1774  TypeRange outputs) {
1775  return areIndexCastCompatible(inputs, outputs);
1776 }
1777 
1778 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1779  // index_castui(constant) -> constant
1780  unsigned resultBitwidth = 64; // Default for index integer attributes.
1781  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1782  resultBitwidth = intTy.getWidth();
1783 
1784  return constFoldCastOp<IntegerAttr, IntegerAttr>(
1785  adaptor.getOperands(), getType(),
1786  [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1787  return a.zextOrTrunc(resultBitwidth);
1788  });
1789 }
1790 
1791 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1792  RewritePatternSet &patterns, MLIRContext *context) {
1793  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // BitcastOp
1798 //===----------------------------------------------------------------------===//
1799 
1800 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1801  if (!areValidCastInputsAndOutputs(inputs, outputs))
1802  return false;
1803 
1804  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1805  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1806  if (!srcType || !dstType)
1807  return false;
1808 
1809  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1810 }
1811 
1812 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1813  auto resType = getType();
1814  auto operand = adaptor.getIn();
1815  if (!operand)
1816  return {};
1817 
1818  /// Bitcast dense elements.
1819  if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1820  return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1821  /// Other shaped types unhandled.
1822  if (llvm::isa<ShapedType>(resType))
1823  return {};
1824 
1825  /// Bitcast poison.
1826  if (llvm::isa<ub::PoisonAttr>(operand))
1827  return ub::PoisonAttr::get(getContext());
1828 
1829  /// Bitcast integer or float to integer or float.
1830  APInt bits = llvm::isa<FloatAttr>(operand)
1831  ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1832  : llvm::cast<IntegerAttr>(operand).getValue();
1833  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1834  "trying to fold on broken IR: operands have incompatible types");
1835 
1836  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1837  return FloatAttr::get(resType,
1838  APFloat(resFloatType.getFloatSemantics(), bits));
1839  return IntegerAttr::get(resType, bits);
1840 }
1841 
1842 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1843  MLIRContext *context) {
1844  patterns.add<BitcastOfBitcast>(context);
1845 }
1846 
1847 //===----------------------------------------------------------------------===//
1848 // CmpIOp
1849 //===----------------------------------------------------------------------===//
1850 
1851 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1852 /// comparison predicates.
1853 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1854  const APInt &lhs, const APInt &rhs) {
1855  switch (predicate) {
1856  case arith::CmpIPredicate::eq:
1857  return lhs.eq(rhs);
1858  case arith::CmpIPredicate::ne:
1859  return lhs.ne(rhs);
1860  case arith::CmpIPredicate::slt:
1861  return lhs.slt(rhs);
1862  case arith::CmpIPredicate::sle:
1863  return lhs.sle(rhs);
1864  case arith::CmpIPredicate::sgt:
1865  return lhs.sgt(rhs);
1866  case arith::CmpIPredicate::sge:
1867  return lhs.sge(rhs);
1868  case arith::CmpIPredicate::ult:
1869  return lhs.ult(rhs);
1870  case arith::CmpIPredicate::ule:
1871  return lhs.ule(rhs);
1872  case arith::CmpIPredicate::ugt:
1873  return lhs.ugt(rhs);
1874  case arith::CmpIPredicate::uge:
1875  return lhs.uge(rhs);
1876  }
1877  llvm_unreachable("unknown cmpi predicate kind");
1878 }
1879 
1880 /// Returns true if the predicate is true for two equal operands.
1881 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1882  switch (predicate) {
1883  case arith::CmpIPredicate::eq:
1884  case arith::CmpIPredicate::sle:
1885  case arith::CmpIPredicate::sge:
1886  case arith::CmpIPredicate::ule:
1887  case arith::CmpIPredicate::uge:
1888  return true;
1889  case arith::CmpIPredicate::ne:
1890  case arith::CmpIPredicate::slt:
1891  case arith::CmpIPredicate::sgt:
1892  case arith::CmpIPredicate::ult:
1893  case arith::CmpIPredicate::ugt:
1894  return false;
1895  }
1896  llvm_unreachable("unknown cmpi predicate kind");
1897 }
1898 
1899 static std::optional<int64_t> getIntegerWidth(Type t) {
1900  if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1901  return intType.getWidth();
1902  }
1903  if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1904  return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1905  }
1906  return std::nullopt;
1907 }
1908 
1909 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1910  // cmpi(pred, x, x)
1911  if (getLhs() == getRhs()) {
1912  auto val = applyCmpPredicateToEqualOperands(getPredicate());
1913  return getBoolAttribute(getType(), val);
1914  }
1915 
1916  if (matchPattern(adaptor.getRhs(), m_Zero())) {
1917  if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1918  // extsi(%x : i1 -> iN) != 0 -> %x
1919  std::optional<int64_t> integerWidth =
1920  getIntegerWidth(extOp.getOperand().getType());
1921  if (integerWidth && integerWidth.value() == 1 &&
1922  getPredicate() == arith::CmpIPredicate::ne)
1923  return extOp.getOperand();
1924  }
1925  if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1926  // extui(%x : i1 -> iN) != 0 -> %x
1927  std::optional<int64_t> integerWidth =
1928  getIntegerWidth(extOp.getOperand().getType());
1929  if (integerWidth && integerWidth.value() == 1 &&
1930  getPredicate() == arith::CmpIPredicate::ne)
1931  return extOp.getOperand();
1932  }
1933 
1934  // arith.cmpi ne, %val, %zero : i1 -> %val
1935  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1936  getPredicate() == arith::CmpIPredicate::ne)
1937  return getLhs();
1938  }
1939 
1940  if (matchPattern(adaptor.getRhs(), m_One())) {
1941  // arith.cmpi eq, %val, %one : i1 -> %val
1942  if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1943  getPredicate() == arith::CmpIPredicate::eq)
1944  return getLhs();
1945  }
1946 
1947  // Move constant to the right side.
1948  if (adaptor.getLhs() && !adaptor.getRhs()) {
1949  // Do not use invertPredicate, as it will change eq to ne and vice versa.
1950  using Pred = CmpIPredicate;
1951  const std::pair<Pred, Pred> invPreds[] = {
1952  {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1953  {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1954  {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1955  {Pred::ne, Pred::ne},
1956  };
1957  Pred origPred = getPredicate();
1958  for (auto pred : invPreds) {
1959  if (origPred == pred.first) {
1960  setPredicate(pred.second);
1961  Value lhs = getLhs();
1962  Value rhs = getRhs();
1963  getLhsMutable().assign(rhs);
1964  getRhsMutable().assign(lhs);
1965  return getResult();
1966  }
1967  }
1968  llvm_unreachable("unknown cmpi predicate kind");
1969  }
1970 
1971  // We are moving constants to the right side; So if lhs is constant rhs is
1972  // guaranteed to be a constant.
1973  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1974  return constFoldBinaryOp<IntegerAttr>(
1975  adaptor.getOperands(), getI1SameShape(lhs.getType()),
1976  [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1977  return APInt(1,
1978  static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1979  });
1980  }
1981 
1982  return {};
1983 }
1984 
1985 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1986  MLIRContext *context) {
1987  patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1988 }
1989 
1990 //===----------------------------------------------------------------------===//
1991 // CmpFOp
1992 //===----------------------------------------------------------------------===//
1993 
1994 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1995 /// comparison predicates.
1996 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1997  const APFloat &lhs, const APFloat &rhs) {
1998  auto cmpResult = lhs.compare(rhs);
1999  switch (predicate) {
2000  case arith::CmpFPredicate::AlwaysFalse:
2001  return false;
2002  case arith::CmpFPredicate::OEQ:
2003  return cmpResult == APFloat::cmpEqual;
2004  case arith::CmpFPredicate::OGT:
2005  return cmpResult == APFloat::cmpGreaterThan;
2006  case arith::CmpFPredicate::OGE:
2007  return cmpResult == APFloat::cmpGreaterThan ||
2008  cmpResult == APFloat::cmpEqual;
2009  case arith::CmpFPredicate::OLT:
2010  return cmpResult == APFloat::cmpLessThan;
2011  case arith::CmpFPredicate::OLE:
2012  return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2013  case arith::CmpFPredicate::ONE:
2014  return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2015  case arith::CmpFPredicate::ORD:
2016  return cmpResult != APFloat::cmpUnordered;
2017  case arith::CmpFPredicate::UEQ:
2018  return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2019  case arith::CmpFPredicate::UGT:
2020  return cmpResult == APFloat::cmpUnordered ||
2021  cmpResult == APFloat::cmpGreaterThan;
2022  case arith::CmpFPredicate::UGE:
2023  return cmpResult == APFloat::cmpUnordered ||
2024  cmpResult == APFloat::cmpGreaterThan ||
2025  cmpResult == APFloat::cmpEqual;
2026  case arith::CmpFPredicate::ULT:
2027  return cmpResult == APFloat::cmpUnordered ||
2028  cmpResult == APFloat::cmpLessThan;
2029  case arith::CmpFPredicate::ULE:
2030  return cmpResult == APFloat::cmpUnordered ||
2031  cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2032  case arith::CmpFPredicate::UNE:
2033  return cmpResult != APFloat::cmpEqual;
2034  case arith::CmpFPredicate::UNO:
2035  return cmpResult == APFloat::cmpUnordered;
2036  case arith::CmpFPredicate::AlwaysTrue:
2037  return true;
2038  }
2039  llvm_unreachable("unknown cmpf predicate kind");
2040 }
2041 
2042 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2043  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2044  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2045 
2046  // If one operand is NaN, making them both NaN does not change the result.
2047  if (lhs && lhs.getValue().isNaN())
2048  rhs = lhs;
2049  if (rhs && rhs.getValue().isNaN())
2050  lhs = rhs;
2051 
2052  if (!lhs || !rhs)
2053  return {};
2054 
2055  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2056  return BoolAttr::get(getContext(), val);
2057 }
2058 
2059 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2060 public:
2062 
2063  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2064  bool isUnsigned) {
2065  using namespace arith;
2066  switch (pred) {
2067  case CmpFPredicate::UEQ:
2068  case CmpFPredicate::OEQ:
2069  return CmpIPredicate::eq;
2070  case CmpFPredicate::UGT:
2071  case CmpFPredicate::OGT:
2072  return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2073  case CmpFPredicate::UGE:
2074  case CmpFPredicate::OGE:
2075  return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2076  case CmpFPredicate::ULT:
2077  case CmpFPredicate::OLT:
2078  return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2079  case CmpFPredicate::ULE:
2080  case CmpFPredicate::OLE:
2081  return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2082  case CmpFPredicate::UNE:
2083  case CmpFPredicate::ONE:
2084  return CmpIPredicate::ne;
2085  default:
2086  llvm_unreachable("Unexpected predicate!");
2087  }
2088  }
2089 
2090  LogicalResult matchAndRewrite(CmpFOp op,
2091  PatternRewriter &rewriter) const override {
2092  FloatAttr flt;
2093  if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2094  return failure();
2095 
2096  const APFloat &rhs = flt.getValue();
2097 
2098  // Don't attempt to fold a nan.
2099  if (rhs.isNaN())
2100  return failure();
2101 
2102  // Get the width of the mantissa. We don't want to hack on conversions that
2103  // might lose information from the integer, e.g. "i64 -> float"
2104  FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2105  int mantissaWidth = floatTy.getFPMantissaWidth();
2106  if (mantissaWidth <= 0)
2107  return failure();
2108 
2109  bool isUnsigned;
2110  Value intVal;
2111 
2112  if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2113  isUnsigned = false;
2114  intVal = si.getIn();
2115  } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2116  isUnsigned = true;
2117  intVal = ui.getIn();
2118  } else {
2119  return failure();
2120  }
2121 
2122  // Check to see that the input is converted from an integer type that is
2123  // small enough that preserves all bits.
2124  auto intTy = llvm::cast<IntegerType>(intVal.getType());
2125  auto intWidth = intTy.getWidth();
2126 
2127  // Number of bits representing values, as opposed to the sign
2128  auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2129 
2130  // Following test does NOT adjust intWidth downwards for signed inputs,
2131  // because the most negative value still requires all the mantissa bits
2132  // to distinguish it from one less than that value.
2133  if ((int)intWidth > mantissaWidth) {
2134  // Conversion would lose accuracy. Check if loss can impact comparison.
2135  int exponent = ilogb(rhs);
2136  if (exponent == APFloat::IEK_Inf) {
2137  int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2138  if (maxExponent < (int)valueBits) {
2139  // Conversion could create infinity.
2140  return failure();
2141  }
2142  } else {
2143  // Note that if rhs is zero or NaN, then Exp is negative
2144  // and first condition is trivially false.
2145  if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2146  // Conversion could affect comparison.
2147  return failure();
2148  }
2149  }
2150  }
2151 
2152  // Convert to equivalent cmpi predicate
2153  CmpIPredicate pred;
2154  switch (op.getPredicate()) {
2155  case CmpFPredicate::ORD:
2156  // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2157  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2158  /*width=*/1);
2159  return success();
2160  case CmpFPredicate::UNO:
2161  // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2162  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2163  /*width=*/1);
2164  return success();
2165  default:
2166  pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2167  break;
2168  }
2169 
2170  if (!isUnsigned) {
2171  // If the rhs value is > SignedMax, fold the comparison. This handles
2172  // +INF and large values.
2173  APFloat signedMax(rhs.getSemantics());
2174  signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2175  APFloat::rmNearestTiesToEven);
2176  if (signedMax < rhs) { // smax < 13123.0
2177  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2178  pred == CmpIPredicate::sle)
2179  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2180  /*width=*/1);
2181  else
2182  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2183  /*width=*/1);
2184  return success();
2185  }
2186  } else {
2187  // If the rhs value is > UnsignedMax, fold the comparison. This handles
2188  // +INF and large values.
2189  APFloat unsignedMax(rhs.getSemantics());
2190  unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2191  APFloat::rmNearestTiesToEven);
2192  if (unsignedMax < rhs) { // umax < 13123.0
2193  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2194  pred == CmpIPredicate::ule)
2195  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2196  /*width=*/1);
2197  else
2198  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2199  /*width=*/1);
2200  return success();
2201  }
2202  }
2203 
2204  if (!isUnsigned) {
2205  // See if the rhs value is < SignedMin.
2206  APFloat signedMin(rhs.getSemantics());
2207  signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2208  APFloat::rmNearestTiesToEven);
2209  if (signedMin > rhs) { // smin > 12312.0
2210  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2211  pred == CmpIPredicate::sge)
2212  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2213  /*width=*/1);
2214  else
2215  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2216  /*width=*/1);
2217  return success();
2218  }
2219  } else {
2220  // See if the rhs value is < UnsignedMin.
2221  APFloat unsignedMin(rhs.getSemantics());
2222  unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2223  APFloat::rmNearestTiesToEven);
2224  if (unsignedMin > rhs) { // umin > 12312.0
2225  if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2226  pred == CmpIPredicate::uge)
2227  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2228  /*width=*/1);
2229  else
2230  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2231  /*width=*/1);
2232  return success();
2233  }
2234  }
2235 
2236  // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2237  // [0, UMAX], but it may still be fractional. See if it is fractional by
2238  // casting the FP value to the integer value and back, checking for
2239  // equality. Don't do this for zero, because -0.0 is not fractional.
2240  bool ignored;
2241  APSInt rhsInt(intWidth, isUnsigned);
2242  if (APFloat::opInvalidOp ==
2243  rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2244  // Undefined behavior invoked - the destination type can't represent
2245  // the input constant.
2246  return failure();
2247  }
2248 
2249  if (!rhs.isZero()) {
2250  APFloat apf(floatTy.getFloatSemantics(),
2251  APInt::getZero(floatTy.getWidth()));
2252  apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2253 
2254  bool equal = apf == rhs;
2255  if (!equal) {
2256  // If we had a comparison against a fractional value, we have to adjust
2257  // the compare predicate and sometimes the value. rhsInt is rounded
2258  // towards zero at this point.
2259  switch (pred) {
2260  case CmpIPredicate::ne: // (float)int != 4.4 --> true
2261  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2262  /*width=*/1);
2263  return success();
2264  case CmpIPredicate::eq: // (float)int == 4.4 --> false
2265  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2266  /*width=*/1);
2267  return success();
2268  case CmpIPredicate::ule:
2269  // (float)int <= 4.4 --> int <= 4
2270  // (float)int <= -4.4 --> false
2271  if (rhs.isNegative()) {
2272  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2273  /*width=*/1);
2274  return success();
2275  }
2276  break;
2277  case CmpIPredicate::sle:
2278  // (float)int <= 4.4 --> int <= 4
2279  // (float)int <= -4.4 --> int < -4
2280  if (rhs.isNegative())
2281  pred = CmpIPredicate::slt;
2282  break;
2283  case CmpIPredicate::ult:
2284  // (float)int < -4.4 --> false
2285  // (float)int < 4.4 --> int <= 4
2286  if (rhs.isNegative()) {
2287  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2288  /*width=*/1);
2289  return success();
2290  }
2291  pred = CmpIPredicate::ule;
2292  break;
2293  case CmpIPredicate::slt:
2294  // (float)int < -4.4 --> int < -4
2295  // (float)int < 4.4 --> int <= 4
2296  if (!rhs.isNegative())
2297  pred = CmpIPredicate::sle;
2298  break;
2299  case CmpIPredicate::ugt:
2300  // (float)int > 4.4 --> int > 4
2301  // (float)int > -4.4 --> true
2302  if (rhs.isNegative()) {
2303  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2304  /*width=*/1);
2305  return success();
2306  }
2307  break;
2308  case CmpIPredicate::sgt:
2309  // (float)int > 4.4 --> int > 4
2310  // (float)int > -4.4 --> int >= -4
2311  if (rhs.isNegative())
2312  pred = CmpIPredicate::sge;
2313  break;
2314  case CmpIPredicate::uge:
2315  // (float)int >= -4.4 --> true
2316  // (float)int >= 4.4 --> int > 4
2317  if (rhs.isNegative()) {
2318  rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2319  /*width=*/1);
2320  return success();
2321  }
2322  pred = CmpIPredicate::ugt;
2323  break;
2324  case CmpIPredicate::sge:
2325  // (float)int >= -4.4 --> int >= -4
2326  // (float)int >= 4.4 --> int > 4
2327  if (!rhs.isNegative())
2328  pred = CmpIPredicate::sgt;
2329  break;
2330  }
2331  }
2332  }
2333 
2334  // Lower this FP comparison into an appropriate integer version of the
2335  // comparison.
2336  rewriter.replaceOpWithNewOp<CmpIOp>(
2337  op, pred, intVal,
2338  rewriter.create<ConstantOp>(
2339  op.getLoc(), intVal.getType(),
2340  rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2341  return success();
2342  }
2343 };
2344 
2345 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2346  MLIRContext *context) {
2347  patterns.insert<CmpFIntToFPConst>(context);
2348 }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // SelectOp
2352 //===----------------------------------------------------------------------===//
2353 
2354 // select %arg, %c1, %c0 => extui %arg
2355 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2357 
2358  LogicalResult matchAndRewrite(arith::SelectOp op,
2359  PatternRewriter &rewriter) const override {
2360  // Cannot extui i1 to i1, or i1 to f32
2361  if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2362  return failure();
2363 
2364  // select %x, c1, %c0 => extui %arg
2365  if (matchPattern(op.getTrueValue(), m_One()) &&
2366  matchPattern(op.getFalseValue(), m_Zero())) {
2367  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2368  op.getCondition());
2369  return success();
2370  }
2371 
2372  // select %x, c0, %c1 => extui (xor %arg, true)
2373  if (matchPattern(op.getTrueValue(), m_Zero()) &&
2374  matchPattern(op.getFalseValue(), m_One())) {
2375  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2376  op, op.getType(),
2377  rewriter.create<arith::XOrIOp>(
2378  op.getLoc(), op.getCondition(),
2379  rewriter.create<arith::ConstantIntOp>(
2380  op.getLoc(), op.getCondition().getType(), 1)));
2381  return success();
2382  }
2383 
2384  return failure();
2385  }
2386 };
2387 
2388 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2389  MLIRContext *context) {
2390  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2391  SelectI1ToNot, SelectToExtUI>(context);
2392 }
2393 
2394 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2395  Value trueVal = getTrueValue();
2396  Value falseVal = getFalseValue();
2397  if (trueVal == falseVal)
2398  return trueVal;
2399 
2400  Value condition = getCondition();
2401 
2402  // select true, %0, %1 => %0
2403  if (matchPattern(adaptor.getCondition(), m_One()))
2404  return trueVal;
2405 
2406  // select false, %0, %1 => %1
2407  if (matchPattern(adaptor.getCondition(), m_Zero()))
2408  return falseVal;
2409 
2410  // If either operand is fully poisoned, return the other.
2411  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2412  return falseVal;
2413 
2414  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2415  return trueVal;
2416 
2417  // select %x, true, false => %x
2418  if (getType().isSignlessInteger(1) &&
2419  matchPattern(adaptor.getTrueValue(), m_One()) &&
2420  matchPattern(adaptor.getFalseValue(), m_Zero()))
2421  return condition;
2422 
2423  if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2424  auto pred = cmp.getPredicate();
2425  if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2426  auto cmpLhs = cmp.getLhs();
2427  auto cmpRhs = cmp.getRhs();
2428 
2429  // %0 = arith.cmpi eq, %arg0, %arg1
2430  // %1 = arith.select %0, %arg0, %arg1 => %arg1
2431 
2432  // %0 = arith.cmpi ne, %arg0, %arg1
2433  // %1 = arith.select %0, %arg0, %arg1 => %arg0
2434 
2435  if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2436  (cmpRhs == trueVal && cmpLhs == falseVal))
2437  return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2438  }
2439  }
2440 
2441  // Constant-fold constant operands over non-splat constant condition.
2442  // select %cst_vec, %cst0, %cst1 => %cst2
2443  if (auto cond =
2444  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2445  if (auto lhs =
2446  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2447  if (auto rhs =
2448  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2449  SmallVector<Attribute> results;
2450  results.reserve(static_cast<size_t>(cond.getNumElements()));
2451  auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2452  cond.value_end<BoolAttr>());
2453  auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2454  lhs.value_end<Attribute>());
2455  auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2456  rhs.value_end<Attribute>());
2457 
2458  for (auto [condVal, lhsVal, rhsVal] :
2459  llvm::zip_equal(condVals, lhsVals, rhsVals))
2460  results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2461 
2462  return DenseElementsAttr::get(lhs.getType(), results);
2463  }
2464  }
2465  }
2466 
2467  return nullptr;
2468 }
2469 
2470 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2471  Type conditionType, resultType;
2473  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2474  parser.parseOptionalAttrDict(result.attributes) ||
2475  parser.parseColonType(resultType))
2476  return failure();
2477 
2478  // Check for the explicit condition type if this is a masked tensor or vector.
2479  if (succeeded(parser.parseOptionalComma())) {
2480  conditionType = resultType;
2481  if (parser.parseType(resultType))
2482  return failure();
2483  } else {
2484  conditionType = parser.getBuilder().getI1Type();
2485  }
2486 
2487  result.addTypes(resultType);
2488  return parser.resolveOperands(operands,
2489  {conditionType, resultType, resultType},
2490  parser.getNameLoc(), result.operands);
2491 }
2492 
2494  p << " " << getOperands();
2495  p.printOptionalAttrDict((*this)->getAttrs());
2496  p << " : ";
2497  if (ShapedType condType =
2498  llvm::dyn_cast<ShapedType>(getCondition().getType()))
2499  p << condType << ", ";
2500  p << getType();
2501 }
2502 
2503 LogicalResult arith::SelectOp::verify() {
2504  Type conditionType = getCondition().getType();
2505  if (conditionType.isSignlessInteger(1))
2506  return success();
2507 
2508  // If the result type is a vector or tensor, the type can be a mask with the
2509  // same elements.
2510  Type resultType = getType();
2511  if (!llvm::isa<TensorType, VectorType>(resultType))
2512  return emitOpError() << "expected condition to be a signless i1, but got "
2513  << conditionType;
2514  Type shapedConditionType = getI1SameShape(resultType);
2515  if (conditionType != shapedConditionType) {
2516  return emitOpError() << "expected condition type to have the same shape "
2517  "as the result type, expected "
2518  << shapedConditionType << ", but got "
2519  << conditionType;
2520  }
2521  return success();
2522 }
2523 //===----------------------------------------------------------------------===//
2524 // ShLIOp
2525 //===----------------------------------------------------------------------===//
2526 
2527 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2528  // shli(x, 0) -> x
2529  if (matchPattern(adaptor.getRhs(), m_Zero()))
2530  return getLhs();
2531  // Don't fold if shifting more or equal than the bit width.
2532  bool bounded = false;
2533  auto result = constFoldBinaryOp<IntegerAttr>(
2534  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2535  bounded = b.ult(b.getBitWidth());
2536  return a.shl(b);
2537  });
2538  return bounded ? result : Attribute();
2539 }
2540 
2541 //===----------------------------------------------------------------------===//
2542 // ShRUIOp
2543 //===----------------------------------------------------------------------===//
2544 
2545 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2546  // shrui(x, 0) -> x
2547  if (matchPattern(adaptor.getRhs(), m_Zero()))
2548  return getLhs();
2549  // Don't fold if shifting more or equal than the bit width.
2550  bool bounded = false;
2551  auto result = constFoldBinaryOp<IntegerAttr>(
2552  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2553  bounded = b.ult(b.getBitWidth());
2554  return a.lshr(b);
2555  });
2556  return bounded ? result : Attribute();
2557 }
2558 
2559 //===----------------------------------------------------------------------===//
2560 // ShRSIOp
2561 //===----------------------------------------------------------------------===//
2562 
2563 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2564  // shrsi(x, 0) -> x
2565  if (matchPattern(adaptor.getRhs(), m_Zero()))
2566  return getLhs();
2567  // Don't fold if shifting more or equal than the bit width.
2568  bool bounded = false;
2569  auto result = constFoldBinaryOp<IntegerAttr>(
2570  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2571  bounded = b.ult(b.getBitWidth());
2572  return a.ashr(b);
2573  });
2574  return bounded ? result : Attribute();
2575 }
2576 
2577 //===----------------------------------------------------------------------===//
2578 // Atomic Enum
2579 //===----------------------------------------------------------------------===//
2580 
2581 /// Returns the identity value attribute associated with an AtomicRMWKind op.
2582 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2583  OpBuilder &builder, Location loc,
2584  bool useOnlyFiniteValue) {
2585  switch (kind) {
2586  case AtomicRMWKind::maximumf: {
2587  const llvm::fltSemantics &semantic =
2588  llvm::cast<FloatType>(resultType).getFloatSemantics();
2589  APFloat identity = useOnlyFiniteValue
2590  ? APFloat::getLargest(semantic, /*Negative=*/true)
2591  : APFloat::getInf(semantic, /*Negative=*/true);
2592  return builder.getFloatAttr(resultType, identity);
2593  }
2594  case AtomicRMWKind::maxnumf: {
2595  const llvm::fltSemantics &semantic =
2596  llvm::cast<FloatType>(resultType).getFloatSemantics();
2597  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2598  return builder.getFloatAttr(resultType, identity);
2599  }
2600  case AtomicRMWKind::addf:
2601  case AtomicRMWKind::addi:
2602  case AtomicRMWKind::maxu:
2603  case AtomicRMWKind::ori:
2604  return builder.getZeroAttr(resultType);
2605  case AtomicRMWKind::andi:
2606  return builder.getIntegerAttr(
2607  resultType,
2608  APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2609  case AtomicRMWKind::maxs:
2610  return builder.getIntegerAttr(
2611  resultType, APInt::getSignedMinValue(
2612  llvm::cast<IntegerType>(resultType).getWidth()));
2613  case AtomicRMWKind::minimumf: {
2614  const llvm::fltSemantics &semantic =
2615  llvm::cast<FloatType>(resultType).getFloatSemantics();
2616  APFloat identity = useOnlyFiniteValue
2617  ? APFloat::getLargest(semantic, /*Negative=*/false)
2618  : APFloat::getInf(semantic, /*Negative=*/false);
2619 
2620  return builder.getFloatAttr(resultType, identity);
2621  }
2622  case AtomicRMWKind::minnumf: {
2623  const llvm::fltSemantics &semantic =
2624  llvm::cast<FloatType>(resultType).getFloatSemantics();
2625  APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2626  return builder.getFloatAttr(resultType, identity);
2627  }
2628  case AtomicRMWKind::mins:
2629  return builder.getIntegerAttr(
2630  resultType, APInt::getSignedMaxValue(
2631  llvm::cast<IntegerType>(resultType).getWidth()));
2632  case AtomicRMWKind::minu:
2633  return builder.getIntegerAttr(
2634  resultType,
2635  APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2636  case AtomicRMWKind::muli:
2637  return builder.getIntegerAttr(resultType, 1);
2638  case AtomicRMWKind::mulf:
2639  return builder.getFloatAttr(resultType, 1);
2640  // TODO: Add remaining reduction operations.
2641  default:
2642  (void)emitOptionalError(loc, "Reduction operation type not supported");
2643  break;
2644  }
2645  return nullptr;
2646 }
2647 
2648 /// Return the identity numeric value associated to the give op.
2649 std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2650  std::optional<AtomicRMWKind> maybeKind =
2652  // Floating-point operations.
2653  .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2654  .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2655  .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2656  .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2657  .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2658  .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2659  // Integer operations.
2660  .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2661  .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2662  .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2663  .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2664  .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2665  .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2666  .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2667  .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2668  .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2669  .Default([](Operation *op) { return std::nullopt; });
2670  if (!maybeKind) {
2671  return std::nullopt;
2672  }
2673 
2674  bool useOnlyFiniteValue = false;
2675  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2676  if (fmfOpInterface) {
2677  arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2678  useOnlyFiniteValue =
2679  bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2680  }
2681 
2682  // Builder only used as helper for attribute creation.
2683  OpBuilder b(op->getContext());
2684  Type resultType = op->getResult(0).getType();
2685 
2686  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2687  useOnlyFiniteValue);
2688 }
2689 
2690 /// Returns the identity value associated with an AtomicRMWKind op.
2691 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2692  OpBuilder &builder, Location loc,
2693  bool useOnlyFiniteValue) {
2694  auto attr =
2695  getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2696  return builder.create<arith::ConstantOp>(loc, attr);
2697 }
2698 
2699 /// Return the value obtained by applying the reduction operation kind
2700 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2701 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2702  Location loc, Value lhs, Value rhs) {
2703  switch (op) {
2704  case AtomicRMWKind::addf:
2705  return builder.create<arith::AddFOp>(loc, lhs, rhs);
2706  case AtomicRMWKind::addi:
2707  return builder.create<arith::AddIOp>(loc, lhs, rhs);
2708  case AtomicRMWKind::mulf:
2709  return builder.create<arith::MulFOp>(loc, lhs, rhs);
2710  case AtomicRMWKind::muli:
2711  return builder.create<arith::MulIOp>(loc, lhs, rhs);
2712  case AtomicRMWKind::maximumf:
2713  return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2714  case AtomicRMWKind::minimumf:
2715  return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2716  case AtomicRMWKind::maxnumf:
2717  return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2718  case AtomicRMWKind::minnumf:
2719  return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2720  case AtomicRMWKind::maxs:
2721  return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2722  case AtomicRMWKind::mins:
2723  return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2724  case AtomicRMWKind::maxu:
2725  return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2726  case AtomicRMWKind::minu:
2727  return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2728  case AtomicRMWKind::ori:
2729  return builder.create<arith::OrIOp>(loc, lhs, rhs);
2730  case AtomicRMWKind::andi:
2731  return builder.create<arith::AndIOp>(loc, lhs, rhs);
2732  // TODO: Add remaining reduction operations.
2733  default:
2734  (void)emitOptionalError(loc, "Reduction operation type not supported");
2735  break;
2736  }
2737  return nullptr;
2738 }
2739 
2740 //===----------------------------------------------------------------------===//
2741 // TableGen'd op method definitions
2742 //===----------------------------------------------------------------------===//
2743 
2744 #define GET_OP_CLASSES
2745 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2746 
2747 //===----------------------------------------------------------------------===//
2748 // TableGen'd enum attribute definitions
2749 //===----------------------------------------------------------------------===//
2750 
2751 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition: ArithOps.cpp:632
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
Definition: ArithOps.cpp:1340
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:62
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition: ArithOps.cpp:69
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition: ArithOps.cpp:109
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
Definition: ArithOps.cpp:1276
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
Definition: ArithOps.cpp:1881
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition: ArithOps.cpp:593
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
Definition: ArithOps.cpp:1290
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
Definition: ArithOps.cpp:1262
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition: ArithOps.cpp:692
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition: ArithOps.cpp:674
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition: ArithOps.cpp:142
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:52
static Attribute getBoolAttribute(Type type, bool value)
Definition: ArithOps.cpp:150
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
Definition: ArithOps.cpp:1355
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1733
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1631
static LogicalResult verifyExtOp(Op op)
Definition: ArithOps.cpp:1312
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition: ArithOps.cpp:57
static int64_t getScalarOrElementWidth(Type type)
Definition: ArithOps.cpp:130
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition: ArithOps.cpp:873
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
Definition: ArithOps.cpp:1283
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition: ArithOps.cpp:171
static std::optional< int64_t > getIntegerWidth(Type t)
Definition: ArithOps.cpp:1899
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
Definition: ArithOps.cpp:1298
std::tuple< Types... > * type_list
Definition: ArithOps.cpp:1256
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition: ArithOps.cpp:43
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition: ArithOps.cpp:354
static LogicalResult verifyTruncateOp(Op op)
Definition: ArithOps.cpp:1326
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2090
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
Definition: ArithOps.cpp:2063
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:830
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static bool classof(Operation *op)
Definition: ArithOps.cpp:283
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition: ArithOps.cpp:277
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition: ArithOps.cpp:289
static bool classof(Operation *op)
Definition: ArithOps.cpp:295
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:54
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition: ArithOps.cpp:252
static bool classof(Operation *op)
Definition: ArithOps.cpp:271
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2649
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
Definition: ArithOps.cpp:1853
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2582
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2701
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2691
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:301
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition: Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition: Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition: Matchers.h:409
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition: Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition: Matchers.h:427
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition: Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition: Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
Definition: ArithOps.cpp:2358
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes