MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37/// Default rounding mode according to default LLVM floating-point environment.
38static constexpr llvm::RoundingMode kDefaultRoundingMode =
39 llvm::RoundingMode::NearestTiesToEven;
40
41//===----------------------------------------------------------------------===//
42// Pattern helpers
43//===----------------------------------------------------------------------===//
44
45static IntegerAttr
48 function_ref<APInt(const APInt &, const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.getType(), value);
53}
54
55static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
57 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
58}
59
60static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
62 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
63}
64
65static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
67 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
68}
69
70static IntegerAttr andIntegerAttrs(PatternRewriter &builder, Value res,
72 return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_and<APInt>());
73}
74
75static IntegerAttr orIntegerAttrs(PatternRewriter &builder, Value res,
77 return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_or<APInt>());
78}
79
80static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res,
82 return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_xor<APInt>());
83}
84
85// Merge overflow flags from 2 ops, selecting the most conservative combination.
86static IntegerOverflowFlagsAttr
87mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
88 IntegerOverflowFlagsAttr val2) {
89 return IntegerOverflowFlagsAttr::get(val1.getContext(),
90 val1.getValue() & val2.getValue());
91}
92
93/// Invert an integer comparison predicate.
94arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
95 switch (pred) {
96 case arith::CmpIPredicate::eq:
97 return arith::CmpIPredicate::ne;
98 case arith::CmpIPredicate::ne:
99 return arith::CmpIPredicate::eq;
100 case arith::CmpIPredicate::slt:
101 return arith::CmpIPredicate::sge;
102 case arith::CmpIPredicate::sle:
103 return arith::CmpIPredicate::sgt;
104 case arith::CmpIPredicate::sgt:
105 return arith::CmpIPredicate::sle;
106 case arith::CmpIPredicate::sge:
107 return arith::CmpIPredicate::slt;
108 case arith::CmpIPredicate::ult:
109 return arith::CmpIPredicate::uge;
110 case arith::CmpIPredicate::ule:
111 return arith::CmpIPredicate::ugt;
112 case arith::CmpIPredicate::ugt:
113 return arith::CmpIPredicate::ule;
114 case arith::CmpIPredicate::uge:
115 return arith::CmpIPredicate::ult;
116 }
117 llvm_unreachable("unknown cmpi predicate kind");
118}
119
120/// Equivalent to
121/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
122///
123/// Not possible to implement as chain of calls as this would introduce a
124/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
125/// on the LLVM dialect and on translation to LLVM.
126static llvm::RoundingMode
127convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
128 if (!roundingMode)
130 switch (*roundingMode) {
131 case RoundingMode::downward:
132 return llvm::RoundingMode::TowardNegative;
133 case RoundingMode::to_nearest_away:
134 return llvm::RoundingMode::NearestTiesToAway;
135 case RoundingMode::to_nearest_even:
136 return llvm::RoundingMode::NearestTiesToEven;
137 case RoundingMode::toward_zero:
138 return llvm::RoundingMode::TowardZero;
139 case RoundingMode::upward:
140 return llvm::RoundingMode::TowardPositive;
141 }
142 llvm_unreachable("Unhandled rounding mode");
143}
144
145static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
146 return arith::CmpIPredicateAttr::get(pred.getContext(),
147 invertPredicate(pred.getValue()));
148}
149
151 Type elemTy = getElementTypeOrSelf(type);
152 if (elemTy.isIntOrFloat())
153 return elemTy.getIntOrFloatBitWidth();
154
155 return -1;
156}
157
159 return getScalarOrElementWidth(value.getType());
160}
161
162static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
163 APInt value;
164 if (matchPattern(attr, m_ConstantInt(&value)))
165 return value;
166
167 return failure();
168}
169
170static Attribute getBoolAttribute(Type type, bool value) {
171 auto boolAttr = BoolAttr::get(type.getContext(), value);
172 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
173 if (!shapedType)
174 return boolAttr;
175 // DenseElementsAttr requires a static shape.
176 if (!shapedType.hasStaticShape())
177 return {};
178 return DenseElementsAttr::get(shapedType, boolAttr);
179}
180
181//===----------------------------------------------------------------------===//
182// TableGen'd canonicalization patterns
183//===----------------------------------------------------------------------===//
184
185namespace {
186#include "ArithCanonicalization.inc"
187} // namespace
188
189//===----------------------------------------------------------------------===//
190// Common helpers
191//===----------------------------------------------------------------------===//
192
193/// Return the type of the same shape (scalar, vector or tensor) containing i1.
195 auto i1Type = IntegerType::get(type.getContext(), 1);
196 if (auto shapedType = dyn_cast<ShapedType>(type))
197 return shapedType.cloneWith(std::nullopt, i1Type);
198 if (llvm::isa<UnrankedTensorType>(type))
199 return UnrankedTensorType::get(i1Type);
200 return i1Type;
201}
202
203//===----------------------------------------------------------------------===//
204// ConstantOp
205//===----------------------------------------------------------------------===//
206
207void arith::ConstantOp::getAsmResultNames(
208 function_ref<void(Value, StringRef)> setNameFn) {
209 auto type = getType();
210 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
211 auto intType = dyn_cast<IntegerType>(type);
212
213 // Sugar i1 constants with 'true' and 'false'.
214 if (intType && intType.getWidth() == 1)
215 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
216
217 // Otherwise, build a complex name with the value and type.
218 SmallString<32> specialNameBuffer;
219 llvm::raw_svector_ostream specialName(specialNameBuffer);
220 specialName << 'c' << intCst.getValue();
221 if (intType)
222 specialName << '_' << type;
223 setNameFn(getResult(), specialName.str());
224 } else {
225 setNameFn(getResult(), "cst");
226 }
227}
228
229/// TODO: disallow arith.constant to return anything other than signless integer
230/// or float like.
231LogicalResult arith::ConstantOp::verify() {
232 auto type = getType();
233 // Integer values must be signless.
234 if (llvm::isa<IntegerType>(type) &&
235 !llvm::cast<IntegerType>(type).isSignless())
236 return emitOpError("integer return type must be signless");
237 // Any float or elements attribute are acceptable.
238 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
239 return emitOpError(
240 "value must be an integer, float, or elements attribute");
241 }
242
243 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
244 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
245 // However, this would most likely require updating the lowerings to LLVM.
246 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
247 return emitOpError(
248 "initializing scalable vectors with elements attribute is not supported"
249 " unless it's a vector splat");
250 return success();
251}
252
253bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
254 // The value's type must be the same as the provided type.
255 auto typedAttr = dyn_cast<TypedAttr>(value);
256 if (!typedAttr || typedAttr.getType() != type)
257 return false;
258 // Integer values must be signless.
259 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
260 if (!intType.isSignless())
261 return false;
262 }
263 // Integer, float, and element attributes are buildable.
264 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
265}
266
267ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
268 Type type, Location loc) {
269 if (isBuildableWith(value, type))
270 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
271 return nullptr;
272}
273
274OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
275
277 int64_t value, unsigned width) {
278 auto type = builder.getIntegerType(width);
279 arith::ConstantOp::build(builder, result, type,
280 builder.getIntegerAttr(type, value));
281}
282
284 Location location,
286 unsigned width) {
287 mlir::OperationState state(location, getOperationName());
288 build(builder, state, value, width);
289 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
290 assert(result && "builder didn't return the right type");
291 return result;
292}
293
296 unsigned width) {
297 return create(builder, builder.getLoc(), value, width);
298}
299
301 Type type, int64_t value) {
302 arith::ConstantOp::build(builder, result, type,
303 builder.getIntegerAttr(type, value));
304}
305
307 Location location, Type type,
308 int64_t value) {
309 mlir::OperationState state(location, getOperationName());
310 build(builder, state, type, value);
311 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
312 assert(result && "builder didn't return the right type");
313 return result;
314}
315
317 Type type, int64_t value) {
318 return create(builder, builder.getLoc(), type, value);
319}
320
322 Type type, const APInt &value) {
323 arith::ConstantOp::build(builder, result, type,
324 builder.getIntegerAttr(type, value));
325}
326
328 Location location, Type type,
329 const APInt &value) {
330 mlir::OperationState state(location, getOperationName());
331 build(builder, state, type, value);
332 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
333 assert(result && "builder didn't return the right type");
334 return result;
335}
336
338 Type type,
339 const APInt &value) {
340 return create(builder, builder.getLoc(), type, value);
341}
342
344 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
345 return constOp.getType().isSignlessInteger();
346 return false;
347}
348
350 FloatType type, const APFloat &value) {
351 arith::ConstantOp::build(builder, result, type,
352 builder.getFloatAttr(type, value));
353}
354
356 Location location,
357 FloatType type,
358 const APFloat &value) {
359 mlir::OperationState state(location, getOperationName());
360 build(builder, state, type, value);
361 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
362 assert(result && "builder didn't return the right type");
363 return result;
364}
365
368 const APFloat &value) {
369 return create(builder, builder.getLoc(), type, value);
370}
371
373 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
374 return llvm::isa<FloatType>(constOp.getType());
375 return false;
376}
377
379 int64_t value) {
380 arith::ConstantOp::build(builder, result, builder.getIndexType(),
381 builder.getIndexAttr(value));
382}
383
385 Location location,
386 int64_t value) {
387 mlir::OperationState state(location, getOperationName());
388 build(builder, state, value);
389 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
390 assert(result && "builder didn't return the right type");
391 return result;
392}
393
396 return create(builder, builder.getLoc(), value);
397}
398
400 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
401 return constOp.getType().isIndex();
402 return false;
403}
404
406 Type type) {
407 // TODO: Incorporate this check to `FloatAttr::get*`.
408 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
409 "type doesn't have a zero representation");
410 TypedAttr zeroAttr = builder.getZeroAttr(type);
411 assert(zeroAttr && "unsupported type for zero attribute");
412 return arith::ConstantOp::create(builder, loc, zeroAttr);
413}
414
415//===----------------------------------------------------------------------===//
416// AddIOp
417//===----------------------------------------------------------------------===//
418
419OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
420 // addi(x, 0) -> x
421 if (matchPattern(adaptor.getRhs(), m_Zero()))
422 return getLhs();
423
424 // addi(subi(a, b), b) -> a
425 if (auto sub = getLhs().getDefiningOp<SubIOp>())
426 if (getRhs() == sub.getRhs())
427 return sub.getLhs();
428
429 // addi(b, subi(a, b)) -> a
430 if (auto sub = getRhs().getDefiningOp<SubIOp>())
431 if (getLhs() == sub.getRhs())
432 return sub.getLhs();
433
435 adaptor.getOperands(),
436 [](APInt a, const APInt &b) { return std::move(a) + b; });
437}
438
439void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
440 MLIRContext *context) {
441 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
442 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
443}
444
445//===----------------------------------------------------------------------===//
446// AddUIExtendedOp
447//===----------------------------------------------------------------------===//
448
449std::optional<SmallVector<int64_t, 4>>
450arith::AddUIExtendedOp::getShapeForUnroll() {
451 if (auto vt = dyn_cast<VectorType>(getType(0)))
452 return llvm::to_vector<4>(vt.getShape());
453 return std::nullopt;
454}
455
456// Returns the overflow bit, assuming that `sum` is the result of unsigned
457// addition of `operand` and another number.
458static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
459 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
460}
461
462LogicalResult
463arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
464 SmallVectorImpl<OpFoldResult> &results) {
465 Type overflowTy = getOverflow().getType();
466 // addui_extended(x, 0) -> x, false
467 if (matchPattern(getRhs(), m_Zero())) {
468 Builder builder(getContext());
469 auto falseValue = builder.getZeroAttr(overflowTy);
470
471 results.push_back(getLhs());
472 results.push_back(falseValue);
473 return success();
474 }
475
476 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
477 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
478 // operands. If that succeeds, calculate the overflow bit based on the sum
479 // and the first (constant) operand, `lhs`.
480 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
481 adaptor.getOperands(),
482 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
483 // If any operand is poison, propagate poison to both results.
484 if (matchPattern(sumAttr, ub::m_Poison())) {
485 results.push_back(sumAttr);
486 results.push_back(sumAttr);
487 return success();
488 }
489 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
490 ArrayRef({sumAttr, adaptor.getLhs()}),
491 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
493 if (!overflowAttr)
494 return failure();
495
496 results.push_back(sumAttr);
497 results.push_back(overflowAttr);
498 return success();
499 }
500
501 return failure();
502}
503
504void arith::AddUIExtendedOp::getCanonicalizationPatterns(
505 RewritePatternSet &patterns, MLIRContext *context) {
506 patterns.add<AddUIExtendedToAddI>(context);
507}
508
509//===----------------------------------------------------------------------===//
510// SubUIExtendedOp
511//===----------------------------------------------------------------------===//
512
513std::optional<SmallVector<int64_t, 4>>
514arith::SubUIExtendedOp::getShapeForUnroll() {
515 if (auto vt = dyn_cast<VectorType>(getType(0)))
516 return llvm::to_vector<4>(vt.getShape());
517 return std::nullopt;
518}
519
520// Returns the borrow bit, assuming `lhs` and `rhs` are operands of an unsigned
521// subtraction whose mathematical result underflows iff `lhs < rhs`.
522static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs) {
523 return lhs.ult(rhs) ? APInt::getAllOnes(1) : APInt::getZero(1);
524}
525
526LogicalResult
527arith::SubUIExtendedOp::fold(FoldAdaptor adaptor,
528 SmallVectorImpl<OpFoldResult> &results) {
529 Type borrowTy = getBorrow().getType();
530 // subui_extended(x, 0) -> x, false
531 if (matchPattern(getRhs(), m_Zero())) {
532 Builder builder(getContext());
533 auto falseValue = builder.getZeroAttr(borrowTy);
534
535 results.push_back(getLhs());
536 results.push_back(falseValue);
537 return success();
538 }
539
540 // subui_extended(x, x) -> 0, false
541 if (getLhs() == getRhs()) {
542 Builder builder(getContext());
543 auto zeroDiff = builder.getZeroAttr(getDiff().getType());
544 auto falseValue = builder.getZeroAttr(borrowTy);
545 if (!zeroDiff)
546 return failure();
547
548 results.push_back(zeroDiff);
549 results.push_back(falseValue);
550 return success();
551 }
552
553 // subui_extended(constant_a, constant_b) -> constant_diff, constant_borrow
554 if (Attribute diffAttr = constFoldBinaryOp<IntegerAttr>(
555 adaptor.getOperands(),
556 [](APInt a, const APInt &b) { return std::move(a) - b; })) {
557 // If any operand is poison, propagate poison to both results.
558 if (matchPattern(diffAttr, ub::m_Poison())) {
559 results.push_back(diffAttr);
560 results.push_back(diffAttr);
561 return success();
562 }
563 Attribute borrowAttr = constFoldBinaryOp<IntegerAttr>(
564 adaptor.getOperands(),
565 getI1SameShape(llvm::cast<TypedAttr>(diffAttr).getType()),
567 if (!borrowAttr)
568 return failure();
569
570 results.push_back(diffAttr);
571 results.push_back(borrowAttr);
572 return success();
573 }
574
575 return failure();
576}
577
578void arith::SubUIExtendedOp::getCanonicalizationPatterns(
579 RewritePatternSet &patterns, MLIRContext *context) {
580 patterns.add<SubUIExtendedToSubI>(context);
581}
582
583//===----------------------------------------------------------------------===//
584// SubIOp
585//===----------------------------------------------------------------------===//
586
587OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
588 // subi(x,x) -> 0
589 if (getOperand(0) == getOperand(1)) {
590 auto shapedType = dyn_cast<ShapedType>(getType());
591 // We can't generate a constant with a dynamic shaped tensor.
592 if (!shapedType || shapedType.hasStaticShape())
593 return Builder(getContext()).getZeroAttr(getType());
594 }
595 // subi(x,0) -> x
596 if (matchPattern(adaptor.getRhs(), m_Zero()))
597 return getLhs();
598
599 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
600 // subi(addi(a, b), b) -> a
601 if (getRhs() == add.getRhs())
602 return add.getLhs();
603 // subi(addi(a, b), a) -> b
604 if (getRhs() == add.getLhs())
605 return add.getRhs();
606 }
607
608 // subi(a, subi(a, b)) -> b
609 if (auto sub = getRhs().getDefiningOp<SubIOp>())
610 if (getLhs() == sub.getLhs())
611 return sub.getRhs();
612
614 adaptor.getOperands(),
615 [](APInt a, const APInt &b) { return std::move(a) - b; });
616}
617
618void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
619 MLIRContext *context) {
620 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
621 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
622 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
623}
624
625//===----------------------------------------------------------------------===//
626// MulIOp
627//===----------------------------------------------------------------------===//
628
629OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
630 // muli(x, 0) -> 0
631 if (matchPattern(adaptor.getRhs(), m_Zero()))
632 return getRhs();
633 // muli(x, 1) -> x
634 if (matchPattern(adaptor.getRhs(), m_One()))
635 return getLhs();
636 // TODO: Handle the overflow case.
637
638 // default folder
640 adaptor.getOperands(),
641 [](const APInt &a, const APInt &b) { return a * b; });
642}
643
644void arith::MulIOp::getAsmResultNames(
645 function_ref<void(Value, StringRef)> setNameFn) {
646 if (!isa<IndexType>(getType()))
647 return;
648
649 // Match vector.vscale by name to avoid depending on the vector dialect (which
650 // is a circular dependency).
651 auto isVscale = [](Operation *op) {
652 return op && op->getName().getStringRef() == "vector.vscale";
653 };
654
655 IntegerAttr baseValue;
656 auto isVscaleExpr = [&](Value a, Value b) {
657 return matchPattern(a, m_Constant(&baseValue)) &&
658 isVscale(b.getDefiningOp());
659 };
660
661 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
662 return;
663
664 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
665 SmallString<32> specialNameBuffer;
666 llvm::raw_svector_ostream specialName(specialNameBuffer);
667 specialName << 'c' << baseValue.getInt() << "_vscale";
668 setNameFn(getResult(), specialName.str());
669}
670
671void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
672 MLIRContext *context) {
673 patterns.add<MulIMulIConstant>(context);
674}
675
676//===----------------------------------------------------------------------===//
677// MulSIExtendedOp
678//===----------------------------------------------------------------------===//
679
680std::optional<SmallVector<int64_t, 4>>
681arith::MulSIExtendedOp::getShapeForUnroll() {
682 if (auto vt = dyn_cast<VectorType>(getType(0)))
683 return llvm::to_vector<4>(vt.getShape());
684 return std::nullopt;
685}
686
687LogicalResult
688arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
689 SmallVectorImpl<OpFoldResult> &results) {
690 // mulsi_extended(x, 0) -> 0, 0
691 if (matchPattern(adaptor.getRhs(), m_Zero())) {
692 Attribute zero = adaptor.getRhs();
693 results.push_back(zero);
694 results.push_back(zero);
695 return success();
696 }
697
698 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
699 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
700 adaptor.getOperands(),
701 [](const APInt &a, const APInt &b) { return a * b; })) {
702 // Invoke the constant fold helper again to calculate the 'high' result.
703 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
704 llvm::APIntOps::mulhs);
705 assert(highAttr && "Unexpected constant-folding failure");
706
707 results.push_back(lowAttr);
708 results.push_back(highAttr);
709 return success();
710 }
711
712 return failure();
713}
714
715void arith::MulSIExtendedOp::getCanonicalizationPatterns(
716 RewritePatternSet &patterns, MLIRContext *context) {
717 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
718}
719
720//===----------------------------------------------------------------------===//
721// MulUIExtendedOp
722//===----------------------------------------------------------------------===//
723
724std::optional<SmallVector<int64_t, 4>>
725arith::MulUIExtendedOp::getShapeForUnroll() {
726 if (auto vt = dyn_cast<VectorType>(getType(0)))
727 return llvm::to_vector<4>(vt.getShape());
728 return std::nullopt;
729}
730
731LogicalResult
732arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
733 SmallVectorImpl<OpFoldResult> &results) {
734 // mului_extended(x, 0) -> 0, 0
735 if (matchPattern(adaptor.getRhs(), m_Zero())) {
736 Attribute zero = adaptor.getRhs();
737 results.push_back(zero);
738 results.push_back(zero);
739 return success();
740 }
741
742 // mului_extended(x, 1) -> x, 0
743 if (matchPattern(adaptor.getRhs(), m_One())) {
744 Builder builder(getContext());
745 Attribute zero = builder.getZeroAttr(getLhs().getType());
746 results.push_back(getLhs());
747 results.push_back(zero);
748 return success();
749 }
750
751 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
752 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
753 adaptor.getOperands(),
754 [](const APInt &a, const APInt &b) { return a * b; })) {
755 // Invoke the constant fold helper again to calculate the 'high' result.
756 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
757 llvm::APIntOps::mulhu);
758 assert(highAttr && "Unexpected constant-folding failure");
759
760 results.push_back(lowAttr);
761 results.push_back(highAttr);
762 return success();
763 }
764
765 return failure();
766}
767
768void arith::MulUIExtendedOp::getCanonicalizationPatterns(
769 RewritePatternSet &patterns, MLIRContext *context) {
770 patterns.add<MulUIExtendedToMulI>(context);
771}
772
773//===----------------------------------------------------------------------===//
774// DivUIOp
775//===----------------------------------------------------------------------===//
776
777/// Fold `(a * b) / b -> a`
779 arith::IntegerOverflowFlags ovfFlags) {
780 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
781 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
782 return {};
783
784 if (mul.getLhs() == rhs)
785 return mul.getRhs();
786
787 if (mul.getRhs() == rhs)
788 return mul.getLhs();
789
790 return {};
791}
792
793OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
794 // divui (x, 1) -> x.
795 if (matchPattern(adaptor.getRhs(), m_One()))
796 return getLhs();
797
798 // (a * b) / b -> a
799 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
800 return val;
801
802 // Don't fold if it would require a division by zero.
803 bool div0 = false;
804 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
805 [&](APInt a, const APInt &b) {
806 if (div0 || !b) {
807 div0 = true;
808 return a;
809 }
810 return a.udiv(b);
811 });
812
813 return div0 ? Attribute() : result;
814}
815
816/// Returns whether an unsigned division by `divisor` is speculatable.
818 // X / 0 => UB
819 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
821
823}
824
825Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
826 return getDivUISpeculatability(getRhs());
827}
828
829//===----------------------------------------------------------------------===//
830// DivSIOp
831//===----------------------------------------------------------------------===//
832
833OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
834 // divsi (x, 1) -> x.
835 if (matchPattern(adaptor.getRhs(), m_One()))
836 return getLhs();
837
838 // (a * b) / b -> a
839 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
840 return val;
841
842 // Don't fold if it would overflow or if it requires a division by zero.
843 bool overflowOrDiv0 = false;
845 adaptor.getOperands(), [&](APInt a, const APInt &b) {
846 if (overflowOrDiv0 || !b) {
847 overflowOrDiv0 = true;
848 return a;
849 }
850 return a.sdiv_ov(b, overflowOrDiv0);
851 });
852
853 return overflowOrDiv0 ? Attribute() : result;
854}
855
856/// Returns whether a signed division by `divisor` is speculatable. This
857/// function conservatively assumes that all signed division by -1 are not
858/// speculatable.
860 // X / 0 => UB
861 // INT_MIN / -1 => UB
862 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
865
867}
868
869Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
870 return getDivSISpeculatability(getRhs());
871}
872
873//===----------------------------------------------------------------------===//
874// Ceil and floor division folding helpers
875//===----------------------------------------------------------------------===//
876
877static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
878 bool &overflow) {
879 // Returns (a-1)/b + 1
880 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
881 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
882 return val.sadd_ov(one, overflow);
883}
884
885//===----------------------------------------------------------------------===//
886// CeilDivUIOp
887//===----------------------------------------------------------------------===//
888
889OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
890 // ceildivui (x, 1) -> x.
891 if (matchPattern(adaptor.getRhs(), m_One()))
892 return getLhs();
893
894 bool overflowOrDiv0 = false;
896 adaptor.getOperands(), [&](APInt a, const APInt &b) {
897 if (overflowOrDiv0 || !b) {
898 overflowOrDiv0 = true;
899 return a;
900 }
901 APInt quotient = a.udiv(b);
902 if (!a.urem(b))
903 return quotient;
904 APInt one(a.getBitWidth(), 1, true);
905 return quotient.uadd_ov(one, overflowOrDiv0);
906 });
907
908 return overflowOrDiv0 ? Attribute() : result;
909}
910
911Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
912 return getDivUISpeculatability(getRhs());
913}
914
915//===----------------------------------------------------------------------===//
916// CeilDivSIOp
917//===----------------------------------------------------------------------===//
918
919OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
920 // ceildivsi (x, 1) -> x.
921 if (matchPattern(adaptor.getRhs(), m_One()))
922 return getLhs();
923
924 // Don't fold if it would overflow or if it requires a division by zero.
925 // TODO: This hook won't fold operations where a = MININT, because
926 // negating MININT overflows. This can be improved.
927 bool overflowOrDiv0 = false;
929 adaptor.getOperands(), [&](APInt a, const APInt &b) {
930 if (overflowOrDiv0 || !b) {
931 overflowOrDiv0 = true;
932 return a;
933 }
934 if (!a)
935 return a;
936 // After this point we know that neither a or b are zero.
937 unsigned bits = a.getBitWidth();
938 APInt zero = APInt::getZero(bits);
939 bool aGtZero = a.sgt(zero);
940 bool bGtZero = b.sgt(zero);
941 if (aGtZero && bGtZero) {
942 // Both positive, return ceil(a, b).
943 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
944 }
945
946 // No folding happens if any of the intermediate arithmetic operations
947 // overflows.
948 bool overflowNegA = false;
949 bool overflowNegB = false;
950 bool overflowDiv = false;
951 bool overflowNegRes = false;
952 if (!aGtZero && !bGtZero) {
953 // Both negative, return ceil(-a, -b).
954 APInt posA = zero.ssub_ov(a, overflowNegA);
955 APInt posB = zero.ssub_ov(b, overflowNegB);
956 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
957 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
958 return res;
959 }
960 if (!aGtZero && bGtZero) {
961 // A is negative, b is positive, return - ( -a / b).
962 APInt posA = zero.ssub_ov(a, overflowNegA);
963 APInt div = posA.sdiv_ov(b, overflowDiv);
964 APInt res = zero.ssub_ov(div, overflowNegRes);
965 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
966 return res;
967 }
968 // A is positive, b is negative, return - (a / -b).
969 APInt posB = zero.ssub_ov(b, overflowNegB);
970 APInt div = a.sdiv_ov(posB, overflowDiv);
971 APInt res = zero.ssub_ov(div, overflowNegRes);
972
973 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
974 return res;
975 });
976
977 return overflowOrDiv0 ? Attribute() : result;
978}
979
980Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
981 return getDivSISpeculatability(getRhs());
982}
983
984//===----------------------------------------------------------------------===//
985// FloorDivSIOp
986//===----------------------------------------------------------------------===//
987
988OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
989 // floordivsi (x, 1) -> x.
990 if (matchPattern(adaptor.getRhs(), m_One()))
991 return getLhs();
992
993 // Don't fold if it would overflow or if it requires a division by zero.
994 bool overflowOrDiv = false;
996 adaptor.getOperands(), [&](APInt a, const APInt &b) {
997 if (b.isZero()) {
998 overflowOrDiv = true;
999 return a;
1000 }
1001 return a.sfloordiv_ov(b, overflowOrDiv);
1002 });
1003
1004 return overflowOrDiv ? Attribute() : result;
1005}
1006
1007//===----------------------------------------------------------------------===//
1008// RemUIOp
1009//===----------------------------------------------------------------------===//
1010
1011OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
1012 // remui (x, 1) -> 0.
1013 if (matchPattern(adaptor.getRhs(), m_One()))
1014 return Builder(getContext()).getZeroAttr(getType());
1015
1016 // Don't fold if it would require a division by zero.
1017 bool div0 = false;
1018 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1019 [&](APInt a, const APInt &b) {
1020 if (div0 || b.isZero()) {
1021 div0 = true;
1022 return a;
1023 }
1024 return a.urem(b);
1025 });
1026
1027 return div0 ? Attribute() : result;
1028}
1029
1030Speculation::Speculatability arith::RemUIOp::getSpeculatability() {
1031 return getDivUISpeculatability(getRhs());
1032}
1033
1034//===----------------------------------------------------------------------===//
1035// RemSIOp
1036//===----------------------------------------------------------------------===//
1037
1038OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
1039 // remsi (x, 1) -> 0.
1040 if (matchPattern(adaptor.getRhs(), m_One()))
1041 return Builder(getContext()).getZeroAttr(getType());
1042
1043 // Don't fold if it would require a division by zero.
1044 bool div0 = false;
1045 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046 [&](APInt a, const APInt &b) {
1047 if (div0 || b.isZero()) {
1048 div0 = true;
1049 return a;
1050 }
1051 return a.srem(b);
1052 });
1053
1054 return div0 ? Attribute() : result;
1055}
1056
1057Speculation::Speculatability arith::RemSIOp::getSpeculatability() {
1058 // X % 0 => UB
1059 // X % -1 is well-defined (always 0), unlike X / -1 which can overflow.
1060 if (matchPattern(getRhs(), m_IntRangeWithoutZeroS()))
1062
1064}
1065
1066//===----------------------------------------------------------------------===//
1067// AndIOp
1068//===----------------------------------------------------------------------===//
1069
1070/// Fold `and(a, and(a, b))` to `and(a, b)`
1071static Value foldAndIofAndI(arith::AndIOp op) {
1072 for (bool reversePrev : {false, true}) {
1073 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
1074 .getDefiningOp<arith::AndIOp>();
1075 if (!prev)
1076 continue;
1077
1078 Value other = (reversePrev ? op.getLhs() : op.getRhs());
1079 if (other != prev.getLhs() && other != prev.getRhs())
1080 continue;
1081
1082 return prev.getResult();
1083 }
1084 return {};
1085}
1086
1087OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
1088 /// and(x, 0) -> 0
1089 if (matchPattern(adaptor.getRhs(), m_Zero()))
1090 return getRhs();
1091 /// and(x, allOnes) -> x
1092 APInt intValue;
1093 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
1094 intValue.isAllOnes())
1095 return getLhs();
1096 /// and(x, not(x)) -> 0
1097 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1098 m_ConstantInt(&intValue))) &&
1099 intValue.isAllOnes())
1100 return Builder(getContext()).getZeroAttr(getType());
1101 /// and(not(x), x) -> 0
1102 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1103 m_ConstantInt(&intValue))) &&
1104 intValue.isAllOnes())
1105 return Builder(getContext()).getZeroAttr(getType());
1106
1107 /// and(a, and(a, b)) -> and(a, b)
1108 if (Value result = foldAndIofAndI(*this))
1109 return result;
1110
1112 adaptor.getOperands(),
1113 [](APInt a, const APInt &b) { return std::move(a) & b; });
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// OrIOp
1118//===----------------------------------------------------------------------===//
1119
1120OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1121 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1122 /// or(x, 0) -> x
1123 if (rhsVal.isZero())
1124 return getLhs();
1125 /// or(x, <all ones>) -> <all ones>
1126 if (rhsVal.isAllOnes())
1127 return adaptor.getRhs();
1128 }
1129
1130 APInt intValue;
1131 /// or(x, xor(x, 1)) -> 1
1132 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1133 m_ConstantInt(&intValue))) &&
1134 intValue.isAllOnes())
1135 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1136 /// or(xor(x, 1), x) -> 1
1137 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1138 m_ConstantInt(&intValue))) &&
1139 intValue.isAllOnes())
1140 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1141
1143 adaptor.getOperands(),
1144 [](APInt a, const APInt &b) { return std::move(a) | b; });
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// XOrIOp
1149//===----------------------------------------------------------------------===//
1150
1151OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1152 /// xor(x, 0) -> x
1153 if (matchPattern(adaptor.getRhs(), m_Zero()))
1154 return getLhs();
1155 /// xor(x, x) -> 0
1156 if (getLhs() == getRhs())
1157 return Builder(getContext()).getZeroAttr(getType());
1158 /// xor(xor(x, a), a) -> x
1159 /// xor(xor(a, x), a) -> x
1160 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1161 if (prev.getRhs() == getRhs())
1162 return prev.getLhs();
1163 if (prev.getLhs() == getRhs())
1164 return prev.getRhs();
1165 }
1166 /// xor(a, xor(x, a)) -> x
1167 /// xor(a, xor(a, x)) -> x
1168 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1169 if (prev.getRhs() == getLhs())
1170 return prev.getLhs();
1171 if (prev.getLhs() == getLhs())
1172 return prev.getRhs();
1173 }
1174
1176 adaptor.getOperands(),
1177 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1178}
1179
1180void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181 MLIRContext *context) {
1182 patterns.add<XOrIXOrIConstant, XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(
1183 context);
1184}
1185
1186//===----------------------------------------------------------------------===//
1187// NegFOp
1188//===----------------------------------------------------------------------===//
1189
1190OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1191 /// negf(negf(x)) -> x
1192 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1193 return op.getOperand();
1194 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1195 [](const APFloat &a) { return -a; });
1196}
1197
1198//===----------------------------------------------------------------------===//
1199// FlushDenormalsOp
1200//===----------------------------------------------------------------------===//
1201
1202OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1203 // TODO: Fold flush_denormals if the floating-point type does not support
1204 // denormals. There is currently no API to query this information from
1205 // APFloat.
1206
1207 // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
1208 if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1209 return op.getResult();
1210
1211 // Constant-fold flush_denormals if the operand is a constant.
1213 adaptor.getOperands(), [](const APFloat &a) {
1214 if (a.isDenormal())
1215 return APFloat::getZero(a.getSemantics(), a.isNegative());
1216 return a;
1217 });
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// AddFOp
1222//===----------------------------------------------------------------------===//
1223
1224OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1225 // addf(x, -0) -> x
1226 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1227 return getLhs();
1228
1229 auto rm = getRoundingmode();
1231 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1232 APFloat result(a);
1233 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1234 return result;
1235 });
1236}
1237
1238//===----------------------------------------------------------------------===//
1239// SubFOp
1240//===----------------------------------------------------------------------===//
1241
1242OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1243 // subf(x, +0) -> x
1244 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1245 return getLhs();
1246
1247 auto rm = getRoundingmode();
1249 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1250 APFloat result(a);
1251 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1252 return result;
1253 });
1254}
1255
1256void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1257 MLIRContext *context) {
1258 patterns.add<SubFOfNegZero>(context);
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// MaximumFOp
1263//===----------------------------------------------------------------------===//
1264
1265OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1266 // maximumf(x,x) -> x
1267 if (getLhs() == getRhs())
1268 return getRhs();
1269
1270 // maximumf(x, -inf) -> x
1271 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1272 return getLhs();
1273
1274 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maximum);
1275}
1276
1277//===----------------------------------------------------------------------===//
1278// MaxNumFOp
1279//===----------------------------------------------------------------------===//
1280
1281OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1282 // maxnumf(x,x) -> x
1283 if (getLhs() == getRhs())
1284 return getRhs();
1285
1286 // maxnumf(x, NaN) -> x
1287 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1288 return getLhs();
1289
1290 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1291}
1292
1293//===----------------------------------------------------------------------===//
1294// MaxSIOp
1295//===----------------------------------------------------------------------===//
1296
1297OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1298 // maxsi(x,x) -> x
1299 if (getLhs() == getRhs())
1300 return getRhs();
1301
1302 if (APInt intValue;
1303 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1304 // maxsi(x,MAX_INT) -> MAX_INT
1305 if (intValue.isMaxSignedValue())
1306 return getRhs();
1307 // maxsi(x, MIN_INT) -> x
1308 if (intValue.isMinSignedValue())
1309 return getLhs();
1310 }
1311
1312 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1313 llvm::APIntOps::smax);
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// MaxUIOp
1318//===----------------------------------------------------------------------===//
1319
1320OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1321 // maxui(x,x) -> x
1322 if (getLhs() == getRhs())
1323 return getRhs();
1324
1325 if (APInt intValue;
1326 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1327 // maxui(x,MAX_INT) -> MAX_INT
1328 if (intValue.isMaxValue())
1329 return getRhs();
1330 // maxui(x, MIN_INT) -> x
1331 if (intValue.isMinValue())
1332 return getLhs();
1333 }
1334
1335 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1336 llvm::APIntOps::umax);
1337}
1338
1339//===----------------------------------------------------------------------===//
1340// MinimumFOp
1341//===----------------------------------------------------------------------===//
1342
1343OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1344 // minimumf(x,x) -> x
1345 if (getLhs() == getRhs())
1346 return getRhs();
1347
1348 // minimumf(x, +inf) -> x
1349 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1350 return getLhs();
1351
1352 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minimum);
1353}
1354
1355//===----------------------------------------------------------------------===//
1356// MinNumFOp
1357//===----------------------------------------------------------------------===//
1358
1359OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1360 // minnumf(x,x) -> x
1361 if (getLhs() == getRhs())
1362 return getRhs();
1363
1364 // minnumf(x, NaN) -> x
1365 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1366 return getLhs();
1367
1368 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minnum);
1369}
1370
1371//===----------------------------------------------------------------------===//
1372// MinSIOp
1373//===----------------------------------------------------------------------===//
1374
1375OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1376 // minsi(x,x) -> x
1377 if (getLhs() == getRhs())
1378 return getRhs();
1379
1380 if (APInt intValue;
1381 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1382 // minsi(x,MIN_INT) -> MIN_INT
1383 if (intValue.isMinSignedValue())
1384 return getRhs();
1385 // minsi(x, MAX_INT) -> x
1386 if (intValue.isMaxSignedValue())
1387 return getLhs();
1388 }
1389
1390 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1391 llvm::APIntOps::smin);
1392}
1393
1394//===----------------------------------------------------------------------===//
1395// MinUIOp
1396//===----------------------------------------------------------------------===//
1397
1398OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1399 // minui(x,x) -> x
1400 if (getLhs() == getRhs())
1401 return getRhs();
1402
1403 if (APInt intValue;
1404 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1405 // minui(x,MIN_INT) -> MIN_INT
1406 if (intValue.isMinValue())
1407 return getRhs();
1408 // minui(x, MAX_INT) -> x
1409 if (intValue.isMaxValue())
1410 return getLhs();
1411 }
1412
1413 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1414 llvm::APIntOps::umin);
1415}
1416
1417//===----------------------------------------------------------------------===//
1418// MulFOp
1419//===----------------------------------------------------------------------===//
1420
1421OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1422 // mulf(x, 1) -> x
1423 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1424 return getLhs();
1425
1426 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1427 arith::FastMathFlags::nsz)) {
1428 // mulf(x, 0) -> 0
1429 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1430 return getRhs();
1431 }
1432
1433 auto rm = getRoundingmode();
1435 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1436 APFloat result(a);
1437 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1438 return result;
1439 });
1440}
1441
1442void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1443 MLIRContext *context) {
1444 patterns.add<MulFOfNegF>(context);
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// DivFOp
1449//===----------------------------------------------------------------------===//
1450
1451OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1452 // divf(x, 1) -> x
1453 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1454 return getLhs();
1455
1456 auto rm = getRoundingmode();
1458 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1459 APFloat result(a);
1460 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1461 return result;
1462 });
1463}
1464
1465void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1466 MLIRContext *context) {
1467 patterns.add<DivFOfNegF>(context);
1468}
1469
1470//===----------------------------------------------------------------------===//
1471// RemFOp
1472//===----------------------------------------------------------------------===//
1473
1474OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1475 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1476 [](const APFloat &a, const APFloat &b) {
1477 APFloat result(a);
1478 // APFloat::mod() offers the remainder
1479 // behavior we want, i.e. the result has
1480 // the sign of LHS operand.
1481 (void)result.mod(b);
1482 return result;
1483 });
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// Utility functions for verifying cast ops
1488//===----------------------------------------------------------------------===//
1489
1490template <typename... Types>
1491using type_list = std::tuple<Types...> *;
1492
1493/// Returns a non-null type only if the provided type is one of the allowed
1494/// types or one of the allowed shaped types of the allowed types. Returns the
1495/// element type if a valid shaped type is provided.
1496template <typename... ShapedTypes, typename... ElementTypes>
1499 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1500 return {};
1501
1502 auto underlyingType = getElementTypeOrSelf(type);
1503 if (!llvm::isa<ElementTypes...>(underlyingType))
1504 return {};
1505
1506 return underlyingType;
1507}
1508
1509/// Get allowed underlying types for vectors and tensors.
1510template <typename... ElementTypes>
1515
1516/// Get allowed underlying types for vectors, tensors, and memrefs.
1517template <typename... ElementTypes>
1523
1524/// Return false if both types are ranked tensor with mismatching encoding.
1525static bool hasSameEncoding(Type typeA, Type typeB) {
1526 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1527 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1528 if (!rankedTensorA || !rankedTensorB)
1529 return true;
1530 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1531}
1532
1534 if (inputs.size() != 1 || outputs.size() != 1)
1535 return false;
1536 if (!hasSameEncoding(inputs.front(), outputs.front()))
1537 return false;
1538 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1539}
1540
1541//===----------------------------------------------------------------------===//
1542// Verifiers for integer and floating point extension/truncation ops
1543//===----------------------------------------------------------------------===//
1544
1545// Extend ops can only extend to a wider type.
1546template <typename ValType, typename Op>
1547static LogicalResult verifyExtOp(Op op) {
1548 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1549 Type dstType = getElementTypeOrSelf(op.getType());
1550
1551 if (llvm::cast<ValType>(srcType).getWidth() >=
1552 llvm::cast<ValType>(dstType).getWidth())
1553 return op.emitError("result type ")
1554 << dstType << " must be wider than operand type " << srcType;
1555
1556 return success();
1557}
1558
1559// Truncate ops can only truncate to a shorter type.
1560template <typename ValType, typename Op>
1561static LogicalResult verifyTruncateOp(Op op) {
1562 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1563 Type dstType = getElementTypeOrSelf(op.getType());
1564
1565 if (llvm::cast<ValType>(srcType).getWidth() <=
1566 llvm::cast<ValType>(dstType).getWidth())
1567 return op.emitError("result type ")
1568 << dstType << " must be shorter than operand type " << srcType;
1569
1570 return success();
1571}
1572
1573/// Validate a cast that changes the width of a type.
1574template <template <typename> class WidthComparator, typename... ElementTypes>
1575static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1576 if (!areValidCastInputsAndOutputs(inputs, outputs))
1577 return false;
1578
1579 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1580 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1581 if (!srcType || !dstType)
1582 return false;
1583
1584 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1585 srcType.getIntOrFloatBitWidth());
1586}
1587
1588/// Attempts to convert `sourceValue` to an APFloat value with
1589/// `targetSemantics` and `roundingMode`, without any information loss.
1590static FailureOr<APFloat>
1591convertFloatValue(APFloat sourceValue,
1592 const llvm::fltSemantics &targetSemantics,
1593 llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
1594 // Reject special values that are not representable in the target type before
1595 // calling APFloat::convert, which would llvm_unreachable on them.
1596 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1597 if (sourceValue.isInfinity() &&
1598 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1599 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1600 return failure();
1601 if (sourceValue.isNaN() &&
1602 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1603 return failure();
1604
1605 bool losesInfo = false;
1606 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1607 if (losesInfo || status != APFloat::opOK)
1608 return failure();
1609
1610 return sourceValue;
1611}
1612
1613//===----------------------------------------------------------------------===//
1614// ExtUIOp
1615//===----------------------------------------------------------------------===//
1616
1617OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1618 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1619 getInMutable().assign(lhs.getIn());
1620 return getResult();
1621 }
1622
1623 Type resType = getElementTypeOrSelf(getType());
1624 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1626 adaptor.getOperands(), getType(),
1627 [bitWidth](const APInt &a, bool &castStatus) {
1628 return a.zext(bitWidth);
1629 });
1630}
1631
1632bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1634}
1635
1636LogicalResult arith::ExtUIOp::verify() {
1637 return verifyExtOp<IntegerType>(*this);
1638}
1639
1640//===----------------------------------------------------------------------===//
1641// ExtSIOp
1642//===----------------------------------------------------------------------===//
1643
1644OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1645 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1646 getInMutable().assign(lhs.getIn());
1647 return getResult();
1648 }
1649
1650 Type resType = getElementTypeOrSelf(getType());
1651 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1653 adaptor.getOperands(), getType(),
1654 [bitWidth](const APInt &a, bool &castStatus) {
1655 return a.sext(bitWidth);
1656 });
1657}
1658
1659bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1661}
1662
1663void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1664 MLIRContext *context) {
1665 patterns.add<ExtSIOfExtUI>(context);
1666}
1667
1668LogicalResult arith::ExtSIOp::verify() {
1669 return verifyExtOp<IntegerType>(*this);
1670}
1671
1672//===----------------------------------------------------------------------===//
1673// ExtFOp
1674//===----------------------------------------------------------------------===//
1675
1676/// Fold extension of float constants when there is no information loss due the
1677/// difference in fp semantics.
1678OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1679 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1680 if (truncFOp.getOperand().getType() == getType()) {
1681 arith::FastMathFlags truncFMF =
1682 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1683 bool isTruncContract =
1684 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1685 arith::FastMathFlags extFMF =
1686 getFastmath().value_or(arith::FastMathFlags::none);
1687 bool isExtContract =
1688 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1689 if (isTruncContract && isExtContract) {
1690 return truncFOp.getOperand();
1691 }
1692 }
1693 }
1694
1695 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1696 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1698 adaptor.getOperands(), getType(),
1699 [&targetSemantics](const APFloat &a, bool &castStatus) {
1700 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1701 if (failed(result)) {
1702 castStatus = false;
1703 return a;
1704 }
1705 return *result;
1706 });
1707}
1708
1709bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1710 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1711}
1712
1713LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1714
1715//===----------------------------------------------------------------------===//
1716// ScalingExtFOp
1717//===----------------------------------------------------------------------===//
1718
1719bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1720 TypeRange outputs) {
1721 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1722}
1723
1724LogicalResult arith::ScalingExtFOp::verify() {
1725 return verifyExtOp<FloatType>(*this);
1726}
1727
1728//===----------------------------------------------------------------------===//
1729// TruncIOp
1730//===----------------------------------------------------------------------===//
1731
1732OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1733 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1734 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1735 Value src = getOperand().getDefiningOp()->getOperand(0);
1736 Type srcType = getElementTypeOrSelf(src.getType());
1737 Type dstType = getElementTypeOrSelf(getType());
1738 // trunci(zexti(a)) -> trunci(a)
1739 // trunci(sexti(a)) -> trunci(a)
1740 if (llvm::cast<IntegerType>(srcType).getWidth() >
1741 llvm::cast<IntegerType>(dstType).getWidth()) {
1742 setOperand(src);
1743 return getResult();
1744 }
1745
1746 // trunci(zexti(a)) -> a
1747 // trunci(sexti(a)) -> a
1748 if (srcType == dstType)
1749 return src;
1750 }
1751
1752 // trunci(trunci(a)) -> trunci(a))
1753 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1754 setOperand(getOperand().getDefiningOp()->getOperand(0));
1755 return getResult();
1756 }
1757
1758 Type resType = getElementTypeOrSelf(getType());
1759 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1761 adaptor.getOperands(), getType(),
1762 [bitWidth](const APInt &a, bool &castStatus) {
1763 return a.trunc(bitWidth);
1764 });
1765}
1766
1767bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1768 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1769}
1770
1771void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1772 MLIRContext *context) {
1773 patterns
1774 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1775 context);
1776}
1777
1778LogicalResult arith::TruncIOp::verify() {
1779 return verifyTruncateOp<IntegerType>(*this);
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// TruncFOp
1784//===----------------------------------------------------------------------===//
1785
1786/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1787/// can be represented without precision loss.
1788OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1789 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1790 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1791 Value src = extOp.getIn();
1792 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1793 auto intermediateType =
1794 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1795 // Check if the srcType is representable in the intermediateType.
1796 if (llvm::APFloatBase::isRepresentableBy(
1797 srcType.getFloatSemantics(),
1798 intermediateType.getFloatSemantics())) {
1799 // truncf(extf(a)) -> truncf(a)
1800 if (srcType.getWidth() > resElemType.getWidth()) {
1801 setOperand(src);
1802 return getResult();
1803 }
1804
1805 // truncf(extf(a)) -> a
1806 if (srcType == resElemType)
1807 return src;
1808 }
1809 }
1810
1811 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1813 adaptor.getOperands(), getType(),
1814 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1815 llvm::RoundingMode llvmRoundingMode =
1816 convertArithRoundingModeToLLVMIR(getRoundingmode());
1817 FailureOr<APFloat> result =
1818 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1819 if (failed(result)) {
1820 castStatus = false;
1821 return a;
1822 }
1823 return *result;
1824 });
1825}
1826
1827void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1828 MLIRContext *context) {
1829 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1830}
1831
1832bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1833 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1834}
1835
1836LogicalResult arith::TruncFOp::verify() {
1837 return verifyTruncateOp<FloatType>(*this);
1838}
1839
1840//===----------------------------------------------------------------------===//
1841// ConvertFOp
1842//===----------------------------------------------------------------------===//
1843
1844OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1845 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1846 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1848 adaptor.getOperands(), getType(),
1849 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1850 llvm::RoundingMode llvmRoundingMode =
1851 convertArithRoundingModeToLLVMIR(getRoundingmode());
1852 FailureOr<APFloat> result =
1853 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1854 if (failed(result)) {
1855 castStatus = false;
1856 return a;
1857 }
1858 return *result;
1859 });
1860}
1861
1862bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1863 if (!areValidCastInputsAndOutputs(inputs, outputs))
1864 return false;
1865 auto srcType = getTypeIfLike<FloatType>(inputs.front());
1866 auto dstType = getTypeIfLike<FloatType>(outputs.front());
1867 if (!srcType || !dstType)
1868 return false;
1869 return srcType != dstType &&
1870 srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1871}
1872
1873LogicalResult arith::ConvertFOp::verify() {
1874 auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
1875 auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
1876 if (srcType == dstType)
1877 return emitError("result element type ")
1878 << dstType << " must be different from operand element type "
1879 << srcType;
1880 if (srcType.getWidth() != dstType.getWidth())
1881 return emitError("result element type ")
1882 << dstType << " must have the same bitwidth as operand element type "
1883 << srcType;
1884 return success();
1885}
1886
1887//===----------------------------------------------------------------------===//
1888// ScalingTruncFOp
1889//===----------------------------------------------------------------------===//
1890
1891bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1892 TypeRange outputs) {
1893 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1894}
1895
1896LogicalResult arith::ScalingTruncFOp::verify() {
1897 return verifyTruncateOp<FloatType>(*this);
1898}
1899
1900//===----------------------------------------------------------------------===//
1901// AndIOp
1902//===----------------------------------------------------------------------===//
1903
1904void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1905 MLIRContext *context) {
1906 patterns.add<AndIAndIConstant, AndOfExtUI, AndOfExtSI>(context);
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// OrIOp
1911//===----------------------------------------------------------------------===//
1912
1913void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1914 MLIRContext *context) {
1915 patterns.add<OrIOrIConstant, OrOfExtUI, OrOfExtSI>(context);
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// Verifiers for casts between integers and floats.
1920//===----------------------------------------------------------------------===//
1921
1922template <typename From, typename To>
1923static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1924 if (!areValidCastInputsAndOutputs(inputs, outputs))
1925 return false;
1926
1927 auto srcType = getTypeIfLike<From>(inputs.front());
1928 auto dstType = getTypeIfLike<To>(outputs.back());
1929
1930 return srcType && dstType;
1931}
1932
1933//===----------------------------------------------------------------------===//
1934// UIToFPOp
1935//===----------------------------------------------------------------------===//
1936
1937bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1938 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1939}
1940
1941OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1942 Type resEleType = getElementTypeOrSelf(getType());
1944 adaptor.getOperands(), getType(),
1945 [&resEleType](const APInt &a, bool &castStatus) {
1946 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1947 APFloat apf(floatTy.getFloatSemantics(),
1948 APInt::getZero(floatTy.getWidth()));
1949 apf.convertFromAPInt(a, /*IsSigned=*/false,
1950 APFloat::rmNearestTiesToEven);
1951 return apf;
1952 });
1953}
1954
1955void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1956 MLIRContext *context) {
1957 patterns.add<UIToFPOfExtUI>(context);
1958}
1959
1960//===----------------------------------------------------------------------===//
1961// SIToFPOp
1962//===----------------------------------------------------------------------===//
1963
1964bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1965 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1966}
1967
1968OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1969 Type resEleType = getElementTypeOrSelf(getType());
1971 adaptor.getOperands(), getType(),
1972 [&resEleType](const APInt &a, bool &castStatus) {
1973 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1974 APFloat apf(floatTy.getFloatSemantics(),
1975 APInt::getZero(floatTy.getWidth()));
1976 apf.convertFromAPInt(a, /*IsSigned=*/true,
1977 APFloat::rmNearestTiesToEven);
1978 return apf;
1979 });
1980}
1981
1982void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1983 MLIRContext *context) {
1984 patterns.add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1985}
1986
1987//===----------------------------------------------------------------------===//
1988// FPToUIOp
1989//===----------------------------------------------------------------------===//
1990
1991bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1992 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1993}
1994
1995OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1996 Type resType = getElementTypeOrSelf(getType());
1997 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1999 adaptor.getOperands(), getType(),
2000 [&bitWidth](const APFloat &a, bool &castStatus) {
2001 bool ignored;
2002 APSInt api(bitWidth, /*isUnsigned=*/true);
2003 castStatus = APFloat::opInvalidOp !=
2004 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2005 return api;
2006 });
2007}
2008
2009//===----------------------------------------------------------------------===//
2010// FPToSIOp
2011//===----------------------------------------------------------------------===//
2012
2013bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2014 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
2015}
2016
2017OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
2018 Type resType = getElementTypeOrSelf(getType());
2019 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
2021 adaptor.getOperands(), getType(),
2022 [&bitWidth](const APFloat &a, bool &castStatus) {
2023 bool ignored;
2024 APSInt api(bitWidth, /*isUnsigned=*/false);
2025 castStatus = APFloat::opInvalidOp !=
2026 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
2027 return api;
2028 });
2029}
2030
2031//===----------------------------------------------------------------------===//
2032// IndexCastOp
2033//===----------------------------------------------------------------------===//
2034
2035/// Return the bit-width of \p t for the purpose of index_cast width checks.
2036/// For vector types use the element type; index maps to its internal storage
2037/// width (64 on all current targets).
2038static unsigned getIndexCastWidth(Type t) {
2039 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
2040 return intTy.getWidth();
2041 return IndexType::kInternalStorageBitWidth;
2042}
2043
2044static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
2045 if (!areValidCastInputsAndOutputs(inputs, outputs))
2046 return false;
2047
2048 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
2049 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
2050 if (!srcType || !dstType)
2051 return false;
2052
2053 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
2054 (srcType.isSignlessInteger() && dstType.isIndex());
2055}
2056
2057bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
2058 TypeRange outputs) {
2059 return areIndexCastCompatible(inputs, outputs);
2060}
2061
2062OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
2063 // index_cast(constant) -> constant
2064 unsigned resultBitwidth = 64; // Default for index integer attributes.
2065 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2066 resultBitwidth = intTy.getWidth();
2067
2068 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2069 adaptor.getOperands(), getType(),
2070 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2071 return a.sextOrTrunc(resultBitwidth);
2072 }))
2073 return foldResult;
2074
2075 // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
2076 // as wide as A. If B is narrower, the inner cast truncates and the outer
2077 // cast sign-extends, so the round-trip is lossy.
2078 if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
2079 Value x = inner.getOperand();
2080 if (x.getType() == getType()) {
2081 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2082 return x;
2083 }
2084 }
2085 return {};
2086}
2087
2088void arith::IndexCastOp::getCanonicalizationPatterns(
2089 RewritePatternSet &patterns, MLIRContext *context) {
2090 patterns.add<IndexCastOfExtSI>(context);
2091}
2092
2093//===----------------------------------------------------------------------===//
2094// IndexCastUIOp
2095//===----------------------------------------------------------------------===//
2096
2097bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
2098 TypeRange outputs) {
2099 return areIndexCastCompatible(inputs, outputs);
2100}
2101
2102OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2103 // index_castui(constant) -> constant
2104 unsigned resultBitwidth = 64; // Default for index integer attributes.
2105 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2106 resultBitwidth = intTy.getWidth();
2107
2108 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2109 adaptor.getOperands(), getType(),
2110 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2111 return a.zextOrTrunc(resultBitwidth);
2112 }))
2113 return foldResult;
2114
2115 // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
2116 // least as wide as A. If B is narrower, the inner cast truncates and the
2117 // outer cast zero-extends, so the round-trip is lossy.
2118 if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2119 Value x = inner.getOperand();
2120 if (x.getType() == getType()) {
2121 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2122 return x;
2123 }
2124 }
2125 return {};
2126}
2127
2128void arith::IndexCastUIOp::getCanonicalizationPatterns(
2129 RewritePatternSet &patterns, MLIRContext *context) {
2130 patterns.add<IndexCastUIOfExtUI>(context);
2131}
2132
2133//===----------------------------------------------------------------------===//
2134// BitcastOp
2135//===----------------------------------------------------------------------===//
2136
2137bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2138 if (!areValidCastInputsAndOutputs(inputs, outputs))
2139 return false;
2140
2141 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
2142 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
2143 if (!srcType || !dstType)
2144 return false;
2145
2146 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
2147}
2148
2149OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2150 auto resType = getType();
2151 auto operand = adaptor.getIn();
2152 if (!operand)
2153 return {};
2154
2155 /// Bitcast dense elements.
2156 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2157 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
2158 /// Other shaped types unhandled.
2159 if (llvm::isa<ShapedType>(resType))
2160 return {};
2161
2162 /// Bitcast poison.
2163 if (matchPattern(operand, ub::m_Poison()))
2164 return ub::PoisonAttr::get(getContext());
2165
2166 /// Bitcast integer or float to integer or float.
2167 APInt bits = llvm::isa<FloatAttr>(operand)
2168 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2169 : llvm::cast<IntegerAttr>(operand).getValue();
2170 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
2171 "trying to fold on broken IR: operands have incompatible types");
2172
2173 if (auto resFloatType = dyn_cast<FloatType>(resType))
2174 return FloatAttr::get(resType,
2175 APFloat(resFloatType.getFloatSemantics(), bits));
2176 return IntegerAttr::get(resType, bits);
2177}
2178
2179void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2180 MLIRContext *context) {
2181 patterns.add<BitcastOfBitcast>(context);
2182}
2183
2184//===----------------------------------------------------------------------===//
2185// CmpIOp
2186//===----------------------------------------------------------------------===//
2187
2188/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
2189/// comparison predicates.
2190bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
2191 const APInt &lhs, const APInt &rhs) {
2192 switch (predicate) {
2193 case arith::CmpIPredicate::eq:
2194 return lhs.eq(rhs);
2195 case arith::CmpIPredicate::ne:
2196 return lhs.ne(rhs);
2197 case arith::CmpIPredicate::slt:
2198 return lhs.slt(rhs);
2199 case arith::CmpIPredicate::sle:
2200 return lhs.sle(rhs);
2201 case arith::CmpIPredicate::sgt:
2202 return lhs.sgt(rhs);
2203 case arith::CmpIPredicate::sge:
2204 return lhs.sge(rhs);
2205 case arith::CmpIPredicate::ult:
2206 return lhs.ult(rhs);
2207 case arith::CmpIPredicate::ule:
2208 return lhs.ule(rhs);
2209 case arith::CmpIPredicate::ugt:
2210 return lhs.ugt(rhs);
2211 case arith::CmpIPredicate::uge:
2212 return lhs.uge(rhs);
2213 }
2214 llvm_unreachable("unknown cmpi predicate kind");
2215}
2216
2217/// Returns true if the predicate is true for two equal operands.
2218static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
2219 switch (predicate) {
2220 case arith::CmpIPredicate::eq:
2221 case arith::CmpIPredicate::sle:
2222 case arith::CmpIPredicate::sge:
2223 case arith::CmpIPredicate::ule:
2224 case arith::CmpIPredicate::uge:
2225 return true;
2226 case arith::CmpIPredicate::ne:
2227 case arith::CmpIPredicate::slt:
2228 case arith::CmpIPredicate::sgt:
2229 case arith::CmpIPredicate::ult:
2230 case arith::CmpIPredicate::ugt:
2231 return false;
2232 }
2233 llvm_unreachable("unknown cmpi predicate kind");
2234}
2235
2236static std::optional<int64_t> getIntegerWidth(Type t) {
2237 if (auto intType = dyn_cast<IntegerType>(t)) {
2238 return intType.getWidth();
2239 }
2240 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2241 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2242 }
2243 return std::nullopt;
2244}
2245
2246OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2247 // cmpi(pred, x, x)
2248 if (getLhs() == getRhs()) {
2249 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2250 return getBoolAttribute(getType(), val);
2251 }
2252
2253 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2254 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2255 // extsi(%x : i1 -> iN) != 0 -> %x
2256 std::optional<int64_t> integerWidth =
2257 getIntegerWidth(extOp.getOperand().getType());
2258 if (integerWidth && integerWidth.value() == 1 &&
2259 getPredicate() == arith::CmpIPredicate::ne)
2260 return extOp.getOperand();
2261 }
2262 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2263 // extui(%x : i1 -> iN) != 0 -> %x
2264 std::optional<int64_t> integerWidth =
2265 getIntegerWidth(extOp.getOperand().getType());
2266 if (integerWidth && integerWidth.value() == 1 &&
2267 getPredicate() == arith::CmpIPredicate::ne)
2268 return extOp.getOperand();
2269 }
2270
2271 // arith.cmpi ne, %val, %zero : i1 -> %val
2272 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2273 getPredicate() == arith::CmpIPredicate::ne)
2274 return getLhs();
2275 }
2276
2277 if (matchPattern(adaptor.getRhs(), m_One())) {
2278 // arith.cmpi eq, %val, %one : i1 -> %val
2279 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2280 getPredicate() == arith::CmpIPredicate::eq)
2281 return getLhs();
2282 }
2283
2284 // Move constant to the right side.
2285 if (adaptor.getLhs() && !adaptor.getRhs()) {
2286 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2287 using Pred = CmpIPredicate;
2288 const std::pair<Pred, Pred> invPreds[] = {
2289 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2290 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2291 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2292 {Pred::ne, Pred::ne},
2293 };
2294 Pred origPred = getPredicate();
2295 for (auto pred : invPreds) {
2296 if (origPred == pred.first) {
2297 setPredicate(pred.second);
2298 Value lhs = getLhs();
2299 Value rhs = getRhs();
2300 getLhsMutable().assign(rhs);
2301 getRhsMutable().assign(lhs);
2302 return getResult();
2303 }
2304 }
2305 llvm_unreachable("unknown cmpi predicate kind");
2306 }
2307
2308 // We are moving constants to the right side; So if lhs is constant rhs is
2309 // guaranteed to be a constant.
2310 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2312 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2313 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2314 return APInt(1,
2315 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2316 });
2317 }
2318
2319 return {};
2320}
2321
2322void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2323 MLIRContext *context) {
2324 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2325}
2326
2327//===----------------------------------------------------------------------===//
2328// CmpFOp
2329//===----------------------------------------------------------------------===//
2330
2331/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2332/// comparison predicates.
2333bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2334 const APFloat &lhs, const APFloat &rhs) {
2335 auto cmpResult = lhs.compare(rhs);
2336 switch (predicate) {
2337 case arith::CmpFPredicate::AlwaysFalse:
2338 return false;
2339 case arith::CmpFPredicate::OEQ:
2340 return cmpResult == APFloat::cmpEqual;
2341 case arith::CmpFPredicate::OGT:
2342 return cmpResult == APFloat::cmpGreaterThan;
2343 case arith::CmpFPredicate::OGE:
2344 return cmpResult == APFloat::cmpGreaterThan ||
2345 cmpResult == APFloat::cmpEqual;
2346 case arith::CmpFPredicate::OLT:
2347 return cmpResult == APFloat::cmpLessThan;
2348 case arith::CmpFPredicate::OLE:
2349 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2350 case arith::CmpFPredicate::ONE:
2351 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2352 case arith::CmpFPredicate::ORD:
2353 return cmpResult != APFloat::cmpUnordered;
2354 case arith::CmpFPredicate::UEQ:
2355 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2356 case arith::CmpFPredicate::UGT:
2357 return cmpResult == APFloat::cmpUnordered ||
2358 cmpResult == APFloat::cmpGreaterThan;
2359 case arith::CmpFPredicate::UGE:
2360 return cmpResult == APFloat::cmpUnordered ||
2361 cmpResult == APFloat::cmpGreaterThan ||
2362 cmpResult == APFloat::cmpEqual;
2363 case arith::CmpFPredicate::ULT:
2364 return cmpResult == APFloat::cmpUnordered ||
2365 cmpResult == APFloat::cmpLessThan;
2366 case arith::CmpFPredicate::ULE:
2367 return cmpResult == APFloat::cmpUnordered ||
2368 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2369 case arith::CmpFPredicate::UNE:
2370 return cmpResult != APFloat::cmpEqual;
2371 case arith::CmpFPredicate::UNO:
2372 return cmpResult == APFloat::cmpUnordered;
2373 case arith::CmpFPredicate::AlwaysTrue:
2374 return true;
2375 }
2376 llvm_unreachable("unknown cmpf predicate kind");
2377}
2378
2379OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2380 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2381 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2382
2383 // If one operand is NaN, making them both NaN does not change the result.
2384 if (lhs && lhs.getValue().isNaN())
2385 rhs = lhs;
2386 if (rhs && rhs.getValue().isNaN())
2387 lhs = rhs;
2388
2389 if (!lhs || !rhs)
2390 return {};
2391
2392 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2393 return BoolAttr::get(getContext(), val);
2394}
2395
2396class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2397public:
2398 using Base::Base;
2399
2400 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2401 bool isUnsigned) {
2402 using namespace arith;
2403 switch (pred) {
2404 case CmpFPredicate::UEQ:
2405 case CmpFPredicate::OEQ:
2406 return CmpIPredicate::eq;
2407 case CmpFPredicate::UGT:
2408 case CmpFPredicate::OGT:
2409 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2410 case CmpFPredicate::UGE:
2411 case CmpFPredicate::OGE:
2412 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2413 case CmpFPredicate::ULT:
2414 case CmpFPredicate::OLT:
2415 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2416 case CmpFPredicate::ULE:
2417 case CmpFPredicate::OLE:
2418 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2419 case CmpFPredicate::UNE:
2420 case CmpFPredicate::ONE:
2421 return CmpIPredicate::ne;
2422 default:
2423 llvm_unreachable("Unexpected predicate!");
2424 }
2425 }
2426
2427 LogicalResult matchAndRewrite(CmpFOp op,
2428 PatternRewriter &rewriter) const override {
2429 FloatAttr flt;
2430 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2431 return failure();
2432
2433 const APFloat &rhs = flt.getValue();
2434
2435 // Don't attempt to fold a nan.
2436 if (rhs.isNaN())
2437 return failure();
2438
2439 // Get the width of the mantissa. We don't want to hack on conversions that
2440 // might lose information from the integer, e.g. "i64 -> float"
2441 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2442 int mantissaWidth = floatTy.getFPMantissaWidth();
2443 if (mantissaWidth <= 0)
2444 return failure();
2445
2446 bool isUnsigned;
2447 Value intVal;
2448
2449 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2450 isUnsigned = false;
2451 intVal = si.getIn();
2452 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2453 isUnsigned = true;
2454 intVal = ui.getIn();
2455 } else {
2456 return failure();
2457 }
2458
2459 // Check to see that the input is converted from an integer type that is
2460 // small enough that preserves all bits.
2461 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2462 auto intWidth = intTy.getWidth();
2463
2464 // Number of bits representing values, as opposed to the sign
2465 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2466
2467 // Following test does NOT adjust intWidth downwards for signed inputs,
2468 // because the most negative value still requires all the mantissa bits
2469 // to distinguish it from one less than that value.
2470 if ((int)intWidth > mantissaWidth) {
2471 // Conversion would lose accuracy. Check if loss can impact comparison.
2472 int exponent = ilogb(rhs);
2473 if (exponent == APFloat::IEK_Inf) {
2474 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2475 if (maxExponent < (int)valueBits) {
2476 // Conversion could create infinity.
2477 return failure();
2478 }
2479 } else {
2480 // Note that if rhs is zero or NaN, then Exp is negative
2481 // and first condition is trivially false.
2482 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2483 // Conversion could affect comparison.
2484 return failure();
2485 }
2486 }
2487 }
2488
2489 // Convert to equivalent cmpi predicate
2490 CmpIPredicate pred;
2491 switch (op.getPredicate()) {
2492 case CmpFPredicate::ORD:
2493 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2494 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2495 /*width=*/1);
2496 return success();
2497 case CmpFPredicate::UNO:
2498 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2499 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2500 /*width=*/1);
2501 return success();
2502 default:
2503 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2504 break;
2505 }
2506
2507 if (!isUnsigned) {
2508 // If the rhs value is > SignedMax, fold the comparison. This handles
2509 // +INF and large values.
2510 APFloat signedMax(rhs.getSemantics());
2511 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2512 APFloat::rmNearestTiesToEven);
2513 if (signedMax < rhs) { // smax < 13123.0
2514 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2515 pred == CmpIPredicate::sle)
2516 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2517 /*width=*/1);
2518 else
2519 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2520 /*width=*/1);
2521 return success();
2522 }
2523 } else {
2524 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2525 // +INF and large values.
2526 APFloat unsignedMax(rhs.getSemantics());
2527 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2528 APFloat::rmNearestTiesToEven);
2529 if (unsignedMax < rhs) { // umax < 13123.0
2530 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2531 pred == CmpIPredicate::ule)
2532 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2533 /*width=*/1);
2534 else
2535 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2536 /*width=*/1);
2537 return success();
2538 }
2539 }
2540
2541 if (!isUnsigned) {
2542 // See if the rhs value is < SignedMin.
2543 APFloat signedMin(rhs.getSemantics());
2544 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2545 APFloat::rmNearestTiesToEven);
2546 if (signedMin > rhs) { // smin > 12312.0
2547 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2548 pred == CmpIPredicate::sge)
2549 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2550 /*width=*/1);
2551 else
2552 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2553 /*width=*/1);
2554 return success();
2555 }
2556 } else {
2557 // See if the rhs value is < UnsignedMin.
2558 APFloat unsignedMin(rhs.getSemantics());
2559 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2560 APFloat::rmNearestTiesToEven);
2561 if (unsignedMin > rhs) { // umin > 12312.0
2562 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2563 pred == CmpIPredicate::uge)
2564 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2565 /*width=*/1);
2566 else
2567 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2568 /*width=*/1);
2569 return success();
2570 }
2571 }
2572
2573 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2574 // [0, UMAX], but it may still be fractional. See if it is fractional by
2575 // casting the FP value to the integer value and back, checking for
2576 // equality. Don't do this for zero, because -0.0 is not fractional.
2577 bool ignored;
2578 APSInt rhsInt(intWidth, isUnsigned);
2579 if (APFloat::opInvalidOp ==
2580 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2581 // Undefined behavior invoked - the destination type can't represent
2582 // the input constant.
2583 return failure();
2584 }
2585
2586 if (!rhs.isZero()) {
2587 APFloat apf(floatTy.getFloatSemantics(),
2588 APInt::getZero(floatTy.getWidth()));
2589 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2590
2591 bool equal = apf == rhs;
2592 if (!equal) {
2593 // If we had a comparison against a fractional value, we have to adjust
2594 // the compare predicate and sometimes the value. rhsInt is rounded
2595 // towards zero at this point.
2596 switch (pred) {
2597 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2598 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2599 /*width=*/1);
2600 return success();
2601 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2602 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2603 /*width=*/1);
2604 return success();
2605 case CmpIPredicate::ule:
2606 // (float)int <= 4.4 --> int <= 4
2607 // (float)int <= -4.4 --> false
2608 if (rhs.isNegative()) {
2609 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2610 /*width=*/1);
2611 return success();
2612 }
2613 break;
2614 case CmpIPredicate::sle:
2615 // (float)int <= 4.4 --> int <= 4
2616 // (float)int <= -4.4 --> int < -4
2617 if (rhs.isNegative())
2618 pred = CmpIPredicate::slt;
2619 break;
2620 case CmpIPredicate::ult:
2621 // (float)int < -4.4 --> false
2622 // (float)int < 4.4 --> int <= 4
2623 if (rhs.isNegative()) {
2624 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2625 /*width=*/1);
2626 return success();
2627 }
2628 pred = CmpIPredicate::ule;
2629 break;
2630 case CmpIPredicate::slt:
2631 // (float)int < -4.4 --> int < -4
2632 // (float)int < 4.4 --> int <= 4
2633 if (!rhs.isNegative())
2634 pred = CmpIPredicate::sle;
2635 break;
2636 case CmpIPredicate::ugt:
2637 // (float)int > 4.4 --> int > 4
2638 // (float)int > -4.4 --> true
2639 if (rhs.isNegative()) {
2640 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2641 /*width=*/1);
2642 return success();
2643 }
2644 break;
2645 case CmpIPredicate::sgt:
2646 // (float)int > 4.4 --> int > 4
2647 // (float)int > -4.4 --> int >= -4
2648 if (rhs.isNegative())
2649 pred = CmpIPredicate::sge;
2650 break;
2651 case CmpIPredicate::uge:
2652 // (float)int >= -4.4 --> true
2653 // (float)int >= 4.4 --> int > 4
2654 if (rhs.isNegative()) {
2655 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2656 /*width=*/1);
2657 return success();
2658 }
2659 pred = CmpIPredicate::ugt;
2660 break;
2661 case CmpIPredicate::sge:
2662 // (float)int >= -4.4 --> int >= -4
2663 // (float)int >= 4.4 --> int > 4
2664 if (!rhs.isNegative())
2665 pred = CmpIPredicate::sgt;
2666 break;
2667 }
2668 }
2669 }
2670
2671 // Lower this FP comparison into an appropriate integer version of the
2672 // comparison.
2673 rewriter.replaceOpWithNewOp<CmpIOp>(
2674 op, pred, intVal,
2675 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2676 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2677 return success();
2678 }
2679};
2680
2681void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2682 MLIRContext *context) {
2683 patterns.insert<CmpFIntToFPConst>(context);
2684}
2685
2686//===----------------------------------------------------------------------===//
2687// SelectOp
2688//===----------------------------------------------------------------------===//
2689
2690// select %arg, %c1, %c0 => extui %arg
2691struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2692 using Base::Base;
2693
2694 LogicalResult matchAndRewrite(arith::SelectOp op,
2695 PatternRewriter &rewriter) const override {
2696 // Cannot extui i1 to i1, or i1 to f32
2697 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2698 return failure();
2699
2700 // select %x, c1, %c0 => extui %arg
2701 if (matchPattern(op.getTrueValue(), m_One()) &&
2702 matchPattern(op.getFalseValue(), m_Zero())) {
2703 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2704 op.getCondition());
2705 return success();
2706 }
2707
2708 // select %x, c0, %c1 => extui (xor %arg, true)
2709 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2710 matchPattern(op.getFalseValue(), m_One())) {
2711 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2712 op, op.getType(),
2713 arith::XOrIOp::create(
2714 rewriter, op.getLoc(), op.getCondition(),
2715 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2716 op.getCondition().getType(), 1)));
2717 return success();
2718 }
2719
2720 return failure();
2721 }
2722};
2723
2724void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2725 MLIRContext *context) {
2726 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2727 SelectI1ToNot, SelectToExtUI>(context);
2728}
2729
2730OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2731 Value trueVal = getTrueValue();
2732 Value falseVal = getFalseValue();
2733 if (trueVal == falseVal)
2734 return trueVal;
2735
2736 Value condition = getCondition();
2737
2738 // select true, %0, %1 => %0
2739 if (matchPattern(adaptor.getCondition(), m_One()))
2740 return trueVal;
2741
2742 // select false, %0, %1 => %1
2743 if (matchPattern(adaptor.getCondition(), m_Zero()))
2744 return falseVal;
2745
2746 // If either operand is fully poisoned, return the other.
2747 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2748 return falseVal;
2749
2750 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2751 return trueVal;
2752
2753 // select %x, true, false => %x
2754 if (getType().isSignlessInteger(1) &&
2755 matchPattern(adaptor.getTrueValue(), m_One()) &&
2756 matchPattern(adaptor.getFalseValue(), m_Zero()))
2757 return condition;
2758
2759 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2760 auto pred = cmp.getPredicate();
2761 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2762 auto cmpLhs = cmp.getLhs();
2763 auto cmpRhs = cmp.getRhs();
2764
2765 // %0 = arith.cmpi eq, %arg0, %arg1
2766 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2767
2768 // %0 = arith.cmpi ne, %arg0, %arg1
2769 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2770
2771 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2772 (cmpRhs == trueVal && cmpLhs == falseVal))
2773 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2774 }
2775 }
2776
2777 // Constant-fold constant operands over non-splat constant condition.
2778 // select %cst_vec, %cst0, %cst1 => %cst2
2779 if (auto cond =
2780 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2781 // DenseElementsAttr by construction always has a static shape.
2782 assert(cond.getType().hasStaticShape() &&
2783 "DenseElementsAttr must have static shape");
2784 if (auto lhs =
2785 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2786 if (auto rhs =
2787 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2788 SmallVector<Attribute> results;
2789 results.reserve(static_cast<size_t>(cond.getNumElements()));
2790 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2791 cond.value_end<BoolAttr>());
2792 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2793 lhs.value_end<Attribute>());
2794 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2795 rhs.value_end<Attribute>());
2796
2797 for (auto [condVal, lhsVal, rhsVal] :
2798 llvm::zip_equal(condVals, lhsVals, rhsVals))
2799 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2800
2801 return DenseElementsAttr::get(lhs.getType(), results);
2802 }
2803 }
2804 }
2805
2806 return nullptr;
2807}
2808
2809ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2810 Type conditionType, resultType;
2811 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2812 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2813 parser.parseOptionalAttrDict(result.attributes) ||
2814 parser.parseColonType(resultType))
2815 return failure();
2816
2817 // Check for the explicit condition type if this is a masked tensor or vector.
2818 if (succeeded(parser.parseOptionalComma())) {
2819 conditionType = resultType;
2820 if (parser.parseType(resultType))
2821 return failure();
2822 } else {
2823 conditionType = parser.getBuilder().getI1Type();
2824 }
2825
2826 result.addTypes(resultType);
2827 return parser.resolveOperands(operands,
2828 {conditionType, resultType, resultType},
2829 parser.getNameLoc(), result.operands);
2830}
2831
2832void arith::SelectOp::print(OpAsmPrinter &p) {
2833 p << " " << getOperands();
2834 p.printOptionalAttrDict((*this)->getAttrs());
2835 p << " : ";
2836 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2837 p << condType << ", ";
2838 p << getType();
2839}
2840
2841LogicalResult arith::SelectOp::verify() {
2842 Type conditionType = getCondition().getType();
2843 if (conditionType.isSignlessInteger(1))
2844 return success();
2845
2846 // If the result type is a vector or tensor, the type can be a mask with the
2847 // same elements.
2848 Type resultType = getType();
2849 if (!llvm::isa<TensorType, VectorType>(resultType))
2850 return emitOpError() << "expected condition to be a signless i1, but got "
2851 << conditionType;
2852 Type shapedConditionType = getI1SameShape(resultType);
2853 if (conditionType != shapedConditionType) {
2854 return emitOpError() << "expected condition type to have the same shape "
2855 "as the result type, expected "
2856 << shapedConditionType << ", but got "
2857 << conditionType;
2858 }
2859 return success();
2860}
2861//===----------------------------------------------------------------------===//
2862// ShLIOp
2863//===----------------------------------------------------------------------===//
2864
2865OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2866 // shli(x, 0) -> x
2867 if (matchPattern(adaptor.getRhs(), m_Zero()))
2868 return getLhs();
2869 // Don't fold if shifting more or equal than the bit width.
2870 bool bounded = false;
2872 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2873 bounded = b.ult(b.getBitWidth());
2874 return a.shl(b);
2875 });
2876 return bounded ? result : Attribute();
2877}
2878
2879//===----------------------------------------------------------------------===//
2880// ShRUIOp
2881//===----------------------------------------------------------------------===//
2882
2883OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2884 // shrui(x, 0) -> x
2885 if (matchPattern(adaptor.getRhs(), m_Zero()))
2886 return getLhs();
2887 // Don't fold if shifting more or equal than the bit width.
2888 bool bounded = false;
2890 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2891 bounded = b.ult(b.getBitWidth());
2892 return a.lshr(b);
2893 });
2894 return bounded ? result : Attribute();
2895}
2896
2897//===----------------------------------------------------------------------===//
2898// ShRSIOp
2899//===----------------------------------------------------------------------===//
2900
2901OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2902 // shrsi(x, 0) -> x
2903 if (matchPattern(adaptor.getRhs(), m_Zero()))
2904 return getLhs();
2905 // Don't fold if shifting more or equal than the bit width.
2906 bool bounded = false;
2908 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2909 bounded = b.ult(b.getBitWidth());
2910 return a.ashr(b);
2911 });
2912 return bounded ? result : Attribute();
2913}
2914
2915//===----------------------------------------------------------------------===//
2916// Atomic Enum
2917//===----------------------------------------------------------------------===//
2918
2919/// Returns the identity value attribute associated with an AtomicRMWKind op.
2920TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2921 OpBuilder &builder, Location loc,
2922 bool useOnlyFiniteValue) {
2923 switch (kind) {
2924 case AtomicRMWKind::maximumf: {
2925 const llvm::fltSemantics &semantic =
2926 llvm::cast<FloatType>(resultType).getFloatSemantics();
2927 APFloat identity = useOnlyFiniteValue
2928 ? APFloat::getLargest(semantic, /*Negative=*/true)
2929 : APFloat::getInf(semantic, /*Negative=*/true);
2930 return builder.getFloatAttr(resultType, identity);
2931 }
2932 case AtomicRMWKind::maxnumf: {
2933 const llvm::fltSemantics &semantic =
2934 llvm::cast<FloatType>(resultType).getFloatSemantics();
2935 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2936 return builder.getFloatAttr(resultType, identity);
2937 }
2938 case AtomicRMWKind::addf:
2939 case AtomicRMWKind::addi:
2940 case AtomicRMWKind::maxu:
2941 case AtomicRMWKind::ori:
2942 case AtomicRMWKind::xori:
2943 return builder.getZeroAttr(resultType);
2944 case AtomicRMWKind::andi:
2945 return builder.getIntegerAttr(
2946 resultType,
2947 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2948 case AtomicRMWKind::maxs:
2949 return builder.getIntegerAttr(
2950 resultType, APInt::getSignedMinValue(
2951 llvm::cast<IntegerType>(resultType).getWidth()));
2952 case AtomicRMWKind::minimumf: {
2953 const llvm::fltSemantics &semantic =
2954 llvm::cast<FloatType>(resultType).getFloatSemantics();
2955 APFloat identity = useOnlyFiniteValue
2956 ? APFloat::getLargest(semantic, /*Negative=*/false)
2957 : APFloat::getInf(semantic, /*Negative=*/false);
2958
2959 return builder.getFloatAttr(resultType, identity);
2960 }
2961 case AtomicRMWKind::minnumf: {
2962 const llvm::fltSemantics &semantic =
2963 llvm::cast<FloatType>(resultType).getFloatSemantics();
2964 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2965 return builder.getFloatAttr(resultType, identity);
2966 }
2967 case AtomicRMWKind::mins:
2968 return builder.getIntegerAttr(
2969 resultType, APInt::getSignedMaxValue(
2970 llvm::cast<IntegerType>(resultType).getWidth()));
2971 case AtomicRMWKind::minu:
2972 return builder.getIntegerAttr(
2973 resultType,
2974 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2975 case AtomicRMWKind::muli:
2976 return builder.getIntegerAttr(resultType, 1);
2977 case AtomicRMWKind::mulf:
2978 return builder.getFloatAttr(resultType, 1);
2979 // TODO: Add remaining reduction operations.
2980 default:
2981 (void)emitOptionalError(loc, "Reduction operation type not supported");
2982 break;
2983 }
2984 return nullptr;
2985}
2986
2987/// Returns the identity numeric value of the given op.
2988std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2989 std::optional<AtomicRMWKind> maybeKind =
2991 // Floating-point operations.
2992 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2993 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2994 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2995 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2996 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2997 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2998 // Integer operations.
2999 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
3000 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
3001 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
3002 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
3003 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
3004 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
3005 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
3006 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
3007 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
3008 .Default(std::nullopt);
3009 if (!maybeKind) {
3010 return std::nullopt;
3011 }
3012
3013 bool useOnlyFiniteValue = false;
3014 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
3015 if (fmfOpInterface) {
3016 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
3017 useOnlyFiniteValue =
3018 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
3019 }
3020
3021 // Builder only used as helper for attribute creation.
3022 OpBuilder b(op->getContext());
3023 Type resultType = op->getResult(0).getType();
3024
3025 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
3026 useOnlyFiniteValue);
3027}
3028
3029/// Returns the identity value associated with an AtomicRMWKind op.
3030Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
3031 OpBuilder &builder, Location loc,
3032 bool useOnlyFiniteValue) {
3033 if (auto attr = getIdentityValueAttr(op, resultType, builder, loc,
3034 useOnlyFiniteValue))
3035 return arith::ConstantOp::create(builder, loc, attr);
3036 return {};
3037}
3038
3039/// Return the value obtained by applying the reduction operation kind
3040/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
3042 Location loc, Value lhs, Value rhs) {
3043 switch (op) {
3044 case AtomicRMWKind::addf:
3045 return arith::AddFOp::create(builder, loc, lhs, rhs);
3046 case AtomicRMWKind::addi:
3047 return arith::AddIOp::create(builder, loc, lhs, rhs);
3048 case AtomicRMWKind::mulf:
3049 return arith::MulFOp::create(builder, loc, lhs, rhs);
3050 case AtomicRMWKind::muli:
3051 return arith::MulIOp::create(builder, loc, lhs, rhs);
3052 case AtomicRMWKind::maximumf:
3053 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
3054 case AtomicRMWKind::minimumf:
3055 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
3056 case AtomicRMWKind::maxnumf:
3057 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
3058 case AtomicRMWKind::minnumf:
3059 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
3060 case AtomicRMWKind::maxs:
3061 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
3062 case AtomicRMWKind::mins:
3063 return arith::MinSIOp::create(builder, loc, lhs, rhs);
3064 case AtomicRMWKind::maxu:
3065 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
3066 case AtomicRMWKind::minu:
3067 return arith::MinUIOp::create(builder, loc, lhs, rhs);
3068 case AtomicRMWKind::ori:
3069 return arith::OrIOp::create(builder, loc, lhs, rhs);
3070 case AtomicRMWKind::andi:
3071 return arith::AndIOp::create(builder, loc, lhs, rhs);
3072 case AtomicRMWKind::xori:
3073 return arith::XOrIOp::create(builder, loc, lhs, rhs);
3074 // TODO: Add remaining reduction operations.
3075 default:
3076 (void)emitOptionalError(loc, "Reduction operation type not supported");
3077 break;
3078 }
3079 return nullptr;
3080}
3081
3082//===----------------------------------------------------------------------===//
3083// TableGen'd op method definitions
3084//===----------------------------------------------------------------------===//
3085
3086#define GET_OP_CLASSES
3087#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3088
3089//===----------------------------------------------------------------------===//
3090// TableGen'd enum attribute definitions
3091//===----------------------------------------------------------------------===//
3092
3093#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:817
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:65
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:87
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
Definition ArithOps.cpp:38
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:778
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:127
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:877
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:859
static IntegerAttr orIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:75
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:55
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:170
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:60
static IntegerAttr andIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:70
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:150
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:194
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:80
static APInt calculateUnsignedBorrow(const APInt &lhs, const APInt &rhs)
Definition ArithOps.cpp:522
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:46
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:458
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:162
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:462
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:93
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:355
static bool classof(Operation *op)
Definition ArithOps.cpp:372
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:349
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:378
static bool classof(Operation *op)
Definition ArithOps.cpp:399
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:55
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:283
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:276
static bool classof(Operation *op)
Definition ArithOps.cpp:343
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:94
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:405
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.