MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37/// Default rounding mode according to default LLVM floating-point environment.
38static constexpr llvm::RoundingMode kDefaultRoundingMode =
39 llvm::RoundingMode::NearestTiesToEven;
40
41//===----------------------------------------------------------------------===//
42// Pattern helpers
43//===----------------------------------------------------------------------===//
44
45static IntegerAttr
48 function_ref<APInt(const APInt &, const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.getType(), value);
53}
54
55static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
57 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
58}
59
60static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
62 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
63}
64
65static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
67 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
68}
69
70// Merge overflow flags from 2 ops, selecting the most conservative combination.
71static IntegerOverflowFlagsAttr
72mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
76}
77
78/// Invert an integer comparison predicate.
79arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
80 switch (pred) {
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
101 }
102 llvm_unreachable("unknown cmpi predicate kind");
103}
104
105/// Equivalent to
106/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
107///
108/// Not possible to implement as chain of calls as this would introduce a
109/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
110/// on the LLVM dialect and on translation to LLVM.
111static llvm::RoundingMode
112convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
113 if (!roundingMode)
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
126 }
127 llvm_unreachable("Unhandled rounding mode");
128}
129
130static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
132 invertPredicate(pred.getValue()));
133}
134
136 Type elemTy = getElementTypeOrSelf(type);
137 if (elemTy.isIntOrFloat())
138 return elemTy.getIntOrFloatBitWidth();
139
140 return -1;
141}
142
144 return getScalarOrElementWidth(value.getType());
145}
146
147static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
148 APInt value;
149 if (matchPattern(attr, m_ConstantInt(&value)))
150 return value;
151
152 return failure();
153}
154
155static Attribute getBoolAttribute(Type type, bool value) {
156 auto boolAttr = BoolAttr::get(type.getContext(), value);
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
158 if (!shapedType)
159 return boolAttr;
160 // DenseElementsAttr requires a static shape.
161 if (!shapedType.hasStaticShape())
162 return {};
163 return DenseElementsAttr::get(shapedType, boolAttr);
164}
165
166//===----------------------------------------------------------------------===//
167// TableGen'd canonicalization patterns
168//===----------------------------------------------------------------------===//
169
170namespace {
171#include "ArithCanonicalization.inc"
172} // namespace
173
174//===----------------------------------------------------------------------===//
175// Common helpers
176//===----------------------------------------------------------------------===//
177
178/// Return the type of the same shape (scalar, vector or tensor) containing i1.
180 auto i1Type = IntegerType::get(type.getContext(), 1);
181 if (auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
185 return i1Type;
186}
187
188//===----------------------------------------------------------------------===//
189// ConstantOp
190//===----------------------------------------------------------------------===//
191
192void arith::ConstantOp::getAsmResultNames(
193 function_ref<void(Value, StringRef)> setNameFn) {
194 auto type = getType();
195 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
197
198 // Sugar i1 constants with 'true' and 'false'.
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
201
202 // Otherwise, build a complex name with the value and type.
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName << 'c' << intCst.getValue();
206 if (intType)
207 specialName << '_' << type;
208 setNameFn(getResult(), specialName.str());
209 } else {
210 setNameFn(getResult(), "cst");
211 }
212}
213
214/// TODO: disallow arith.constant to return anything other than signless integer
215/// or float like.
216LogicalResult arith::ConstantOp::verify() {
217 auto type = getType();
218 // Integer values must be signless.
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError("integer return type must be signless");
222 // Any float or elements attribute are acceptable.
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
224 return emitOpError(
225 "value must be an integer, float, or elements attribute");
226 }
227
228 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
229 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
230 // However, this would most likely require updating the lowerings to LLVM.
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
232 return emitOpError(
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
235 return success();
236}
237
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
239 // The value's type must be the same as the provided type.
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
242 return false;
243 // Integer values must be signless.
244 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
245 if (!intType.isSignless())
246 return false;
247 }
248 // Integer, float, and element attributes are buildable.
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250}
251
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
256 return nullptr;
257}
258
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
260
262 int64_t value, unsigned width) {
263 auto type = builder.getIntegerType(width);
264 arith::ConstantOp::build(builder, result, type,
265 builder.getIntegerAttr(type, value));
266}
267
269 Location location,
271 unsigned width) {
272 mlir::OperationState state(location, getOperationName());
273 build(builder, state, value, width);
274 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
275 assert(result && "builder didn't return the right type");
276 return result;
277}
278
281 unsigned width) {
282 return create(builder, builder.getLoc(), value, width);
283}
284
286 Type type, int64_t value) {
287 arith::ConstantOp::build(builder, result, type,
288 builder.getIntegerAttr(type, value));
289}
290
292 Location location, Type type,
293 int64_t value) {
294 mlir::OperationState state(location, getOperationName());
295 build(builder, state, type, value);
296 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
297 assert(result && "builder didn't return the right type");
298 return result;
299}
300
302 Type type, int64_t value) {
303 return create(builder, builder.getLoc(), type, value);
304}
305
307 Type type, const APInt &value) {
308 arith::ConstantOp::build(builder, result, type,
309 builder.getIntegerAttr(type, value));
310}
311
313 Location location, Type type,
314 const APInt &value) {
315 mlir::OperationState state(location, getOperationName());
316 build(builder, state, type, value);
317 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
318 assert(result && "builder didn't return the right type");
319 return result;
320}
321
323 Type type,
324 const APInt &value) {
325 return create(builder, builder.getLoc(), type, value);
326}
327
329 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
331 return false;
332}
333
335 FloatType type, const APFloat &value) {
336 arith::ConstantOp::build(builder, result, type,
337 builder.getFloatAttr(type, value));
338}
339
341 Location location,
342 FloatType type,
343 const APFloat &value) {
344 mlir::OperationState state(location, getOperationName());
345 build(builder, state, type, value);
346 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
347 assert(result && "builder didn't return the right type");
348 return result;
349}
350
353 const APFloat &value) {
354 return create(builder, builder.getLoc(), type, value);
355}
356
358 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
360 return false;
361}
362
364 int64_t value) {
365 arith::ConstantOp::build(builder, result, builder.getIndexType(),
366 builder.getIndexAttr(value));
367}
368
370 Location location,
371 int64_t value) {
372 mlir::OperationState state(location, getOperationName());
373 build(builder, state, value);
374 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
375 assert(result && "builder didn't return the right type");
376 return result;
377}
378
381 return create(builder, builder.getLoc(), value);
382}
383
385 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
387 return false;
388}
389
391 Type type) {
392 // TODO: Incorporate this check to `FloatAttr::get*`.
393 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
394 "type doesn't have a zero representation");
395 TypedAttr zeroAttr = builder.getZeroAttr(type);
396 assert(zeroAttr && "unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
398}
399
400//===----------------------------------------------------------------------===//
401// AddIOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
405 // addi(x, 0) -> x
406 if (matchPattern(adaptor.getRhs(), m_Zero()))
407 return getLhs();
408
409 // addi(subi(a, b), b) -> a
410 if (auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
412 return sub.getLhs();
413
414 // addi(b, subi(a, b)) -> a
415 if (auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
417 return sub.getLhs();
418
420 adaptor.getOperands(),
421 [](APInt a, const APInt &b) { return std::move(a) + b; });
422}
423
424void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
425 MLIRContext *context) {
426 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428}
429
430//===----------------------------------------------------------------------===//
431// AddUIExtendedOp
432//===----------------------------------------------------------------------===//
433
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (auto vt = dyn_cast<VectorType>(getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
438 return std::nullopt;
439}
440
441// Returns the overflow bit, assuming that `sum` is the result of unsigned
442// addition of `operand` and another number.
443static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
445}
446
447LogicalResult
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
451 // addui_extended(x, 0) -> x, false
452 if (matchPattern(getRhs(), m_Zero())) {
453 Builder builder(getContext());
454 auto falseValue = builder.getZeroAttr(overflowTy);
455
456 results.push_back(getLhs());
457 results.push_back(falseValue);
458 return success();
459 }
460
461 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
462 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
463 // operands. If that succeeds, calculate the overflow bit based on the sum
464 // and the first (constant) operand, `lhs`.
465 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
466 adaptor.getOperands(),
467 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
468 // If any operand is poison, propagate poison to both results.
469 if (matchPattern(sumAttr, ub::m_Poison())) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
472 return success();
473 }
474 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
475 ArrayRef({sumAttr, adaptor.getLhs()}),
476 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
478 if (!overflowAttr)
479 return failure();
480
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
483 return success();
484 }
485
486 return failure();
487}
488
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.add<AddUIExtendedToAddI>(context);
492}
493
494//===----------------------------------------------------------------------===//
495// SubUIExtendedOp
496//===----------------------------------------------------------------------===//
497
498std::optional<SmallVector<int64_t, 4>>
499arith::SubUIExtendedOp::getShapeForUnroll() {
500 if (auto vt = dyn_cast<VectorType>(getType(0)))
501 return llvm::to_vector<4>(vt.getShape());
502 return std::nullopt;
503}
504
505// Returns the borrow bit, assuming `lhs` and `rhs` are operands of an unsigned
506// subtraction whose mathematical result underflows iff `lhs < rhs`.
507static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs) {
508 return lhs.ult(rhs) ? APInt::getAllOnes(1) : APInt::getZero(1);
509}
510
511LogicalResult
512arith::SubUIExtendedOp::fold(FoldAdaptor adaptor,
513 SmallVectorImpl<OpFoldResult> &results) {
514 Type borrowTy = getBorrow().getType();
515 // subui_extended(x, 0) -> x, false
516 if (matchPattern(getRhs(), m_Zero())) {
517 Builder builder(getContext());
518 auto falseValue = builder.getZeroAttr(borrowTy);
519
520 results.push_back(getLhs());
521 results.push_back(falseValue);
522 return success();
523 }
524
525 // subui_extended(x, x) -> 0, false
526 if (getLhs() == getRhs()) {
527 Builder builder(getContext());
528 auto zeroDiff = builder.getZeroAttr(getDiff().getType());
529 auto falseValue = builder.getZeroAttr(borrowTy);
530 if (!zeroDiff)
531 return failure();
532
533 results.push_back(zeroDiff);
534 results.push_back(falseValue);
535 return success();
536 }
537
538 // subui_extended(constant_a, constant_b) -> constant_diff, constant_borrow
539 if (Attribute diffAttr = constFoldBinaryOp<IntegerAttr>(
540 adaptor.getOperands(),
541 [](APInt a, const APInt &b) { return std::move(a) - b; })) {
542 // If any operand is poison, propagate poison to both results.
543 if (matchPattern(diffAttr, ub::m_Poison())) {
544 results.push_back(diffAttr);
545 results.push_back(diffAttr);
546 return success();
547 }
548 Attribute borrowAttr = constFoldBinaryOp<IntegerAttr>(
549 adaptor.getOperands(),
550 getI1SameShape(llvm::cast<TypedAttr>(diffAttr).getType()),
552 if (!borrowAttr)
553 return failure();
554
555 results.push_back(diffAttr);
556 results.push_back(borrowAttr);
557 return success();
558 }
559
560 return failure();
561}
562
563void arith::SubUIExtendedOp::getCanonicalizationPatterns(
564 RewritePatternSet &patterns, MLIRContext *context) {
565 patterns.add<SubUIExtendedToSubI>(context);
566}
567
568//===----------------------------------------------------------------------===//
569// SubIOp
570//===----------------------------------------------------------------------===//
571
572OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
573 // subi(x,x) -> 0
574 if (getOperand(0) == getOperand(1)) {
575 auto shapedType = dyn_cast<ShapedType>(getType());
576 // We can't generate a constant with a dynamic shaped tensor.
577 if (!shapedType || shapedType.hasStaticShape())
578 return Builder(getContext()).getZeroAttr(getType());
579 }
580 // subi(x,0) -> x
581 if (matchPattern(adaptor.getRhs(), m_Zero()))
582 return getLhs();
583
584 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
585 // subi(addi(a, b), b) -> a
586 if (getRhs() == add.getRhs())
587 return add.getLhs();
588 // subi(addi(a, b), a) -> b
589 if (getRhs() == add.getLhs())
590 return add.getRhs();
591 }
592
593 // subi(a, subi(a, b)) -> b
594 if (auto sub = getRhs().getDefiningOp<SubIOp>())
595 if (getLhs() == sub.getLhs())
596 return sub.getRhs();
597
599 adaptor.getOperands(),
600 [](APInt a, const APInt &b) { return std::move(a) - b; });
601}
602
603void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604 MLIRContext *context) {
605 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
606 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
607 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
608}
609
610//===----------------------------------------------------------------------===//
611// MulIOp
612//===----------------------------------------------------------------------===//
613
614OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
615 // muli(x, 0) -> 0
616 if (matchPattern(adaptor.getRhs(), m_Zero()))
617 return getRhs();
618 // muli(x, 1) -> x
619 if (matchPattern(adaptor.getRhs(), m_One()))
620 return getLhs();
621 // TODO: Handle the overflow case.
622
623 // default folder
625 adaptor.getOperands(),
626 [](const APInt &a, const APInt &b) { return a * b; });
627}
628
629void arith::MulIOp::getAsmResultNames(
630 function_ref<void(Value, StringRef)> setNameFn) {
631 if (!isa<IndexType>(getType()))
632 return;
633
634 // Match vector.vscale by name to avoid depending on the vector dialect (which
635 // is a circular dependency).
636 auto isVscale = [](Operation *op) {
637 return op && op->getName().getStringRef() == "vector.vscale";
638 };
639
640 IntegerAttr baseValue;
641 auto isVscaleExpr = [&](Value a, Value b) {
642 return matchPattern(a, m_Constant(&baseValue)) &&
643 isVscale(b.getDefiningOp());
644 };
645
646 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
647 return;
648
649 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
650 SmallString<32> specialNameBuffer;
651 llvm::raw_svector_ostream specialName(specialNameBuffer);
652 specialName << 'c' << baseValue.getInt() << "_vscale";
653 setNameFn(getResult(), specialName.str());
654}
655
656void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
657 MLIRContext *context) {
658 patterns.add<MulIMulIConstant>(context);
659}
660
661//===----------------------------------------------------------------------===//
662// MulSIExtendedOp
663//===----------------------------------------------------------------------===//
664
665std::optional<SmallVector<int64_t, 4>>
666arith::MulSIExtendedOp::getShapeForUnroll() {
667 if (auto vt = dyn_cast<VectorType>(getType(0)))
668 return llvm::to_vector<4>(vt.getShape());
669 return std::nullopt;
670}
671
672LogicalResult
673arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
674 SmallVectorImpl<OpFoldResult> &results) {
675 // mulsi_extended(x, 0) -> 0, 0
676 if (matchPattern(adaptor.getRhs(), m_Zero())) {
677 Attribute zero = adaptor.getRhs();
678 results.push_back(zero);
679 results.push_back(zero);
680 return success();
681 }
682
683 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
684 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
685 adaptor.getOperands(),
686 [](const APInt &a, const APInt &b) { return a * b; })) {
687 // Invoke the constant fold helper again to calculate the 'high' result.
688 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
689 llvm::APIntOps::mulhs);
690 assert(highAttr && "Unexpected constant-folding failure");
691
692 results.push_back(lowAttr);
693 results.push_back(highAttr);
694 return success();
695 }
696
697 return failure();
698}
699
700void arith::MulSIExtendedOp::getCanonicalizationPatterns(
701 RewritePatternSet &patterns, MLIRContext *context) {
702 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
703}
704
705//===----------------------------------------------------------------------===//
706// MulUIExtendedOp
707//===----------------------------------------------------------------------===//
708
709std::optional<SmallVector<int64_t, 4>>
710arith::MulUIExtendedOp::getShapeForUnroll() {
711 if (auto vt = dyn_cast<VectorType>(getType(0)))
712 return llvm::to_vector<4>(vt.getShape());
713 return std::nullopt;
714}
715
716LogicalResult
717arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
718 SmallVectorImpl<OpFoldResult> &results) {
719 // mului_extended(x, 0) -> 0, 0
720 if (matchPattern(adaptor.getRhs(), m_Zero())) {
721 Attribute zero = adaptor.getRhs();
722 results.push_back(zero);
723 results.push_back(zero);
724 return success();
725 }
726
727 // mului_extended(x, 1) -> x, 0
728 if (matchPattern(adaptor.getRhs(), m_One())) {
729 Builder builder(getContext());
730 Attribute zero = builder.getZeroAttr(getLhs().getType());
731 results.push_back(getLhs());
732 results.push_back(zero);
733 return success();
734 }
735
736 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
737 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
738 adaptor.getOperands(),
739 [](const APInt &a, const APInt &b) { return a * b; })) {
740 // Invoke the constant fold helper again to calculate the 'high' result.
741 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
742 llvm::APIntOps::mulhu);
743 assert(highAttr && "Unexpected constant-folding failure");
744
745 results.push_back(lowAttr);
746 results.push_back(highAttr);
747 return success();
748 }
749
750 return failure();
751}
752
753void arith::MulUIExtendedOp::getCanonicalizationPatterns(
754 RewritePatternSet &patterns, MLIRContext *context) {
755 patterns.add<MulUIExtendedToMulI>(context);
756}
757
758//===----------------------------------------------------------------------===//
759// DivUIOp
760//===----------------------------------------------------------------------===//
761
762/// Fold `(a * b) / b -> a`
764 arith::IntegerOverflowFlags ovfFlags) {
765 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
766 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
767 return {};
768
769 if (mul.getLhs() == rhs)
770 return mul.getRhs();
771
772 if (mul.getRhs() == rhs)
773 return mul.getLhs();
774
775 return {};
776}
777
778OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
779 // divui (x, 1) -> x.
780 if (matchPattern(adaptor.getRhs(), m_One()))
781 return getLhs();
782
783 // (a * b) / b -> a
784 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
785 return val;
786
787 // Don't fold if it would require a division by zero.
788 bool div0 = false;
789 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
790 [&](APInt a, const APInt &b) {
791 if (div0 || !b) {
792 div0 = true;
793 return a;
794 }
795 return a.udiv(b);
796 });
797
798 return div0 ? Attribute() : result;
799}
800
801/// Returns whether an unsigned division by `divisor` is speculatable.
803 // X / 0 => UB
804 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
806
808}
809
810Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
811 return getDivUISpeculatability(getRhs());
812}
813
814//===----------------------------------------------------------------------===//
815// DivSIOp
816//===----------------------------------------------------------------------===//
817
818OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
819 // divsi (x, 1) -> x.
820 if (matchPattern(adaptor.getRhs(), m_One()))
821 return getLhs();
822
823 // (a * b) / b -> a
824 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
825 return val;
826
827 // Don't fold if it would overflow or if it requires a division by zero.
828 bool overflowOrDiv0 = false;
830 adaptor.getOperands(), [&](APInt a, const APInt &b) {
831 if (overflowOrDiv0 || !b) {
832 overflowOrDiv0 = true;
833 return a;
834 }
835 return a.sdiv_ov(b, overflowOrDiv0);
836 });
837
838 return overflowOrDiv0 ? Attribute() : result;
839}
840
841/// Returns whether a signed division by `divisor` is speculatable. This
842/// function conservatively assumes that all signed division by -1 are not
843/// speculatable.
845 // X / 0 => UB
846 // INT_MIN / -1 => UB
847 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
850
852}
853
854Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
855 return getDivSISpeculatability(getRhs());
856}
857
858//===----------------------------------------------------------------------===//
859// Ceil and floor division folding helpers
860//===----------------------------------------------------------------------===//
861
862static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
863 bool &overflow) {
864 // Returns (a-1)/b + 1
865 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
866 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
867 return val.sadd_ov(one, overflow);
868}
869
870//===----------------------------------------------------------------------===//
871// CeilDivUIOp
872//===----------------------------------------------------------------------===//
873
874OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
875 // ceildivui (x, 1) -> x.
876 if (matchPattern(adaptor.getRhs(), m_One()))
877 return getLhs();
878
879 bool overflowOrDiv0 = false;
881 adaptor.getOperands(), [&](APInt a, const APInt &b) {
882 if (overflowOrDiv0 || !b) {
883 overflowOrDiv0 = true;
884 return a;
885 }
886 APInt quotient = a.udiv(b);
887 if (!a.urem(b))
888 return quotient;
889 APInt one(a.getBitWidth(), 1, true);
890 return quotient.uadd_ov(one, overflowOrDiv0);
891 });
892
893 return overflowOrDiv0 ? Attribute() : result;
894}
895
896Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
897 return getDivUISpeculatability(getRhs());
898}
899
900//===----------------------------------------------------------------------===//
901// CeilDivSIOp
902//===----------------------------------------------------------------------===//
903
904OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
905 // ceildivsi (x, 1) -> x.
906 if (matchPattern(adaptor.getRhs(), m_One()))
907 return getLhs();
908
909 // Don't fold if it would overflow or if it requires a division by zero.
910 // TODO: This hook won't fold operations where a = MININT, because
911 // negating MININT overflows. This can be improved.
912 bool overflowOrDiv0 = false;
914 adaptor.getOperands(), [&](APInt a, const APInt &b) {
915 if (overflowOrDiv0 || !b) {
916 overflowOrDiv0 = true;
917 return a;
918 }
919 if (!a)
920 return a;
921 // After this point we know that neither a or b are zero.
922 unsigned bits = a.getBitWidth();
923 APInt zero = APInt::getZero(bits);
924 bool aGtZero = a.sgt(zero);
925 bool bGtZero = b.sgt(zero);
926 if (aGtZero && bGtZero) {
927 // Both positive, return ceil(a, b).
928 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
929 }
930
931 // No folding happens if any of the intermediate arithmetic operations
932 // overflows.
933 bool overflowNegA = false;
934 bool overflowNegB = false;
935 bool overflowDiv = false;
936 bool overflowNegRes = false;
937 if (!aGtZero && !bGtZero) {
938 // Both negative, return ceil(-a, -b).
939 APInt posA = zero.ssub_ov(a, overflowNegA);
940 APInt posB = zero.ssub_ov(b, overflowNegB);
941 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
942 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
943 return res;
944 }
945 if (!aGtZero && bGtZero) {
946 // A is negative, b is positive, return - ( -a / b).
947 APInt posA = zero.ssub_ov(a, overflowNegA);
948 APInt div = posA.sdiv_ov(b, overflowDiv);
949 APInt res = zero.ssub_ov(div, overflowNegRes);
950 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
951 return res;
952 }
953 // A is positive, b is negative, return - (a / -b).
954 APInt posB = zero.ssub_ov(b, overflowNegB);
955 APInt div = a.sdiv_ov(posB, overflowDiv);
956 APInt res = zero.ssub_ov(div, overflowNegRes);
957
958 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
959 return res;
960 });
961
962 return overflowOrDiv0 ? Attribute() : result;
963}
964
965Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
966 return getDivSISpeculatability(getRhs());
967}
968
969//===----------------------------------------------------------------------===//
970// FloorDivSIOp
971//===----------------------------------------------------------------------===//
972
973OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
974 // floordivsi (x, 1) -> x.
975 if (matchPattern(adaptor.getRhs(), m_One()))
976 return getLhs();
977
978 // Don't fold if it would overflow or if it requires a division by zero.
979 bool overflowOrDiv = false;
981 adaptor.getOperands(), [&](APInt a, const APInt &b) {
982 if (b.isZero()) {
983 overflowOrDiv = true;
984 return a;
985 }
986 return a.sfloordiv_ov(b, overflowOrDiv);
987 });
988
989 return overflowOrDiv ? Attribute() : result;
990}
991
992//===----------------------------------------------------------------------===//
993// RemUIOp
994//===----------------------------------------------------------------------===//
995
996OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
997 // remui (x, 1) -> 0.
998 if (matchPattern(adaptor.getRhs(), m_One()))
999 return Builder(getContext()).getZeroAttr(getType());
1000
1001 // Don't fold if it would require a division by zero.
1002 bool div0 = false;
1003 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1004 [&](APInt a, const APInt &b) {
1005 if (div0 || b.isZero()) {
1006 div0 = true;
1007 return a;
1008 }
1009 return a.urem(b);
1010 });
1011
1012 return div0 ? Attribute() : result;
1013}
1014
1015Speculation::Speculatability arith::RemUIOp::getSpeculatability() {
1016 return getDivUISpeculatability(getRhs());
1017}
1018
1019//===----------------------------------------------------------------------===//
1020// RemSIOp
1021//===----------------------------------------------------------------------===//
1022
1023OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
1024 // remsi (x, 1) -> 0.
1025 if (matchPattern(adaptor.getRhs(), m_One()))
1026 return Builder(getContext()).getZeroAttr(getType());
1027
1028 // Don't fold if it would require a division by zero.
1029 bool div0 = false;
1030 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1031 [&](APInt a, const APInt &b) {
1032 if (div0 || b.isZero()) {
1033 div0 = true;
1034 return a;
1035 }
1036 return a.srem(b);
1037 });
1038
1039 return div0 ? Attribute() : result;
1040}
1041
1042Speculation::Speculatability arith::RemSIOp::getSpeculatability() {
1043 // X % 0 => UB
1044 // X % -1 is well-defined (always 0), unlike X / -1 which can overflow.
1045 if (matchPattern(getRhs(), m_IntRangeWithoutZeroS()))
1047
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// AndIOp
1053//===----------------------------------------------------------------------===//
1054
1055/// Fold `and(a, and(a, b))` to `and(a, b)`
1056static Value foldAndIofAndI(arith::AndIOp op) {
1057 for (bool reversePrev : {false, true}) {
1058 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
1059 .getDefiningOp<arith::AndIOp>();
1060 if (!prev)
1061 continue;
1062
1063 Value other = (reversePrev ? op.getLhs() : op.getRhs());
1064 if (other != prev.getLhs() && other != prev.getRhs())
1065 continue;
1066
1067 return prev.getResult();
1068 }
1069 return {};
1070}
1071
1072OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1073 /// and(x, 0) -> 0
1074 if (matchPattern(adaptor.getRhs(), m_Zero()))
1075 return getRhs();
1076 /// and(x, allOnes) -> x
1077 APInt intValue;
1078 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
1079 intValue.isAllOnes())
1080 return getLhs();
1081 /// and(x, not(x)) -> 0
1082 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1083 m_ConstantInt(&intValue))) &&
1084 intValue.isAllOnes())
1085 return Builder(getContext()).getZeroAttr(getType());
1086 /// and(not(x), x) -> 0
1087 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1088 m_ConstantInt(&intValue))) &&
1089 intValue.isAllOnes())
1090 return Builder(getContext()).getZeroAttr(getType());
1091
1092 /// and(a, and(a, b)) -> and(a, b)
1093 if (Value result = foldAndIofAndI(*this))
1094 return result;
1095
1097 adaptor.getOperands(),
1098 [](APInt a, const APInt &b) { return std::move(a) & b; });
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// OrIOp
1103//===----------------------------------------------------------------------===//
1104
1105OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1106 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1107 /// or(x, 0) -> x
1108 if (rhsVal.isZero())
1109 return getLhs();
1110 /// or(x, <all ones>) -> <all ones>
1111 if (rhsVal.isAllOnes())
1112 return adaptor.getRhs();
1113 }
1114
1115 APInt intValue;
1116 /// or(x, xor(x, 1)) -> 1
1117 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1118 m_ConstantInt(&intValue))) &&
1119 intValue.isAllOnes())
1120 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1121 /// or(xor(x, 1), x) -> 1
1122 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1123 m_ConstantInt(&intValue))) &&
1124 intValue.isAllOnes())
1125 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1126
1128 adaptor.getOperands(),
1129 [](APInt a, const APInt &b) { return std::move(a) | b; });
1130}
1131
1132//===----------------------------------------------------------------------===//
1133// XOrIOp
1134//===----------------------------------------------------------------------===//
1135
1136OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1137 /// xor(x, 0) -> x
1138 if (matchPattern(adaptor.getRhs(), m_Zero()))
1139 return getLhs();
1140 /// xor(x, x) -> 0
1141 if (getLhs() == getRhs())
1142 return Builder(getContext()).getZeroAttr(getType());
1143 /// xor(xor(x, a), a) -> x
1144 /// xor(xor(a, x), a) -> x
1145 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1146 if (prev.getRhs() == getRhs())
1147 return prev.getLhs();
1148 if (prev.getLhs() == getRhs())
1149 return prev.getRhs();
1150 }
1151 /// xor(a, xor(x, a)) -> x
1152 /// xor(a, xor(a, x)) -> x
1153 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1154 if (prev.getRhs() == getLhs())
1155 return prev.getLhs();
1156 if (prev.getLhs() == getLhs())
1157 return prev.getRhs();
1158 }
1159
1161 adaptor.getOperands(),
1162 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1163}
1164
1165void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1166 MLIRContext *context) {
1167 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// NegFOp
1172//===----------------------------------------------------------------------===//
1173
1174OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1175 /// negf(negf(x)) -> x
1176 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1177 return op.getOperand();
1178 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1179 [](const APFloat &a) { return -a; });
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// FlushDenormalsOp
1184//===----------------------------------------------------------------------===//
1185
1186OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1187 // TODO: Fold flush_denormals if the floating-point type does not support
1188 // denormals. There is currently no API to query this information from
1189 // APFloat.
1190
1191 // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
1192 if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1193 return op.getResult();
1194
1195 // Constant-fold flush_denormals if the operand is a constant.
1197 adaptor.getOperands(), [](const APFloat &a) {
1198 if (a.isDenormal())
1199 return APFloat::getZero(a.getSemantics(), a.isNegative());
1200 return a;
1201 });
1202}
1203
1204//===----------------------------------------------------------------------===//
1205// AddFOp
1206//===----------------------------------------------------------------------===//
1207
1208OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1209 // addf(x, -0) -> x
1210 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1211 return getLhs();
1212
1213 auto rm = getRoundingmode();
1215 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1216 APFloat result(a);
1217 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1218 return result;
1219 });
1220}
1221
1222//===----------------------------------------------------------------------===//
1223// SubFOp
1224//===----------------------------------------------------------------------===//
1225
1226OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1227 // subf(x, +0) -> x
1228 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1229 return getLhs();
1230
1231 auto rm = getRoundingmode();
1233 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1234 APFloat result(a);
1235 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1236 return result;
1237 });
1238}
1239
1240void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1241 MLIRContext *context) {
1242 patterns.add<SubFOfNegZero>(context);
1243}
1244
1245//===----------------------------------------------------------------------===//
1246// MaximumFOp
1247//===----------------------------------------------------------------------===//
1248
1249OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1250 // maximumf(x,x) -> x
1251 if (getLhs() == getRhs())
1252 return getRhs();
1253
1254 // maximumf(x, -inf) -> x
1255 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1256 return getLhs();
1257
1258 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maximum);
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// MaxNumFOp
1263//===----------------------------------------------------------------------===//
1264
1265OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1266 // maxnumf(x,x) -> x
1267 if (getLhs() == getRhs())
1268 return getRhs();
1269
1270 // maxnumf(x, NaN) -> x
1271 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1272 return getLhs();
1273
1274 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1275}
1276
1277//===----------------------------------------------------------------------===//
1278// MaxSIOp
1279//===----------------------------------------------------------------------===//
1280
1281OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1282 // maxsi(x,x) -> x
1283 if (getLhs() == getRhs())
1284 return getRhs();
1285
1286 if (APInt intValue;
1287 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1288 // maxsi(x,MAX_INT) -> MAX_INT
1289 if (intValue.isMaxSignedValue())
1290 return getRhs();
1291 // maxsi(x, MIN_INT) -> x
1292 if (intValue.isMinSignedValue())
1293 return getLhs();
1294 }
1295
1296 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1297 llvm::APIntOps::smax);
1298}
1299
1300//===----------------------------------------------------------------------===//
1301// MaxUIOp
1302//===----------------------------------------------------------------------===//
1303
1304OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1305 // maxui(x,x) -> x
1306 if (getLhs() == getRhs())
1307 return getRhs();
1308
1309 if (APInt intValue;
1310 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1311 // maxui(x,MAX_INT) -> MAX_INT
1312 if (intValue.isMaxValue())
1313 return getRhs();
1314 // maxui(x, MIN_INT) -> x
1315 if (intValue.isMinValue())
1316 return getLhs();
1317 }
1318
1319 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1320 llvm::APIntOps::umax);
1321}
1322
1323//===----------------------------------------------------------------------===//
1324// MinimumFOp
1325//===----------------------------------------------------------------------===//
1326
1327OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1328 // minimumf(x,x) -> x
1329 if (getLhs() == getRhs())
1330 return getRhs();
1331
1332 // minimumf(x, +inf) -> x
1333 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1334 return getLhs();
1335
1336 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minimum);
1337}
1338
1339//===----------------------------------------------------------------------===//
1340// MinNumFOp
1341//===----------------------------------------------------------------------===//
1342
1343OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1344 // minnumf(x,x) -> x
1345 if (getLhs() == getRhs())
1346 return getRhs();
1347
1348 // minnumf(x, NaN) -> x
1349 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1350 return getLhs();
1351
1352 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minnum);
1353}
1354
1355//===----------------------------------------------------------------------===//
1356// MinSIOp
1357//===----------------------------------------------------------------------===//
1358
1359OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1360 // minsi(x,x) -> x
1361 if (getLhs() == getRhs())
1362 return getRhs();
1363
1364 if (APInt intValue;
1365 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1366 // minsi(x,MIN_INT) -> MIN_INT
1367 if (intValue.isMinSignedValue())
1368 return getRhs();
1369 // minsi(x, MAX_INT) -> x
1370 if (intValue.isMaxSignedValue())
1371 return getLhs();
1372 }
1373
1374 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1375 llvm::APIntOps::smin);
1376}
1377
1378//===----------------------------------------------------------------------===//
1379// MinUIOp
1380//===----------------------------------------------------------------------===//
1381
1382OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1383 // minui(x,x) -> x
1384 if (getLhs() == getRhs())
1385 return getRhs();
1386
1387 if (APInt intValue;
1388 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1389 // minui(x,MIN_INT) -> MIN_INT
1390 if (intValue.isMinValue())
1391 return getRhs();
1392 // minui(x, MAX_INT) -> x
1393 if (intValue.isMaxValue())
1394 return getLhs();
1395 }
1396
1397 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1398 llvm::APIntOps::umin);
1399}
1400
1401//===----------------------------------------------------------------------===//
1402// MulFOp
1403//===----------------------------------------------------------------------===//
1404
1405OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1406 // mulf(x, 1) -> x
1407 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1408 return getLhs();
1409
1410 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1411 arith::FastMathFlags::nsz)) {
1412 // mulf(x, 0) -> 0
1413 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1414 return getRhs();
1415 }
1416
1417 auto rm = getRoundingmode();
1419 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1420 APFloat result(a);
1421 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1422 return result;
1423 });
1424}
1425
1426void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1427 MLIRContext *context) {
1428 patterns.add<MulFOfNegF>(context);
1429}
1430
1431//===----------------------------------------------------------------------===//
1432// DivFOp
1433//===----------------------------------------------------------------------===//
1434
1435OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1436 // divf(x, 1) -> x
1437 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1438 return getLhs();
1439
1440 auto rm = getRoundingmode();
1442 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1443 APFloat result(a);
1444 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1445 return result;
1446 });
1447}
1448
1449void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1450 MLIRContext *context) {
1451 patterns.add<DivFOfNegF>(context);
1452}
1453
1454//===----------------------------------------------------------------------===//
1455// RemFOp
1456//===----------------------------------------------------------------------===//
1457
1458OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1459 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1460 [](const APFloat &a, const APFloat &b) {
1461 APFloat result(a);
1462 // APFloat::mod() offers the remainder
1463 // behavior we want, i.e. the result has
1464 // the sign of LHS operand.
1465 (void)result.mod(b);
1466 return result;
1467 });
1468}
1469
1470//===----------------------------------------------------------------------===//
1471// Utility functions for verifying cast ops
1472//===----------------------------------------------------------------------===//
1473
1474template <typename... Types>
1475using type_list = std::tuple<Types...> *;
1476
1477/// Returns a non-null type only if the provided type is one of the allowed
1478/// types or one of the allowed shaped types of the allowed types. Returns the
1479/// element type if a valid shaped type is provided.
1480template <typename... ShapedTypes, typename... ElementTypes>
1483 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1484 return {};
1485
1486 auto underlyingType = getElementTypeOrSelf(type);
1487 if (!llvm::isa<ElementTypes...>(underlyingType))
1488 return {};
1489
1490 return underlyingType;
1491}
1492
1493/// Get allowed underlying types for vectors and tensors.
1494template <typename... ElementTypes>
1499
1500/// Get allowed underlying types for vectors, tensors, and memrefs.
1501template <typename... ElementTypes>
1507
1508/// Return false if both types are ranked tensor with mismatching encoding.
1509static bool hasSameEncoding(Type typeA, Type typeB) {
1510 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1511 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1512 if (!rankedTensorA || !rankedTensorB)
1513 return true;
1514 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1515}
1516
1518 if (inputs.size() != 1 || outputs.size() != 1)
1519 return false;
1520 if (!hasSameEncoding(inputs.front(), outputs.front()))
1521 return false;
1522 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// Verifiers for integer and floating point extension/truncation ops
1527//===----------------------------------------------------------------------===//
1528
1529// Extend ops can only extend to a wider type.
1530template <typename ValType, typename Op>
1531static LogicalResult verifyExtOp(Op op) {
1532 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1533 Type dstType = getElementTypeOrSelf(op.getType());
1534
1535 if (llvm::cast<ValType>(srcType).getWidth() >=
1536 llvm::cast<ValType>(dstType).getWidth())
1537 return op.emitError("result type ")
1538 << dstType << " must be wider than operand type " << srcType;
1539
1540 return success();
1541}
1542
1543// Truncate ops can only truncate to a shorter type.
1544template <typename ValType, typename Op>
1545static LogicalResult verifyTruncateOp(Op op) {
1546 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1547 Type dstType = getElementTypeOrSelf(op.getType());
1548
1549 if (llvm::cast<ValType>(srcType).getWidth() <=
1550 llvm::cast<ValType>(dstType).getWidth())
1551 return op.emitError("result type ")
1552 << dstType << " must be shorter than operand type " << srcType;
1553
1554 return success();
1555}
1556
1557/// Validate a cast that changes the width of a type.
1558template <template <typename> class WidthComparator, typename... ElementTypes>
1559static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1560 if (!areValidCastInputsAndOutputs(inputs, outputs))
1561 return false;
1562
1563 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1564 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1565 if (!srcType || !dstType)
1566 return false;
1567
1568 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1569 srcType.getIntOrFloatBitWidth());
1570}
1571
1572/// Attempts to convert `sourceValue` to an APFloat value with
1573/// `targetSemantics` and `roundingMode`, without any information loss.
1574static FailureOr<APFloat>
1575convertFloatValue(APFloat sourceValue,
1576 const llvm::fltSemantics &targetSemantics,
1577 llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
1578 // Reject special values that are not representable in the target type before
1579 // calling APFloat::convert, which would llvm_unreachable on them.
1580 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1581 if (sourceValue.isInfinity() &&
1582 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1583 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1584 return failure();
1585 if (sourceValue.isNaN() &&
1586 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1587 return failure();
1588
1589 bool losesInfo = false;
1590 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1591 if (losesInfo || status != APFloat::opOK)
1592 return failure();
1593
1594 return sourceValue;
1595}
1596
1597//===----------------------------------------------------------------------===//
1598// ExtUIOp
1599//===----------------------------------------------------------------------===//
1600
1601OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1602 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1603 getInMutable().assign(lhs.getIn());
1604 return getResult();
1605 }
1606
1607 Type resType = getElementTypeOrSelf(getType());
1608 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1610 adaptor.getOperands(), getType(),
1611 [bitWidth](const APInt &a, bool &castStatus) {
1612 return a.zext(bitWidth);
1613 });
1614}
1615
1616bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1618}
1619
1620LogicalResult arith::ExtUIOp::verify() {
1621 return verifyExtOp<IntegerType>(*this);
1622}
1623
1624//===----------------------------------------------------------------------===//
1625// ExtSIOp
1626//===----------------------------------------------------------------------===//
1627
1628OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1629 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1630 getInMutable().assign(lhs.getIn());
1631 return getResult();
1632 }
1633
1634 Type resType = getElementTypeOrSelf(getType());
1635 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1637 adaptor.getOperands(), getType(),
1638 [bitWidth](const APInt &a, bool &castStatus) {
1639 return a.sext(bitWidth);
1640 });
1641}
1642
1643bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1645}
1646
1647void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1648 MLIRContext *context) {
1649 patterns.add<ExtSIOfExtUI>(context);
1650}
1651
1652LogicalResult arith::ExtSIOp::verify() {
1653 return verifyExtOp<IntegerType>(*this);
1654}
1655
1656//===----------------------------------------------------------------------===//
1657// ExtFOp
1658//===----------------------------------------------------------------------===//
1659
1660/// Fold extension of float constants when there is no information loss due the
1661/// difference in fp semantics.
1662OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1663 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1664 if (truncFOp.getOperand().getType() == getType()) {
1665 arith::FastMathFlags truncFMF =
1666 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1667 bool isTruncContract =
1668 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1669 arith::FastMathFlags extFMF =
1670 getFastmath().value_or(arith::FastMathFlags::none);
1671 bool isExtContract =
1672 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1673 if (isTruncContract && isExtContract) {
1674 return truncFOp.getOperand();
1675 }
1676 }
1677 }
1678
1679 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1680 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1682 adaptor.getOperands(), getType(),
1683 [&targetSemantics](const APFloat &a, bool &castStatus) {
1684 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1685 if (failed(result)) {
1686 castStatus = false;
1687 return a;
1688 }
1689 return *result;
1690 });
1691}
1692
1693bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1694 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1695}
1696
1697LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1698
1699//===----------------------------------------------------------------------===//
1700// ScalingExtFOp
1701//===----------------------------------------------------------------------===//
1702
1703bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1704 TypeRange outputs) {
1705 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1706}
1707
1708LogicalResult arith::ScalingExtFOp::verify() {
1709 return verifyExtOp<FloatType>(*this);
1710}
1711
1712//===----------------------------------------------------------------------===//
1713// TruncIOp
1714//===----------------------------------------------------------------------===//
1715
1716OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1717 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1718 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1719 Value src = getOperand().getDefiningOp()->getOperand(0);
1720 Type srcType = getElementTypeOrSelf(src.getType());
1721 Type dstType = getElementTypeOrSelf(getType());
1722 // trunci(zexti(a)) -> trunci(a)
1723 // trunci(sexti(a)) -> trunci(a)
1724 if (llvm::cast<IntegerType>(srcType).getWidth() >
1725 llvm::cast<IntegerType>(dstType).getWidth()) {
1726 setOperand(src);
1727 return getResult();
1728 }
1729
1730 // trunci(zexti(a)) -> a
1731 // trunci(sexti(a)) -> a
1732 if (srcType == dstType)
1733 return src;
1734 }
1735
1736 // trunci(trunci(a)) -> trunci(a))
1737 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1738 setOperand(getOperand().getDefiningOp()->getOperand(0));
1739 return getResult();
1740 }
1741
1742 Type resType = getElementTypeOrSelf(getType());
1743 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1745 adaptor.getOperands(), getType(),
1746 [bitWidth](const APInt &a, bool &castStatus) {
1747 return a.trunc(bitWidth);
1748 });
1749}
1750
1751bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1752 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1753}
1754
1755void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1756 MLIRContext *context) {
1757 patterns
1758 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1759 context);
1760}
1761
1762LogicalResult arith::TruncIOp::verify() {
1763 return verifyTruncateOp<IntegerType>(*this);
1764}
1765
1766//===----------------------------------------------------------------------===//
1767// TruncFOp
1768//===----------------------------------------------------------------------===//
1769
1770/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1771/// can be represented without precision loss.
1772OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1773 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1774 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1775 Value src = extOp.getIn();
1776 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1777 auto intermediateType =
1778 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1779 // Check if the srcType is representable in the intermediateType.
1780 if (llvm::APFloatBase::isRepresentableBy(
1781 srcType.getFloatSemantics(),
1782 intermediateType.getFloatSemantics())) {
1783 // truncf(extf(a)) -> truncf(a)
1784 if (srcType.getWidth() > resElemType.getWidth()) {
1785 setOperand(src);
1786 return getResult();
1787 }
1788
1789 // truncf(extf(a)) -> a
1790 if (srcType == resElemType)
1791 return src;
1792 }
1793 }
1794
1795 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1797 adaptor.getOperands(), getType(),
1798 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1799 llvm::RoundingMode llvmRoundingMode =
1800 convertArithRoundingModeToLLVMIR(getRoundingmode());
1801 FailureOr<APFloat> result =
1802 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1803 if (failed(result)) {
1804 castStatus = false;
1805 return a;
1806 }
1807 return *result;
1808 });
1809}
1810
1811void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1812 MLIRContext *context) {
1813 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1814}
1815
1816bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1817 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1818}
1819
1820LogicalResult arith::TruncFOp::verify() {
1821 return verifyTruncateOp<FloatType>(*this);
1822}
1823
1824//===----------------------------------------------------------------------===//
1825// ConvertFOp
1826//===----------------------------------------------------------------------===//
1827
1828OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1829 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1830 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1832 adaptor.getOperands(), getType(),
1833 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1834 llvm::RoundingMode llvmRoundingMode =
1835 convertArithRoundingModeToLLVMIR(getRoundingmode());
1836 FailureOr<APFloat> result =
1837 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1838 if (failed(result)) {
1839 castStatus = false;
1840 return a;
1841 }
1842 return *result;
1843 });
1844}
1845
1846bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1847 if (!areValidCastInputsAndOutputs(inputs, outputs))
1848 return false;
1849 auto srcType = getTypeIfLike<FloatType>(inputs.front());
1850 auto dstType = getTypeIfLike<FloatType>(outputs.front());
1851 if (!srcType || !dstType)
1852 return false;
1853 return srcType != dstType &&
1854 srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1855}
1856
1857LogicalResult arith::ConvertFOp::verify() {
1858 auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
1859 auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
1860 if (srcType == dstType)
1861 return emitError("result element type ")
1862 << dstType << " must be different from operand element type "
1863 << srcType;
1864 if (srcType.getWidth() != dstType.getWidth())
1865 return emitError("result element type ")
1866 << dstType << " must have the same bitwidth as operand element type "
1867 << srcType;
1868 return success();
1869}
1870
1871//===----------------------------------------------------------------------===//
1872// ScalingTruncFOp
1873//===----------------------------------------------------------------------===//
1874
1875bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1876 TypeRange outputs) {
1877 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1878}
1879
1880LogicalResult arith::ScalingTruncFOp::verify() {
1881 return verifyTruncateOp<FloatType>(*this);
1882}
1883
1884//===----------------------------------------------------------------------===//
1885// AndIOp
1886//===----------------------------------------------------------------------===//
1887
1888void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1889 MLIRContext *context) {
1890 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1891}
1892
1893//===----------------------------------------------------------------------===//
1894// OrIOp
1895//===----------------------------------------------------------------------===//
1896
1897void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1898 MLIRContext *context) {
1899 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1900}
1901
1902//===----------------------------------------------------------------------===//
1903// Verifiers for casts between integers and floats.
1904//===----------------------------------------------------------------------===//
1905
1906template <typename From, typename To>
1907static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1908 if (!areValidCastInputsAndOutputs(inputs, outputs))
1909 return false;
1910
1911 auto srcType = getTypeIfLike<From>(inputs.front());
1912 auto dstType = getTypeIfLike<To>(outputs.back());
1913
1914 return srcType && dstType;
1915}
1916
1917//===----------------------------------------------------------------------===//
1918// UIToFPOp
1919//===----------------------------------------------------------------------===//
1920
1921bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1922 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1923}
1924
1925OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1926 Type resEleType = getElementTypeOrSelf(getType());
1928 adaptor.getOperands(), getType(),
1929 [&resEleType](const APInt &a, bool &castStatus) {
1930 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1931 APFloat apf(floatTy.getFloatSemantics(),
1932 APInt::getZero(floatTy.getWidth()));
1933 apf.convertFromAPInt(a, /*IsSigned=*/false,
1934 APFloat::rmNearestTiesToEven);
1935 return apf;
1936 });
1937}
1938
1939void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1940 MLIRContext *context) {
1941 patterns.add<UIToFPOfExtUI>(context);
1942}
1943
1944//===----------------------------------------------------------------------===//
1945// SIToFPOp
1946//===----------------------------------------------------------------------===//
1947
1948bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1949 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1950}
1951
1952OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1953 Type resEleType = getElementTypeOrSelf(getType());
1955 adaptor.getOperands(), getType(),
1956 [&resEleType](const APInt &a, bool &castStatus) {
1957 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1958 APFloat apf(floatTy.getFloatSemantics(),
1959 APInt::getZero(floatTy.getWidth()));
1960 apf.convertFromAPInt(a, /*IsSigned=*/true,
1961 APFloat::rmNearestTiesToEven);
1962 return apf;
1963 });
1964}
1965
1966void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1967 MLIRContext *context) {
1968 patterns.add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1969}
1970
1971//===----------------------------------------------------------------------===//
1972// FPToUIOp
1973//===----------------------------------------------------------------------===//
1974
1975bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1976 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1977}
1978
1979OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1980 Type resType = getElementTypeOrSelf(getType());
1981 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1983 adaptor.getOperands(), getType(),
1984 [&bitWidth](const APFloat &a, bool &castStatus) {
1985 bool ignored;
1986 APSInt api(bitWidth, /*isUnsigned=*/true);
1987 castStatus = APFloat::opInvalidOp !=
1988 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1989 return api;
1990 });
1991}
1992
1993//===----------------------------------------------------------------------===//
1994// FPToSIOp
1995//===----------------------------------------------------------------------===//
1996
1997bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1998 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1999}
2000
2001OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
2002 Type resType = getElementTypeOrSelf(getType());
2003 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
2005 adaptor.getOperands(), getType(),
2006 [&bitWidth](const APFloat &a, bool &castStatus) {
2007 bool ignored;
2008 APSInt api(bitWidth, /*isUnsigned=*/false);
2009 castStatus = APFloat::opInvalidOp !=
2010 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2011 return api;
2012 });
2013}
2014
2015//===----------------------------------------------------------------------===//
2016// IndexCastOp
2017//===----------------------------------------------------------------------===//
2018
2019/// Return the bit-width of \p t for the purpose of index_cast width checks.
2020/// For vector types use the element type; index maps to its internal storage
2021/// width (64 on all current targets).
2022static unsigned getIndexCastWidth(Type t) {
2023 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
2024 return intTy.getWidth();
2025 return IndexType::kInternalStorageBitWidth;
2026}
2027
2028static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
2029 if (!areValidCastInputsAndOutputs(inputs, outputs))
2030 return false;
2031
2032 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
2033 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
2034 if (!srcType || !dstType)
2035 return false;
2036
2037 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
2038 (srcType.isSignlessInteger() && dstType.isIndex());
2039}
2040
2041bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
2042 TypeRange outputs) {
2043 return areIndexCastCompatible(inputs, outputs);
2044}
2045
2046OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
2047 // index_cast(constant) -> constant
2048 unsigned resultBitwidth = 64; // Default for index integer attributes.
2049 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2050 resultBitwidth = intTy.getWidth();
2051
2052 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2053 adaptor.getOperands(), getType(),
2054 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2055 return a.sextOrTrunc(resultBitwidth);
2056 }))
2057 return foldResult;
2058
2059 // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
2060 // as wide as A. If B is narrower, the inner cast truncates and the outer
2061 // cast sign-extends, so the round-trip is lossy.
2062 if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
2063 Value x = inner.getOperand();
2064 if (x.getType() == getType()) {
2065 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2066 return x;
2067 }
2068 }
2069 return {};
2070}
2071
2072void arith::IndexCastOp::getCanonicalizationPatterns(
2073 RewritePatternSet &patterns, MLIRContext *context) {
2074 patterns.add<IndexCastOfExtSI>(context);
2075}
2076
2077//===----------------------------------------------------------------------===//
2078// IndexCastUIOp
2079//===----------------------------------------------------------------------===//
2080
2081bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
2082 TypeRange outputs) {
2083 return areIndexCastCompatible(inputs, outputs);
2084}
2085
2086OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2087 // index_castui(constant) -> constant
2088 unsigned resultBitwidth = 64; // Default for index integer attributes.
2089 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2090 resultBitwidth = intTy.getWidth();
2091
2092 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2093 adaptor.getOperands(), getType(),
2094 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2095 return a.zextOrTrunc(resultBitwidth);
2096 }))
2097 return foldResult;
2098
2099 // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
2100 // least as wide as A. If B is narrower, the inner cast truncates and the
2101 // outer cast zero-extends, so the round-trip is lossy.
2102 if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2103 Value x = inner.getOperand();
2104 if (x.getType() == getType()) {
2105 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2106 return x;
2107 }
2108 }
2109 return {};
2110}
2111
2112void arith::IndexCastUIOp::getCanonicalizationPatterns(
2113 RewritePatternSet &patterns, MLIRContext *context) {
2114 patterns.add<IndexCastUIOfExtUI>(context);
2115}
2116
2117//===----------------------------------------------------------------------===//
2118// BitcastOp
2119//===----------------------------------------------------------------------===//
2120
2121bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2122 if (!areValidCastInputsAndOutputs(inputs, outputs))
2123 return false;
2124
2125 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
2126 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
2127 if (!srcType || !dstType)
2128 return false;
2129
2130 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
2131}
2132
2133OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2134 auto resType = getType();
2135 auto operand = adaptor.getIn();
2136 if (!operand)
2137 return {};
2138
2139 /// Bitcast dense elements.
2140 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2141 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
2142 /// Other shaped types unhandled.
2143 if (llvm::isa<ShapedType>(resType))
2144 return {};
2145
2146 /// Bitcast poison.
2147 if (matchPattern(operand, ub::m_Poison()))
2148 return ub::PoisonAttr::get(getContext());
2149
2150 /// Bitcast integer or float to integer or float.
2151 APInt bits = llvm::isa<FloatAttr>(operand)
2152 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2153 : llvm::cast<IntegerAttr>(operand).getValue();
2154 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
2155 "trying to fold on broken IR: operands have incompatible types");
2156
2157 if (auto resFloatType = dyn_cast<FloatType>(resType))
2158 return FloatAttr::get(resType,
2159 APFloat(resFloatType.getFloatSemantics(), bits));
2160 return IntegerAttr::get(resType, bits);
2161}
2162
2163void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2164 MLIRContext *context) {
2165 patterns.add<BitcastOfBitcast>(context);
2166}
2167
2168//===----------------------------------------------------------------------===//
2169// CmpIOp
2170//===----------------------------------------------------------------------===//
2171
2172/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
2173/// comparison predicates.
2174bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
2175 const APInt &lhs, const APInt &rhs) {
2176 switch (predicate) {
2177 case arith::CmpIPredicate::eq:
2178 return lhs.eq(rhs);
2179 case arith::CmpIPredicate::ne:
2180 return lhs.ne(rhs);
2181 case arith::CmpIPredicate::slt:
2182 return lhs.slt(rhs);
2183 case arith::CmpIPredicate::sle:
2184 return lhs.sle(rhs);
2185 case arith::CmpIPredicate::sgt:
2186 return lhs.sgt(rhs);
2187 case arith::CmpIPredicate::sge:
2188 return lhs.sge(rhs);
2189 case arith::CmpIPredicate::ult:
2190 return lhs.ult(rhs);
2191 case arith::CmpIPredicate::ule:
2192 return lhs.ule(rhs);
2193 case arith::CmpIPredicate::ugt:
2194 return lhs.ugt(rhs);
2195 case arith::CmpIPredicate::uge:
2196 return lhs.uge(rhs);
2197 }
2198 llvm_unreachable("unknown cmpi predicate kind");
2199}
2200
2201/// Returns true if the predicate is true for two equal operands.
2202static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
2203 switch (predicate) {
2204 case arith::CmpIPredicate::eq:
2205 case arith::CmpIPredicate::sle:
2206 case arith::CmpIPredicate::sge:
2207 case arith::CmpIPredicate::ule:
2208 case arith::CmpIPredicate::uge:
2209 return true;
2210 case arith::CmpIPredicate::ne:
2211 case arith::CmpIPredicate::slt:
2212 case arith::CmpIPredicate::sgt:
2213 case arith::CmpIPredicate::ult:
2214 case arith::CmpIPredicate::ugt:
2215 return false;
2216 }
2217 llvm_unreachable("unknown cmpi predicate kind");
2218}
2219
2220static std::optional<int64_t> getIntegerWidth(Type t) {
2221 if (auto intType = dyn_cast<IntegerType>(t)) {
2222 return intType.getWidth();
2223 }
2224 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2225 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2226 }
2227 return std::nullopt;
2228}
2229
2230OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2231 // cmpi(pred, x, x)
2232 if (getLhs() == getRhs()) {
2233 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2234 return getBoolAttribute(getType(), val);
2235 }
2236
2237 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2238 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2239 // extsi(%x : i1 -> iN) != 0 -> %x
2240 std::optional<int64_t> integerWidth =
2241 getIntegerWidth(extOp.getOperand().getType());
2242 if (integerWidth && integerWidth.value() == 1 &&
2243 getPredicate() == arith::CmpIPredicate::ne)
2244 return extOp.getOperand();
2245 }
2246 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2247 // extui(%x : i1 -> iN) != 0 -> %x
2248 std::optional<int64_t> integerWidth =
2249 getIntegerWidth(extOp.getOperand().getType());
2250 if (integerWidth && integerWidth.value() == 1 &&
2251 getPredicate() == arith::CmpIPredicate::ne)
2252 return extOp.getOperand();
2253 }
2254
2255 // arith.cmpi ne, %val, %zero : i1 -> %val
2256 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2257 getPredicate() == arith::CmpIPredicate::ne)
2258 return getLhs();
2259 }
2260
2261 if (matchPattern(adaptor.getRhs(), m_One())) {
2262 // arith.cmpi eq, %val, %one : i1 -> %val
2263 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2264 getPredicate() == arith::CmpIPredicate::eq)
2265 return getLhs();
2266 }
2267
2268 // Move constant to the right side.
2269 if (adaptor.getLhs() && !adaptor.getRhs()) {
2270 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2271 using Pred = CmpIPredicate;
2272 const std::pair<Pred, Pred> invPreds[] = {
2273 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2274 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2275 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2276 {Pred::ne, Pred::ne},
2277 };
2278 Pred origPred = getPredicate();
2279 for (auto pred : invPreds) {
2280 if (origPred == pred.first) {
2281 setPredicate(pred.second);
2282 Value lhs = getLhs();
2283 Value rhs = getRhs();
2284 getLhsMutable().assign(rhs);
2285 getRhsMutable().assign(lhs);
2286 return getResult();
2287 }
2288 }
2289 llvm_unreachable("unknown cmpi predicate kind");
2290 }
2291
2292 // We are moving constants to the right side; So if lhs is constant rhs is
2293 // guaranteed to be a constant.
2294 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2296 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2297 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2298 return APInt(1,
2299 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2300 });
2301 }
2302
2303 return {};
2304}
2305
2306void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2307 MLIRContext *context) {
2308 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2309}
2310
2311//===----------------------------------------------------------------------===//
2312// CmpFOp
2313//===----------------------------------------------------------------------===//
2314
2315/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2316/// comparison predicates.
2317bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2318 const APFloat &lhs, const APFloat &rhs) {
2319 auto cmpResult = lhs.compare(rhs);
2320 switch (predicate) {
2321 case arith::CmpFPredicate::AlwaysFalse:
2322 return false;
2323 case arith::CmpFPredicate::OEQ:
2324 return cmpResult == APFloat::cmpEqual;
2325 case arith::CmpFPredicate::OGT:
2326 return cmpResult == APFloat::cmpGreaterThan;
2327 case arith::CmpFPredicate::OGE:
2328 return cmpResult == APFloat::cmpGreaterThan ||
2329 cmpResult == APFloat::cmpEqual;
2330 case arith::CmpFPredicate::OLT:
2331 return cmpResult == APFloat::cmpLessThan;
2332 case arith::CmpFPredicate::OLE:
2333 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2334 case arith::CmpFPredicate::ONE:
2335 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2336 case arith::CmpFPredicate::ORD:
2337 return cmpResult != APFloat::cmpUnordered;
2338 case arith::CmpFPredicate::UEQ:
2339 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2340 case arith::CmpFPredicate::UGT:
2341 return cmpResult == APFloat::cmpUnordered ||
2342 cmpResult == APFloat::cmpGreaterThan;
2343 case arith::CmpFPredicate::UGE:
2344 return cmpResult == APFloat::cmpUnordered ||
2345 cmpResult == APFloat::cmpGreaterThan ||
2346 cmpResult == APFloat::cmpEqual;
2347 case arith::CmpFPredicate::ULT:
2348 return cmpResult == APFloat::cmpUnordered ||
2349 cmpResult == APFloat::cmpLessThan;
2350 case arith::CmpFPredicate::ULE:
2351 return cmpResult == APFloat::cmpUnordered ||
2352 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2353 case arith::CmpFPredicate::UNE:
2354 return cmpResult != APFloat::cmpEqual;
2355 case arith::CmpFPredicate::UNO:
2356 return cmpResult == APFloat::cmpUnordered;
2357 case arith::CmpFPredicate::AlwaysTrue:
2358 return true;
2359 }
2360 llvm_unreachable("unknown cmpf predicate kind");
2361}
2362
2363OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2364 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2365 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2366
2367 // If one operand is NaN, making them both NaN does not change the result.
2368 if (lhs && lhs.getValue().isNaN())
2369 rhs = lhs;
2370 if (rhs && rhs.getValue().isNaN())
2371 lhs = rhs;
2372
2373 if (!lhs || !rhs)
2374 return {};
2375
2376 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2377 return BoolAttr::get(getContext(), val);
2378}
2379
2380class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2381public:
2382 using Base::Base;
2383
2384 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2385 bool isUnsigned) {
2386 using namespace arith;
2387 switch (pred) {
2388 case CmpFPredicate::UEQ:
2389 case CmpFPredicate::OEQ:
2390 return CmpIPredicate::eq;
2391 case CmpFPredicate::UGT:
2392 case CmpFPredicate::OGT:
2393 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2394 case CmpFPredicate::UGE:
2395 case CmpFPredicate::OGE:
2396 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2397 case CmpFPredicate::ULT:
2398 case CmpFPredicate::OLT:
2399 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2400 case CmpFPredicate::ULE:
2401 case CmpFPredicate::OLE:
2402 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2403 case CmpFPredicate::UNE:
2404 case CmpFPredicate::ONE:
2405 return CmpIPredicate::ne;
2406 default:
2407 llvm_unreachable("Unexpected predicate!");
2408 }
2409 }
2410
2411 LogicalResult matchAndRewrite(CmpFOp op,
2412 PatternRewriter &rewriter) const override {
2413 FloatAttr flt;
2414 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2415 return failure();
2416
2417 const APFloat &rhs = flt.getValue();
2418
2419 // Don't attempt to fold a nan.
2420 if (rhs.isNaN())
2421 return failure();
2422
2423 // Get the width of the mantissa. We don't want to hack on conversions that
2424 // might lose information from the integer, e.g. "i64 -> float"
2425 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2426 int mantissaWidth = floatTy.getFPMantissaWidth();
2427 if (mantissaWidth <= 0)
2428 return failure();
2429
2430 bool isUnsigned;
2431 Value intVal;
2432
2433 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2434 isUnsigned = false;
2435 intVal = si.getIn();
2436 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2437 isUnsigned = true;
2438 intVal = ui.getIn();
2439 } else {
2440 return failure();
2441 }
2442
2443 // Check to see that the input is converted from an integer type that is
2444 // small enough that preserves all bits.
2445 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2446 auto intWidth = intTy.getWidth();
2447
2448 // Number of bits representing values, as opposed to the sign
2449 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2450
2451 // Following test does NOT adjust intWidth downwards for signed inputs,
2452 // because the most negative value still requires all the mantissa bits
2453 // to distinguish it from one less than that value.
2454 if ((int)intWidth > mantissaWidth) {
2455 // Conversion would lose accuracy. Check if loss can impact comparison.
2456 int exponent = ilogb(rhs);
2457 if (exponent == APFloat::IEK_Inf) {
2458 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2459 if (maxExponent < (int)valueBits) {
2460 // Conversion could create infinity.
2461 return failure();
2462 }
2463 } else {
2464 // Note that if rhs is zero or NaN, then Exp is negative
2465 // and first condition is trivially false.
2466 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2467 // Conversion could affect comparison.
2468 return failure();
2469 }
2470 }
2471 }
2472
2473 // Convert to equivalent cmpi predicate
2474 CmpIPredicate pred;
2475 switch (op.getPredicate()) {
2476 case CmpFPredicate::ORD:
2477 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2478 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2479 /*width=*/1);
2480 return success();
2481 case CmpFPredicate::UNO:
2482 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2483 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2484 /*width=*/1);
2485 return success();
2486 default:
2487 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2488 break;
2489 }
2490
2491 if (!isUnsigned) {
2492 // If the rhs value is > SignedMax, fold the comparison. This handles
2493 // +INF and large values.
2494 APFloat signedMax(rhs.getSemantics());
2495 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2496 APFloat::rmNearestTiesToEven);
2497 if (signedMax < rhs) { // smax < 13123.0
2498 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2499 pred == CmpIPredicate::sle)
2500 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2501 /*width=*/1);
2502 else
2503 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2504 /*width=*/1);
2505 return success();
2506 }
2507 } else {
2508 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2509 // +INF and large values.
2510 APFloat unsignedMax(rhs.getSemantics());
2511 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2512 APFloat::rmNearestTiesToEven);
2513 if (unsignedMax < rhs) { // umax < 13123.0
2514 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2515 pred == CmpIPredicate::ule)
2516 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2517 /*width=*/1);
2518 else
2519 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2520 /*width=*/1);
2521 return success();
2522 }
2523 }
2524
2525 if (!isUnsigned) {
2526 // See if the rhs value is < SignedMin.
2527 APFloat signedMin(rhs.getSemantics());
2528 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2529 APFloat::rmNearestTiesToEven);
2530 if (signedMin > rhs) { // smin > 12312.0
2531 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2532 pred == CmpIPredicate::sge)
2533 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2534 /*width=*/1);
2535 else
2536 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2537 /*width=*/1);
2538 return success();
2539 }
2540 } else {
2541 // See if the rhs value is < UnsignedMin.
2542 APFloat unsignedMin(rhs.getSemantics());
2543 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2544 APFloat::rmNearestTiesToEven);
2545 if (unsignedMin > rhs) { // umin > 12312.0
2546 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2547 pred == CmpIPredicate::uge)
2548 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2549 /*width=*/1);
2550 else
2551 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2552 /*width=*/1);
2553 return success();
2554 }
2555 }
2556
2557 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2558 // [0, UMAX], but it may still be fractional. See if it is fractional by
2559 // casting the FP value to the integer value and back, checking for
2560 // equality. Don't do this for zero, because -0.0 is not fractional.
2561 bool ignored;
2562 APSInt rhsInt(intWidth, isUnsigned);
2563 if (APFloat::opInvalidOp ==
2564 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2565 // Undefined behavior invoked - the destination type can't represent
2566 // the input constant.
2567 return failure();
2568 }
2569
2570 if (!rhs.isZero()) {
2571 APFloat apf(floatTy.getFloatSemantics(),
2572 APInt::getZero(floatTy.getWidth()));
2573 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2574
2575 bool equal = apf == rhs;
2576 if (!equal) {
2577 // If we had a comparison against a fractional value, we have to adjust
2578 // the compare predicate and sometimes the value. rhsInt is rounded
2579 // towards zero at this point.
2580 switch (pred) {
2581 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2582 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2583 /*width=*/1);
2584 return success();
2585 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2586 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2587 /*width=*/1);
2588 return success();
2589 case CmpIPredicate::ule:
2590 // (float)int <= 4.4 --> int <= 4
2591 // (float)int <= -4.4 --> false
2592 if (rhs.isNegative()) {
2593 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2594 /*width=*/1);
2595 return success();
2596 }
2597 break;
2598 case CmpIPredicate::sle:
2599 // (float)int <= 4.4 --> int <= 4
2600 // (float)int <= -4.4 --> int < -4
2601 if (rhs.isNegative())
2602 pred = CmpIPredicate::slt;
2603 break;
2604 case CmpIPredicate::ult:
2605 // (float)int < -4.4 --> false
2606 // (float)int < 4.4 --> int <= 4
2607 if (rhs.isNegative()) {
2608 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2609 /*width=*/1);
2610 return success();
2611 }
2612 pred = CmpIPredicate::ule;
2613 break;
2614 case CmpIPredicate::slt:
2615 // (float)int < -4.4 --> int < -4
2616 // (float)int < 4.4 --> int <= 4
2617 if (!rhs.isNegative())
2618 pred = CmpIPredicate::sle;
2619 break;
2620 case CmpIPredicate::ugt:
2621 // (float)int > 4.4 --> int > 4
2622 // (float)int > -4.4 --> true
2623 if (rhs.isNegative()) {
2624 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2625 /*width=*/1);
2626 return success();
2627 }
2628 break;
2629 case CmpIPredicate::sgt:
2630 // (float)int > 4.4 --> int > 4
2631 // (float)int > -4.4 --> int >= -4
2632 if (rhs.isNegative())
2633 pred = CmpIPredicate::sge;
2634 break;
2635 case CmpIPredicate::uge:
2636 // (float)int >= -4.4 --> true
2637 // (float)int >= 4.4 --> int > 4
2638 if (rhs.isNegative()) {
2639 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2640 /*width=*/1);
2641 return success();
2642 }
2643 pred = CmpIPredicate::ugt;
2644 break;
2645 case CmpIPredicate::sge:
2646 // (float)int >= -4.4 --> int >= -4
2647 // (float)int >= 4.4 --> int > 4
2648 if (!rhs.isNegative())
2649 pred = CmpIPredicate::sgt;
2650 break;
2651 }
2652 }
2653 }
2654
2655 // Lower this FP comparison into an appropriate integer version of the
2656 // comparison.
2657 rewriter.replaceOpWithNewOp<CmpIOp>(
2658 op, pred, intVal,
2659 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2660 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2661 return success();
2662 }
2663};
2664
2665void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2666 MLIRContext *context) {
2667 patterns.insert<CmpFIntToFPConst>(context);
2668}
2669
2670//===----------------------------------------------------------------------===//
2671// SelectOp
2672//===----------------------------------------------------------------------===//
2673
2674// select %arg, %c1, %c0 => extui %arg
2675struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2676 using Base::Base;
2677
2678 LogicalResult matchAndRewrite(arith::SelectOp op,
2679 PatternRewriter &rewriter) const override {
2680 // Cannot extui i1 to i1, or i1 to f32
2681 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2682 return failure();
2683
2684 // select %x, c1, %c0 => extui %arg
2685 if (matchPattern(op.getTrueValue(), m_One()) &&
2686 matchPattern(op.getFalseValue(), m_Zero())) {
2687 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2688 op.getCondition());
2689 return success();
2690 }
2691
2692 // select %x, c0, %c1 => extui (xor %arg, true)
2693 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2694 matchPattern(op.getFalseValue(), m_One())) {
2695 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2696 op, op.getType(),
2697 arith::XOrIOp::create(
2698 rewriter, op.getLoc(), op.getCondition(),
2699 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2700 op.getCondition().getType(), 1)));
2701 return success();
2702 }
2703
2704 return failure();
2705 }
2706};
2707
2708void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2709 MLIRContext *context) {
2710 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2711 SelectI1ToNot, SelectToExtUI>(context);
2712}
2713
2714OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2715 Value trueVal = getTrueValue();
2716 Value falseVal = getFalseValue();
2717 if (trueVal == falseVal)
2718 return trueVal;
2719
2720 Value condition = getCondition();
2721
2722 // select true, %0, %1 => %0
2723 if (matchPattern(adaptor.getCondition(), m_One()))
2724 return trueVal;
2725
2726 // select false, %0, %1 => %1
2727 if (matchPattern(adaptor.getCondition(), m_Zero()))
2728 return falseVal;
2729
2730 // If either operand is fully poisoned, return the other.
2731 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2732 return falseVal;
2733
2734 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2735 return trueVal;
2736
2737 // select %x, true, false => %x
2738 if (getType().isSignlessInteger(1) &&
2739 matchPattern(adaptor.getTrueValue(), m_One()) &&
2740 matchPattern(adaptor.getFalseValue(), m_Zero()))
2741 return condition;
2742
2743 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2744 auto pred = cmp.getPredicate();
2745 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2746 auto cmpLhs = cmp.getLhs();
2747 auto cmpRhs = cmp.getRhs();
2748
2749 // %0 = arith.cmpi eq, %arg0, %arg1
2750 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2751
2752 // %0 = arith.cmpi ne, %arg0, %arg1
2753 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2754
2755 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2756 (cmpRhs == trueVal && cmpLhs == falseVal))
2757 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2758 }
2759 }
2760
2761 // Constant-fold constant operands over non-splat constant condition.
2762 // select %cst_vec, %cst0, %cst1 => %cst2
2763 if (auto cond =
2764 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2765 // DenseElementsAttr by construction always has a static shape.
2766 assert(cond.getType().hasStaticShape() &&
2767 "DenseElementsAttr must have static shape");
2768 if (auto lhs =
2769 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2770 if (auto rhs =
2771 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2772 SmallVector<Attribute> results;
2773 results.reserve(static_cast<size_t>(cond.getNumElements()));
2774 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2775 cond.value_end<BoolAttr>());
2776 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2777 lhs.value_end<Attribute>());
2778 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2779 rhs.value_end<Attribute>());
2780
2781 for (auto [condVal, lhsVal, rhsVal] :
2782 llvm::zip_equal(condVals, lhsVals, rhsVals))
2783 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2784
2785 return DenseElementsAttr::get(lhs.getType(), results);
2786 }
2787 }
2788 }
2789
2790 return nullptr;
2791}
2792
2793ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2794 Type conditionType, resultType;
2795 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2796 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2797 parser.parseOptionalAttrDict(result.attributes) ||
2798 parser.parseColonType(resultType))
2799 return failure();
2800
2801 // Check for the explicit condition type if this is a masked tensor or vector.
2802 if (succeeded(parser.parseOptionalComma())) {
2803 conditionType = resultType;
2804 if (parser.parseType(resultType))
2805 return failure();
2806 } else {
2807 conditionType = parser.getBuilder().getI1Type();
2808 }
2809
2810 result.addTypes(resultType);
2811 return parser.resolveOperands(operands,
2812 {conditionType, resultType, resultType},
2813 parser.getNameLoc(), result.operands);
2814}
2815
2816void arith::SelectOp::print(OpAsmPrinter &p) {
2817 p << " " << getOperands();
2818 p.printOptionalAttrDict((*this)->getAttrs());
2819 p << " : ";
2820 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2821 p << condType << ", ";
2822 p << getType();
2823}
2824
2825LogicalResult arith::SelectOp::verify() {
2826 Type conditionType = getCondition().getType();
2827 if (conditionType.isSignlessInteger(1))
2828 return success();
2829
2830 // If the result type is a vector or tensor, the type can be a mask with the
2831 // same elements.
2832 Type resultType = getType();
2833 if (!llvm::isa<TensorType, VectorType>(resultType))
2834 return emitOpError() << "expected condition to be a signless i1, but got "
2835 << conditionType;
2836 Type shapedConditionType = getI1SameShape(resultType);
2837 if (conditionType != shapedConditionType) {
2838 return emitOpError() << "expected condition type to have the same shape "
2839 "as the result type, expected "
2840 << shapedConditionType << ", but got "
2841 << conditionType;
2842 }
2843 return success();
2844}
2845//===----------------------------------------------------------------------===//
2846// ShLIOp
2847//===----------------------------------------------------------------------===//
2848
2849OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2850 // shli(x, 0) -> x
2851 if (matchPattern(adaptor.getRhs(), m_Zero()))
2852 return getLhs();
2853 // Don't fold if shifting more or equal than the bit width.
2854 bool bounded = false;
2856 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2857 bounded = b.ult(b.getBitWidth());
2858 return a.shl(b);
2859 });
2860 return bounded ? result : Attribute();
2861}
2862
2863//===----------------------------------------------------------------------===//
2864// ShRUIOp
2865//===----------------------------------------------------------------------===//
2866
2867OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2868 // shrui(x, 0) -> x
2869 if (matchPattern(adaptor.getRhs(), m_Zero()))
2870 return getLhs();
2871 // Don't fold if shifting more or equal than the bit width.
2872 bool bounded = false;
2874 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2875 bounded = b.ult(b.getBitWidth());
2876 return a.lshr(b);
2877 });
2878 return bounded ? result : Attribute();
2879}
2880
2881//===----------------------------------------------------------------------===//
2882// ShRSIOp
2883//===----------------------------------------------------------------------===//
2884
2885OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2886 // shrsi(x, 0) -> x
2887 if (matchPattern(adaptor.getRhs(), m_Zero()))
2888 return getLhs();
2889 // Don't fold if shifting more or equal than the bit width.
2890 bool bounded = false;
2892 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2893 bounded = b.ult(b.getBitWidth());
2894 return a.ashr(b);
2895 });
2896 return bounded ? result : Attribute();
2897}
2898
2899//===----------------------------------------------------------------------===//
2900// Atomic Enum
2901//===----------------------------------------------------------------------===//
2902
2903/// Returns the identity value attribute associated with an AtomicRMWKind op.
2904TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2905 OpBuilder &builder, Location loc,
2906 bool useOnlyFiniteValue) {
2907 switch (kind) {
2908 case AtomicRMWKind::maximumf: {
2909 const llvm::fltSemantics &semantic =
2910 llvm::cast<FloatType>(resultType).getFloatSemantics();
2911 APFloat identity = useOnlyFiniteValue
2912 ? APFloat::getLargest(semantic, /*Negative=*/true)
2913 : APFloat::getInf(semantic, /*Negative=*/true);
2914 return builder.getFloatAttr(resultType, identity);
2915 }
2916 case AtomicRMWKind::maxnumf: {
2917 const llvm::fltSemantics &semantic =
2918 llvm::cast<FloatType>(resultType).getFloatSemantics();
2919 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2920 return builder.getFloatAttr(resultType, identity);
2921 }
2922 case AtomicRMWKind::addf:
2923 case AtomicRMWKind::addi:
2924 case AtomicRMWKind::maxu:
2925 case AtomicRMWKind::ori:
2926 case AtomicRMWKind::xori:
2927 return builder.getZeroAttr(resultType);
2928 case AtomicRMWKind::andi:
2929 return builder.getIntegerAttr(
2930 resultType,
2931 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2932 case AtomicRMWKind::maxs:
2933 return builder.getIntegerAttr(
2934 resultType, APInt::getSignedMinValue(
2935 llvm::cast<IntegerType>(resultType).getWidth()));
2936 case AtomicRMWKind::minimumf: {
2937 const llvm::fltSemantics &semantic =
2938 llvm::cast<FloatType>(resultType).getFloatSemantics();
2939 APFloat identity = useOnlyFiniteValue
2940 ? APFloat::getLargest(semantic, /*Negative=*/false)
2941 : APFloat::getInf(semantic, /*Negative=*/false);
2942
2943 return builder.getFloatAttr(resultType, identity);
2944 }
2945 case AtomicRMWKind::minnumf: {
2946 const llvm::fltSemantics &semantic =
2947 llvm::cast<FloatType>(resultType).getFloatSemantics();
2948 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2949 return builder.getFloatAttr(resultType, identity);
2950 }
2951 case AtomicRMWKind::mins:
2952 return builder.getIntegerAttr(
2953 resultType, APInt::getSignedMaxValue(
2954 llvm::cast<IntegerType>(resultType).getWidth()));
2955 case AtomicRMWKind::minu:
2956 return builder.getIntegerAttr(
2957 resultType,
2958 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2959 case AtomicRMWKind::muli:
2960 return builder.getIntegerAttr(resultType, 1);
2961 case AtomicRMWKind::mulf:
2962 return builder.getFloatAttr(resultType, 1);
2963 // TODO: Add remaining reduction operations.
2964 default:
2965 (void)emitOptionalError(loc, "Reduction operation type not supported");
2966 break;
2967 }
2968 return nullptr;
2969}
2970
2971/// Returns the identity numeric value of the given op.
2972std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2973 std::optional<AtomicRMWKind> maybeKind =
2975 // Floating-point operations.
2976 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2977 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2978 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2979 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2980 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2981 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2982 // Integer operations.
2983 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2984 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2985 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2986 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2987 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2988 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2989 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2990 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2991 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2992 .Default(std::nullopt);
2993 if (!maybeKind) {
2994 return std::nullopt;
2995 }
2996
2997 bool useOnlyFiniteValue = false;
2998 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2999 if (fmfOpInterface) {
3000 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
3001 useOnlyFiniteValue =
3002 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
3003 }
3004
3005 // Builder only used as helper for attribute creation.
3006 OpBuilder b(op->getContext());
3007 Type resultType = op->getResult(0).getType();
3008
3009 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
3010 useOnlyFiniteValue);
3011}
3012
3013/// Returns the identity value associated with an AtomicRMWKind op.
3014Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
3015 OpBuilder &builder, Location loc,
3016 bool useOnlyFiniteValue) {
3017 if (auto attr = getIdentityValueAttr(op, resultType, builder, loc,
3018 useOnlyFiniteValue))
3019 return arith::ConstantOp::create(builder, loc, attr);
3020 return {};
3021}
3022
3023/// Return the value obtained by applying the reduction operation kind
3024/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
3026 Location loc, Value lhs, Value rhs) {
3027 switch (op) {
3028 case AtomicRMWKind::addf:
3029 return arith::AddFOp::create(builder, loc, lhs, rhs);
3030 case AtomicRMWKind::addi:
3031 return arith::AddIOp::create(builder, loc, lhs, rhs);
3032 case AtomicRMWKind::mulf:
3033 return arith::MulFOp::create(builder, loc, lhs, rhs);
3034 case AtomicRMWKind::muli:
3035 return arith::MulIOp::create(builder, loc, lhs, rhs);
3036 case AtomicRMWKind::maximumf:
3037 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
3038 case AtomicRMWKind::minimumf:
3039 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
3040 case AtomicRMWKind::maxnumf:
3041 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
3042 case AtomicRMWKind::minnumf:
3043 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
3044 case AtomicRMWKind::maxs:
3045 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
3046 case AtomicRMWKind::mins:
3047 return arith::MinSIOp::create(builder, loc, lhs, rhs);
3048 case AtomicRMWKind::maxu:
3049 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
3050 case AtomicRMWKind::minu:
3051 return arith::MinUIOp::create(builder, loc, lhs, rhs);
3052 case AtomicRMWKind::ori:
3053 return arith::OrIOp::create(builder, loc, lhs, rhs);
3054 case AtomicRMWKind::andi:
3055 return arith::AndIOp::create(builder, loc, lhs, rhs);
3056 case AtomicRMWKind::xori:
3057 return arith::XOrIOp::create(builder, loc, lhs, rhs);
3058 // TODO: Add remaining reduction operations.
3059 default:
3060 (void)emitOptionalError(loc, "Reduction operation type not supported");
3061 break;
3062 }
3063 return nullptr;
3064}
3065
3066//===----------------------------------------------------------------------===//
3067// TableGen'd op method definitions
3068//===----------------------------------------------------------------------===//
3069
3070#define GET_OP_CLASSES
3071#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3072
3073//===----------------------------------------------------------------------===//
3074// TableGen'd enum attribute definitions
3075//===----------------------------------------------------------------------===//
3076
3077#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:802
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:65
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:72
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
Definition ArithOps.cpp:38
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:763
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:112
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:862
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:844
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:55
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:155
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:60
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:135
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:179
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs)
Definition ArithOps.cpp:507
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:46
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:443
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:147
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:93
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:340
static bool classof(Operation *op)
Definition ArithOps.cpp:357
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:334
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:363
static bool classof(Operation *op)
Definition ArithOps.cpp:384
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:55
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:261
static bool classof(Operation *op)
Definition ArithOps.cpp:328
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:79
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:390
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.