MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
126 invertPredicate(pred.getValue()));
127}
128
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
138 return getScalarOrElementWidth(value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, m_ConstantInt(&value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(type.getContext(), value);
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152 if (!shapedType)
153 return boolAttr;
154 // DenseElementsAttr requires a static shape.
155 if (!shapedType.hasStaticShape())
156 return {};
157 return DenseElementsAttr::get(shapedType, boolAttr);
158}
159
160//===----------------------------------------------------------------------===//
161// TableGen'd canonicalization patterns
162//===----------------------------------------------------------------------===//
163
164namespace {
165#include "ArithCanonicalization.inc"
166} // namespace
167
168//===----------------------------------------------------------------------===//
169// Common helpers
170//===----------------------------------------------------------------------===//
171
172/// Return the type of the same shape (scalar, vector or tensor) containing i1.
174 auto i1Type = IntegerType::get(type.getContext(), 1);
175 if (auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
179 return i1Type;
180}
181
182//===----------------------------------------------------------------------===//
183// ConstantOp
184//===----------------------------------------------------------------------===//
185
186void arith::ConstantOp::getAsmResultNames(
187 function_ref<void(Value, StringRef)> setNameFn) {
188 auto type = getType();
189 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
191
192 // Sugar i1 constants with 'true' and 'false'.
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
195
196 // Otherwise, build a complex name with the value and type.
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName << 'c' << intCst.getValue();
200 if (intType)
201 specialName << '_' << type;
202 setNameFn(getResult(), specialName.str());
203 } else {
204 setNameFn(getResult(), "cst");
205 }
206}
207
208/// TODO: disallow arith.constant to return anything other than signless integer
209/// or float like.
210LogicalResult arith::ConstantOp::verify() {
211 auto type = getType();
212 // Integer values must be signless.
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError("integer return type must be signless");
216 // Any float or elements attribute are acceptable.
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
218 return emitOpError(
219 "value must be an integer, float, or elements attribute");
220 }
221
222 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
223 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
224 // However, this would most likely require updating the lowerings to LLVM.
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
226 return emitOpError(
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
229 return success();
230}
231
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
233 // The value's type must be the same as the provided type.
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
236 return false;
237 // Integer values must be signless.
238 if (llvm::isa<IntegerType>(type) &&
239 !llvm::cast<IntegerType>(type).isSignless())
240 return false;
241 // Integer, float, and element attributes are buildable.
242 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
243}
244
245ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
246 Type type, Location loc) {
247 if (isBuildableWith(value, type))
248 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
249 return nullptr;
250}
251
252OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
253
255 int64_t value, unsigned width) {
256 auto type = builder.getIntegerType(width);
257 arith::ConstantOp::build(builder, result, type,
258 builder.getIntegerAttr(type, value));
259}
260
262 Location location,
264 unsigned width) {
265 mlir::OperationState state(location, getOperationName());
266 build(builder, state, value, width);
267 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
268 assert(result && "builder didn't return the right type");
269 return result;
270}
271
274 unsigned width) {
275 return create(builder, builder.getLoc(), value, width);
276}
277
279 Type type, int64_t value) {
280 arith::ConstantOp::build(builder, result, type,
281 builder.getIntegerAttr(type, value));
282}
283
285 Location location, Type type,
286 int64_t value) {
287 mlir::OperationState state(location, getOperationName());
288 build(builder, state, type, value);
289 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
290 assert(result && "builder didn't return the right type");
291 return result;
292}
293
295 Type type, int64_t value) {
296 return create(builder, builder.getLoc(), type, value);
297}
298
300 Type type, const APInt &value) {
301 arith::ConstantOp::build(builder, result, type,
302 builder.getIntegerAttr(type, value));
303}
304
306 Location location, Type type,
307 const APInt &value) {
308 mlir::OperationState state(location, getOperationName());
309 build(builder, state, type, value);
310 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
311 assert(result && "builder didn't return the right type");
312 return result;
313}
314
316 Type type,
317 const APInt &value) {
318 return create(builder, builder.getLoc(), type, value);
319}
320
322 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
323 return constOp.getType().isSignlessInteger();
324 return false;
325}
326
328 FloatType type, const APFloat &value) {
329 arith::ConstantOp::build(builder, result, type,
330 builder.getFloatAttr(type, value));
331}
332
334 Location location,
335 FloatType type,
336 const APFloat &value) {
337 mlir::OperationState state(location, getOperationName());
338 build(builder, state, type, value);
339 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
340 assert(result && "builder didn't return the right type");
341 return result;
342}
343
346 const APFloat &value) {
347 return create(builder, builder.getLoc(), type, value);
348}
349
351 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
352 return llvm::isa<FloatType>(constOp.getType());
353 return false;
354}
355
357 int64_t value) {
358 arith::ConstantOp::build(builder, result, builder.getIndexType(),
359 builder.getIndexAttr(value));
360}
361
363 Location location,
364 int64_t value) {
365 mlir::OperationState state(location, getOperationName());
366 build(builder, state, value);
367 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
368 assert(result && "builder didn't return the right type");
369 return result;
370}
371
374 return create(builder, builder.getLoc(), value);
375}
376
378 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
379 return constOp.getType().isIndex();
380 return false;
381}
382
384 Type type) {
385 // TODO: Incorporate this check to `FloatAttr::get*`.
386 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
387 "type doesn't have a zero representation");
388 TypedAttr zeroAttr = builder.getZeroAttr(type);
389 assert(zeroAttr && "unsupported type for zero attribute");
390 return arith::ConstantOp::create(builder, loc, zeroAttr);
391}
392
393//===----------------------------------------------------------------------===//
394// AddIOp
395//===----------------------------------------------------------------------===//
396
397OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
398 // addi(x, 0) -> x
399 if (matchPattern(adaptor.getRhs(), m_Zero()))
400 return getLhs();
401
402 // addi(subi(a, b), b) -> a
403 if (auto sub = getLhs().getDefiningOp<SubIOp>())
404 if (getRhs() == sub.getRhs())
405 return sub.getLhs();
406
407 // addi(b, subi(a, b)) -> a
408 if (auto sub = getRhs().getDefiningOp<SubIOp>())
409 if (getLhs() == sub.getRhs())
410 return sub.getLhs();
411
413 adaptor.getOperands(),
414 [](APInt a, const APInt &b) { return std::move(a) + b; });
415}
416
417void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
418 MLIRContext *context) {
419 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
420 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
421}
422
423//===----------------------------------------------------------------------===//
424// AddUIExtendedOp
425//===----------------------------------------------------------------------===//
426
427std::optional<SmallVector<int64_t, 4>>
428arith::AddUIExtendedOp::getShapeForUnroll() {
429 if (auto vt = dyn_cast<VectorType>(getType(0)))
430 return llvm::to_vector<4>(vt.getShape());
431 return std::nullopt;
432}
433
434// Returns the overflow bit, assuming that `sum` is the result of unsigned
435// addition of `operand` and another number.
436static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
437 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
438}
439
440LogicalResult
441arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
442 SmallVectorImpl<OpFoldResult> &results) {
443 Type overflowTy = getOverflow().getType();
444 // addui_extended(x, 0) -> x, false
445 if (matchPattern(getRhs(), m_Zero())) {
446 Builder builder(getContext());
447 auto falseValue = builder.getZeroAttr(overflowTy);
448
449 results.push_back(getLhs());
450 results.push_back(falseValue);
451 return success();
452 }
453
454 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
455 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
456 // operands. If that succeeds, calculate the overflow bit based on the sum
457 // and the first (constant) operand, `lhs`.
458 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
459 adaptor.getOperands(),
460 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
461 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
462 ArrayRef({sumAttr, adaptor.getLhs()}),
463 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
465 if (!overflowAttr)
466 return failure();
467
468 results.push_back(sumAttr);
469 results.push_back(overflowAttr);
470 return success();
471 }
472
473 return failure();
474}
475
476void arith::AddUIExtendedOp::getCanonicalizationPatterns(
477 RewritePatternSet &patterns, MLIRContext *context) {
478 patterns.add<AddUIExtendedToAddI>(context);
479}
480
481//===----------------------------------------------------------------------===//
482// SubIOp
483//===----------------------------------------------------------------------===//
484
485OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
486 // subi(x,x) -> 0
487 if (getOperand(0) == getOperand(1)) {
488 auto shapedType = dyn_cast<ShapedType>(getType());
489 // We can't generate a constant with a dynamic shaped tensor.
490 if (!shapedType || shapedType.hasStaticShape())
491 return Builder(getContext()).getZeroAttr(getType());
492 }
493 // subi(x,0) -> x
494 if (matchPattern(adaptor.getRhs(), m_Zero()))
495 return getLhs();
496
497 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
498 // subi(addi(a, b), b) -> a
499 if (getRhs() == add.getRhs())
500 return add.getLhs();
501 // subi(addi(a, b), a) -> b
502 if (getRhs() == add.getLhs())
503 return add.getRhs();
504 }
505
507 adaptor.getOperands(),
508 [](APInt a, const APInt &b) { return std::move(a) - b; });
509}
510
511void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
512 MLIRContext *context) {
513 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
514 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
515 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
516}
517
518//===----------------------------------------------------------------------===//
519// MulIOp
520//===----------------------------------------------------------------------===//
521
522OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
523 // muli(x, 0) -> 0
524 if (matchPattern(adaptor.getRhs(), m_Zero()))
525 return getRhs();
526 // muli(x, 1) -> x
527 if (matchPattern(adaptor.getRhs(), m_One()))
528 return getLhs();
529 // TODO: Handle the overflow case.
530
531 // default folder
533 adaptor.getOperands(),
534 [](const APInt &a, const APInt &b) { return a * b; });
535}
536
537void arith::MulIOp::getAsmResultNames(
538 function_ref<void(Value, StringRef)> setNameFn) {
539 if (!isa<IndexType>(getType()))
540 return;
541
542 // Match vector.vscale by name to avoid depending on the vector dialect (which
543 // is a circular dependency).
544 auto isVscale = [](Operation *op) {
545 return op && op->getName().getStringRef() == "vector.vscale";
546 };
547
548 IntegerAttr baseValue;
549 auto isVscaleExpr = [&](Value a, Value b) {
550 return matchPattern(a, m_Constant(&baseValue)) &&
551 isVscale(b.getDefiningOp());
552 };
553
554 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
555 return;
556
557 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
558 SmallString<32> specialNameBuffer;
559 llvm::raw_svector_ostream specialName(specialNameBuffer);
560 specialName << 'c' << baseValue.getInt() << "_vscale";
561 setNameFn(getResult(), specialName.str());
562}
563
564void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
565 MLIRContext *context) {
566 patterns.add<MulIMulIConstant>(context);
567}
568
569//===----------------------------------------------------------------------===//
570// MulSIExtendedOp
571//===----------------------------------------------------------------------===//
572
573std::optional<SmallVector<int64_t, 4>>
574arith::MulSIExtendedOp::getShapeForUnroll() {
575 if (auto vt = dyn_cast<VectorType>(getType(0)))
576 return llvm::to_vector<4>(vt.getShape());
577 return std::nullopt;
578}
579
580LogicalResult
581arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
582 SmallVectorImpl<OpFoldResult> &results) {
583 // mulsi_extended(x, 0) -> 0, 0
584 if (matchPattern(adaptor.getRhs(), m_Zero())) {
585 Attribute zero = adaptor.getRhs();
586 results.push_back(zero);
587 results.push_back(zero);
588 return success();
589 }
590
591 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
592 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
593 adaptor.getOperands(),
594 [](const APInt &a, const APInt &b) { return a * b; })) {
595 // Invoke the constant fold helper again to calculate the 'high' result.
596 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
597 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
598 return llvm::APIntOps::mulhs(a, b);
599 });
600 assert(highAttr && "Unexpected constant-folding failure");
601
602 results.push_back(lowAttr);
603 results.push_back(highAttr);
604 return success();
605 }
606
607 return failure();
608}
609
610void arith::MulSIExtendedOp::getCanonicalizationPatterns(
611 RewritePatternSet &patterns, MLIRContext *context) {
612 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
613}
614
615//===----------------------------------------------------------------------===//
616// MulUIExtendedOp
617//===----------------------------------------------------------------------===//
618
619std::optional<SmallVector<int64_t, 4>>
620arith::MulUIExtendedOp::getShapeForUnroll() {
621 if (auto vt = dyn_cast<VectorType>(getType(0)))
622 return llvm::to_vector<4>(vt.getShape());
623 return std::nullopt;
624}
625
626LogicalResult
627arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
628 SmallVectorImpl<OpFoldResult> &results) {
629 // mului_extended(x, 0) -> 0, 0
630 if (matchPattern(adaptor.getRhs(), m_Zero())) {
631 Attribute zero = adaptor.getRhs();
632 results.push_back(zero);
633 results.push_back(zero);
634 return success();
635 }
636
637 // mului_extended(x, 1) -> x, 0
638 if (matchPattern(adaptor.getRhs(), m_One())) {
639 Builder builder(getContext());
640 Attribute zero = builder.getZeroAttr(getLhs().getType());
641 results.push_back(getLhs());
642 results.push_back(zero);
643 return success();
644 }
645
646 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
647 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
648 adaptor.getOperands(),
649 [](const APInt &a, const APInt &b) { return a * b; })) {
650 // Invoke the constant fold helper again to calculate the 'high' result.
651 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
652 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
653 return llvm::APIntOps::mulhu(a, b);
654 });
655 assert(highAttr && "Unexpected constant-folding failure");
656
657 results.push_back(lowAttr);
658 results.push_back(highAttr);
659 return success();
660 }
661
662 return failure();
663}
664
665void arith::MulUIExtendedOp::getCanonicalizationPatterns(
666 RewritePatternSet &patterns, MLIRContext *context) {
667 patterns.add<MulUIExtendedToMulI>(context);
668}
669
670//===----------------------------------------------------------------------===//
671// DivUIOp
672//===----------------------------------------------------------------------===//
673
674/// Fold `(a * b) / b -> a`
676 arith::IntegerOverflowFlags ovfFlags) {
677 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
678 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
679 return {};
680
681 if (mul.getLhs() == rhs)
682 return mul.getRhs();
683
684 if (mul.getRhs() == rhs)
685 return mul.getLhs();
686
687 return {};
688}
689
690OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
691 // divui (x, 1) -> x.
692 if (matchPattern(adaptor.getRhs(), m_One()))
693 return getLhs();
694
695 // (a * b) / b -> a
696 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
697 return val;
698
699 // Don't fold if it would require a division by zero.
700 bool div0 = false;
701 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
702 [&](APInt a, const APInt &b) {
703 if (div0 || !b) {
704 div0 = true;
705 return a;
706 }
707 return a.udiv(b);
708 });
709
710 return div0 ? Attribute() : result;
711}
712
713/// Returns whether an unsigned division by `divisor` is speculatable.
715 // X / 0 => UB
716 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
718
720}
721
722Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
723 return getDivUISpeculatability(getRhs());
724}
725
726//===----------------------------------------------------------------------===//
727// DivSIOp
728//===----------------------------------------------------------------------===//
729
730OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
731 // divsi (x, 1) -> x.
732 if (matchPattern(adaptor.getRhs(), m_One()))
733 return getLhs();
734
735 // (a * b) / b -> a
736 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
737 return val;
738
739 // Don't fold if it would overflow or if it requires a division by zero.
740 bool overflowOrDiv0 = false;
742 adaptor.getOperands(), [&](APInt a, const APInt &b) {
743 if (overflowOrDiv0 || !b) {
744 overflowOrDiv0 = true;
745 return a;
746 }
747 return a.sdiv_ov(b, overflowOrDiv0);
748 });
749
750 return overflowOrDiv0 ? Attribute() : result;
751}
752
753/// Returns whether a signed division by `divisor` is speculatable. This
754/// function conservatively assumes that all signed division by -1 are not
755/// speculatable.
757 // X / 0 => UB
758 // INT_MIN / -1 => UB
759 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
762
764}
765
766Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
767 return getDivSISpeculatability(getRhs());
768}
769
770//===----------------------------------------------------------------------===//
771// Ceil and floor division folding helpers
772//===----------------------------------------------------------------------===//
773
774static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
775 bool &overflow) {
776 // Returns (a-1)/b + 1
777 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
778 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
779 return val.sadd_ov(one, overflow);
780}
781
782//===----------------------------------------------------------------------===//
783// CeilDivUIOp
784//===----------------------------------------------------------------------===//
785
786OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
787 // ceildivui (x, 1) -> x.
788 if (matchPattern(adaptor.getRhs(), m_One()))
789 return getLhs();
790
791 bool overflowOrDiv0 = false;
793 adaptor.getOperands(), [&](APInt a, const APInt &b) {
794 if (overflowOrDiv0 || !b) {
795 overflowOrDiv0 = true;
796 return a;
797 }
798 APInt quotient = a.udiv(b);
799 if (!a.urem(b))
800 return quotient;
801 APInt one(a.getBitWidth(), 1, true);
802 return quotient.uadd_ov(one, overflowOrDiv0);
803 });
804
805 return overflowOrDiv0 ? Attribute() : result;
806}
807
808Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
809 return getDivUISpeculatability(getRhs());
810}
811
812//===----------------------------------------------------------------------===//
813// CeilDivSIOp
814//===----------------------------------------------------------------------===//
815
816OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
817 // ceildivsi (x, 1) -> x.
818 if (matchPattern(adaptor.getRhs(), m_One()))
819 return getLhs();
820
821 // Don't fold if it would overflow or if it requires a division by zero.
822 // TODO: This hook won't fold operations where a = MININT, because
823 // negating MININT overflows. This can be improved.
824 bool overflowOrDiv0 = false;
826 adaptor.getOperands(), [&](APInt a, const APInt &b) {
827 if (overflowOrDiv0 || !b) {
828 overflowOrDiv0 = true;
829 return a;
830 }
831 if (!a)
832 return a;
833 // After this point we know that neither a or b are zero.
834 unsigned bits = a.getBitWidth();
835 APInt zero = APInt::getZero(bits);
836 bool aGtZero = a.sgt(zero);
837 bool bGtZero = b.sgt(zero);
838 if (aGtZero && bGtZero) {
839 // Both positive, return ceil(a, b).
840 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
841 }
842
843 // No folding happens if any of the intermediate arithmetic operations
844 // overflows.
845 bool overflowNegA = false;
846 bool overflowNegB = false;
847 bool overflowDiv = false;
848 bool overflowNegRes = false;
849 if (!aGtZero && !bGtZero) {
850 // Both negative, return ceil(-a, -b).
851 APInt posA = zero.ssub_ov(a, overflowNegA);
852 APInt posB = zero.ssub_ov(b, overflowNegB);
853 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
854 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
855 return res;
856 }
857 if (!aGtZero && bGtZero) {
858 // A is negative, b is positive, return - ( -a / b).
859 APInt posA = zero.ssub_ov(a, overflowNegA);
860 APInt div = posA.sdiv_ov(b, overflowDiv);
861 APInt res = zero.ssub_ov(div, overflowNegRes);
862 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
863 return res;
864 }
865 // A is positive, b is negative, return - (a / -b).
866 APInt posB = zero.ssub_ov(b, overflowNegB);
867 APInt div = a.sdiv_ov(posB, overflowDiv);
868 APInt res = zero.ssub_ov(div, overflowNegRes);
869
870 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
871 return res;
872 });
873
874 return overflowOrDiv0 ? Attribute() : result;
875}
876
877Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
878 return getDivSISpeculatability(getRhs());
879}
880
881//===----------------------------------------------------------------------===//
882// FloorDivSIOp
883//===----------------------------------------------------------------------===//
884
885OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
886 // floordivsi (x, 1) -> x.
887 if (matchPattern(adaptor.getRhs(), m_One()))
888 return getLhs();
889
890 // Don't fold if it would overflow or if it requires a division by zero.
891 bool overflowOrDiv = false;
893 adaptor.getOperands(), [&](APInt a, const APInt &b) {
894 if (b.isZero()) {
895 overflowOrDiv = true;
896 return a;
897 }
898 return a.sfloordiv_ov(b, overflowOrDiv);
899 });
900
901 return overflowOrDiv ? Attribute() : result;
902}
903
904//===----------------------------------------------------------------------===//
905// RemUIOp
906//===----------------------------------------------------------------------===//
907
908OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
909 // remui (x, 1) -> 0.
910 if (matchPattern(adaptor.getRhs(), m_One()))
911 return Builder(getContext()).getZeroAttr(getType());
912
913 // Don't fold if it would require a division by zero.
914 bool div0 = false;
915 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
916 [&](APInt a, const APInt &b) {
917 if (div0 || b.isZero()) {
918 div0 = true;
919 return a;
920 }
921 return a.urem(b);
922 });
923
924 return div0 ? Attribute() : result;
925}
926
927//===----------------------------------------------------------------------===//
928// RemSIOp
929//===----------------------------------------------------------------------===//
930
931OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
932 // remsi (x, 1) -> 0.
933 if (matchPattern(adaptor.getRhs(), m_One()))
934 return Builder(getContext()).getZeroAttr(getType());
935
936 // Don't fold if it would require a division by zero.
937 bool div0 = false;
938 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
939 [&](APInt a, const APInt &b) {
940 if (div0 || b.isZero()) {
941 div0 = true;
942 return a;
943 }
944 return a.srem(b);
945 });
946
947 return div0 ? Attribute() : result;
948}
949
950//===----------------------------------------------------------------------===//
951// AndIOp
952//===----------------------------------------------------------------------===//
953
954/// Fold `and(a, and(a, b))` to `and(a, b)`
955static Value foldAndIofAndI(arith::AndIOp op) {
956 for (bool reversePrev : {false, true}) {
957 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
958 .getDefiningOp<arith::AndIOp>();
959 if (!prev)
960 continue;
961
962 Value other = (reversePrev ? op.getLhs() : op.getRhs());
963 if (other != prev.getLhs() && other != prev.getRhs())
964 continue;
965
966 return prev.getResult();
967 }
968 return {};
969}
970
971OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
972 /// and(x, 0) -> 0
973 if (matchPattern(adaptor.getRhs(), m_Zero()))
974 return getRhs();
975 /// and(x, allOnes) -> x
976 APInt intValue;
977 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
978 intValue.isAllOnes())
979 return getLhs();
980 /// and(x, not(x)) -> 0
981 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
982 m_ConstantInt(&intValue))) &&
983 intValue.isAllOnes())
984 return Builder(getContext()).getZeroAttr(getType());
985 /// and(not(x), x) -> 0
986 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
987 m_ConstantInt(&intValue))) &&
988 intValue.isAllOnes())
989 return Builder(getContext()).getZeroAttr(getType());
990
991 /// and(a, and(a, b)) -> and(a, b)
992 if (Value result = foldAndIofAndI(*this))
993 return result;
994
996 adaptor.getOperands(),
997 [](APInt a, const APInt &b) { return std::move(a) & b; });
998}
999
1000//===----------------------------------------------------------------------===//
1001// OrIOp
1002//===----------------------------------------------------------------------===//
1003
1004OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1005 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1006 /// or(x, 0) -> x
1007 if (rhsVal.isZero())
1008 return getLhs();
1009 /// or(x, <all ones>) -> <all ones>
1010 if (rhsVal.isAllOnes())
1011 return adaptor.getRhs();
1012 }
1013
1014 APInt intValue;
1015 /// or(x, xor(x, 1)) -> 1
1016 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1017 m_ConstantInt(&intValue))) &&
1018 intValue.isAllOnes())
1019 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1020 /// or(xor(x, 1), x) -> 1
1021 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1022 m_ConstantInt(&intValue))) &&
1023 intValue.isAllOnes())
1024 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1025
1027 adaptor.getOperands(),
1028 [](APInt a, const APInt &b) { return std::move(a) | b; });
1029}
1030
1031//===----------------------------------------------------------------------===//
1032// XOrIOp
1033//===----------------------------------------------------------------------===//
1034
1035OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1036 /// xor(x, 0) -> x
1037 if (matchPattern(adaptor.getRhs(), m_Zero()))
1038 return getLhs();
1039 /// xor(x, x) -> 0
1040 if (getLhs() == getRhs())
1041 return Builder(getContext()).getZeroAttr(getType());
1042 /// xor(xor(x, a), a) -> x
1043 /// xor(xor(a, x), a) -> x
1044 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1045 if (prev.getRhs() == getRhs())
1046 return prev.getLhs();
1047 if (prev.getLhs() == getRhs())
1048 return prev.getRhs();
1049 }
1050 /// xor(a, xor(x, a)) -> x
1051 /// xor(a, xor(a, x)) -> x
1052 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1053 if (prev.getRhs() == getLhs())
1054 return prev.getLhs();
1055 if (prev.getLhs() == getLhs())
1056 return prev.getRhs();
1057 }
1058
1060 adaptor.getOperands(),
1061 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1062}
1063
1064void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1065 MLIRContext *context) {
1066 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// NegFOp
1071//===----------------------------------------------------------------------===//
1072
1073OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1074 /// negf(negf(x)) -> x
1075 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1076 return op.getOperand();
1077 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1078 [](const APFloat &a) { return -a; });
1079}
1080
1081//===----------------------------------------------------------------------===//
1082// AddFOp
1083//===----------------------------------------------------------------------===//
1084
1085OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1086 // addf(x, -0) -> x
1087 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1088 return getLhs();
1089
1091 adaptor.getOperands(),
1092 [](const APFloat &a, const APFloat &b) { return a + b; });
1093}
1094
1095//===----------------------------------------------------------------------===//
1096// SubFOp
1097//===----------------------------------------------------------------------===//
1098
1099OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1100 // subf(x, +0) -> x
1101 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1102 return getLhs();
1103
1105 adaptor.getOperands(),
1106 [](const APFloat &a, const APFloat &b) { return a - b; });
1107}
1108
1109//===----------------------------------------------------------------------===//
1110// MaximumFOp
1111//===----------------------------------------------------------------------===//
1112
1113OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1114 // maximumf(x,x) -> x
1115 if (getLhs() == getRhs())
1116 return getRhs();
1117
1118 // maximumf(x, -inf) -> x
1119 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1120 return getLhs();
1121
1123 adaptor.getOperands(),
1124 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1125}
1126
1127//===----------------------------------------------------------------------===//
1128// MaxNumFOp
1129//===----------------------------------------------------------------------===//
1130
1131OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1132 // maxnumf(x,x) -> x
1133 if (getLhs() == getRhs())
1134 return getRhs();
1135
1136 // maxnumf(x, NaN) -> x
1137 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1138 return getLhs();
1139
1140 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1141}
1142
1143//===----------------------------------------------------------------------===//
1144// MaxSIOp
1145//===----------------------------------------------------------------------===//
1146
1147OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1148 // maxsi(x,x) -> x
1149 if (getLhs() == getRhs())
1150 return getRhs();
1151
1152 if (APInt intValue;
1153 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1154 // maxsi(x,MAX_INT) -> MAX_INT
1155 if (intValue.isMaxSignedValue())
1156 return getRhs();
1157 // maxsi(x, MIN_INT) -> x
1158 if (intValue.isMinSignedValue())
1159 return getLhs();
1160 }
1161
1162 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1163 [](const APInt &a, const APInt &b) {
1164 return llvm::APIntOps::smax(a, b);
1165 });
1166}
1167
1168//===----------------------------------------------------------------------===//
1169// MaxUIOp
1170//===----------------------------------------------------------------------===//
1171
1172OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1173 // maxui(x,x) -> x
1174 if (getLhs() == getRhs())
1175 return getRhs();
1176
1177 if (APInt intValue;
1178 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1179 // maxui(x,MAX_INT) -> MAX_INT
1180 if (intValue.isMaxValue())
1181 return getRhs();
1182 // maxui(x, MIN_INT) -> x
1183 if (intValue.isMinValue())
1184 return getLhs();
1185 }
1186
1187 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1188 [](const APInt &a, const APInt &b) {
1189 return llvm::APIntOps::umax(a, b);
1190 });
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// MinimumFOp
1195//===----------------------------------------------------------------------===//
1196
1197OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1198 // minimumf(x,x) -> x
1199 if (getLhs() == getRhs())
1200 return getRhs();
1201
1202 // minimumf(x, +inf) -> x
1203 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1204 return getLhs();
1205
1207 adaptor.getOperands(),
1208 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// MinNumFOp
1213//===----------------------------------------------------------------------===//
1214
1215OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1216 // minnumf(x,x) -> x
1217 if (getLhs() == getRhs())
1218 return getRhs();
1219
1220 // minnumf(x, NaN) -> x
1221 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1222 return getLhs();
1223
1225 adaptor.getOperands(),
1226 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1227}
1228
1229//===----------------------------------------------------------------------===//
1230// MinSIOp
1231//===----------------------------------------------------------------------===//
1232
1233OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1234 // minsi(x,x) -> x
1235 if (getLhs() == getRhs())
1236 return getRhs();
1237
1238 if (APInt intValue;
1239 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1240 // minsi(x,MIN_INT) -> MIN_INT
1241 if (intValue.isMinSignedValue())
1242 return getRhs();
1243 // minsi(x, MAX_INT) -> x
1244 if (intValue.isMaxSignedValue())
1245 return getLhs();
1246 }
1247
1248 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1249 [](const APInt &a, const APInt &b) {
1250 return llvm::APIntOps::smin(a, b);
1251 });
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// MinUIOp
1256//===----------------------------------------------------------------------===//
1257
1258OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1259 // minui(x,x) -> x
1260 if (getLhs() == getRhs())
1261 return getRhs();
1262
1263 if (APInt intValue;
1264 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1265 // minui(x,MIN_INT) -> MIN_INT
1266 if (intValue.isMinValue())
1267 return getRhs();
1268 // minui(x, MAX_INT) -> x
1269 if (intValue.isMaxValue())
1270 return getLhs();
1271 }
1272
1273 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1274 [](const APInt &a, const APInt &b) {
1275 return llvm::APIntOps::umin(a, b);
1276 });
1277}
1278
1279//===----------------------------------------------------------------------===//
1280// MulFOp
1281//===----------------------------------------------------------------------===//
1282
1283OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1284 // mulf(x, 1) -> x
1285 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1286 return getLhs();
1287
1288 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1289 arith::FastMathFlags::nsz)) {
1290 // mulf(x, 0) -> 0
1291 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1292 return getRhs();
1293 }
1294
1296 adaptor.getOperands(),
1297 [](const APFloat &a, const APFloat &b) { return a * b; });
1298}
1299
1300void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1301 MLIRContext *context) {
1302 patterns.add<MulFOfNegF>(context);
1303}
1304
1305//===----------------------------------------------------------------------===//
1306// DivFOp
1307//===----------------------------------------------------------------------===//
1308
1309OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1310 // divf(x, 1) -> x
1311 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1312 return getLhs();
1313
1315 adaptor.getOperands(),
1316 [](const APFloat &a, const APFloat &b) { return a / b; });
1317}
1318
1319void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1320 MLIRContext *context) {
1321 patterns.add<DivFOfNegF>(context);
1322}
1323
1324//===----------------------------------------------------------------------===//
1325// RemFOp
1326//===----------------------------------------------------------------------===//
1327
1328OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1329 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1330 [](const APFloat &a, const APFloat &b) {
1331 APFloat result(a);
1332 // APFloat::mod() offers the remainder
1333 // behavior we want, i.e. the result has
1334 // the sign of LHS operand.
1335 (void)result.mod(b);
1336 return result;
1337 });
1338}
1339
1340//===----------------------------------------------------------------------===//
1341// Utility functions for verifying cast ops
1342//===----------------------------------------------------------------------===//
1343
1344template <typename... Types>
1345using type_list = std::tuple<Types...> *;
1346
1347/// Returns a non-null type only if the provided type is one of the allowed
1348/// types or one of the allowed shaped types of the allowed types. Returns the
1349/// element type if a valid shaped type is provided.
1350template <typename... ShapedTypes, typename... ElementTypes>
1353 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1354 return {};
1355
1356 auto underlyingType = getElementTypeOrSelf(type);
1357 if (!llvm::isa<ElementTypes...>(underlyingType))
1358 return {};
1359
1360 return underlyingType;
1361}
1362
1363/// Get allowed underlying types for vectors and tensors.
1364template <typename... ElementTypes>
1369
1370/// Get allowed underlying types for vectors, tensors, and memrefs.
1371template <typename... ElementTypes>
1377
1378/// Return false if both types are ranked tensor with mismatching encoding.
1379static bool hasSameEncoding(Type typeA, Type typeB) {
1380 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1381 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1382 if (!rankedTensorA || !rankedTensorB)
1383 return true;
1384 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1385}
1386
1388 if (inputs.size() != 1 || outputs.size() != 1)
1389 return false;
1390 if (!hasSameEncoding(inputs.front(), outputs.front()))
1391 return false;
1392 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1393}
1394
1395//===----------------------------------------------------------------------===//
1396// Verifiers for integer and floating point extension/truncation ops
1397//===----------------------------------------------------------------------===//
1398
1399// Extend ops can only extend to a wider type.
1400template <typename ValType, typename Op>
1401static LogicalResult verifyExtOp(Op op) {
1402 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1403 Type dstType = getElementTypeOrSelf(op.getType());
1404
1405 if (llvm::cast<ValType>(srcType).getWidth() >=
1406 llvm::cast<ValType>(dstType).getWidth())
1407 return op.emitError("result type ")
1408 << dstType << " must be wider than operand type " << srcType;
1409
1410 return success();
1411}
1412
1413// Truncate ops can only truncate to a shorter type.
1414template <typename ValType, typename Op>
1415static LogicalResult verifyTruncateOp(Op op) {
1416 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1417 Type dstType = getElementTypeOrSelf(op.getType());
1418
1419 if (llvm::cast<ValType>(srcType).getWidth() <=
1420 llvm::cast<ValType>(dstType).getWidth())
1421 return op.emitError("result type ")
1422 << dstType << " must be shorter than operand type " << srcType;
1423
1424 return success();
1425}
1426
1427/// Validate a cast that changes the width of a type.
1428template <template <typename> class WidthComparator, typename... ElementTypes>
1429static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1430 if (!areValidCastInputsAndOutputs(inputs, outputs))
1431 return false;
1432
1433 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1434 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1435 if (!srcType || !dstType)
1436 return false;
1437
1438 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1439 srcType.getIntOrFloatBitWidth());
1440}
1441
1442/// Attempts to convert `sourceValue` to an APFloat value with
1443/// `targetSemantics` and `roundingMode`, without any information loss.
1444static FailureOr<APFloat> convertFloatValue(
1445 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1446 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1447 bool losesInfo = false;
1448 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1449 if (losesInfo || status != APFloat::opOK)
1450 return failure();
1451
1452 return sourceValue;
1453}
1454
1455//===----------------------------------------------------------------------===//
1456// ExtUIOp
1457//===----------------------------------------------------------------------===//
1458
1459OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1460 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1461 getInMutable().assign(lhs.getIn());
1462 return getResult();
1463 }
1464
1465 Type resType = getElementTypeOrSelf(getType());
1466 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1468 adaptor.getOperands(), getType(),
1469 [bitWidth](const APInt &a, bool &castStatus) {
1470 return a.zext(bitWidth);
1471 });
1472}
1473
1474bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1476}
1477
1478LogicalResult arith::ExtUIOp::verify() {
1479 return verifyExtOp<IntegerType>(*this);
1480}
1481
1482//===----------------------------------------------------------------------===//
1483// ExtSIOp
1484//===----------------------------------------------------------------------===//
1485
1486OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1487 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1488 getInMutable().assign(lhs.getIn());
1489 return getResult();
1490 }
1491
1492 Type resType = getElementTypeOrSelf(getType());
1493 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1495 adaptor.getOperands(), getType(),
1496 [bitWidth](const APInt &a, bool &castStatus) {
1497 return a.sext(bitWidth);
1498 });
1499}
1500
1501bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1503}
1504
1505void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1506 MLIRContext *context) {
1507 patterns.add<ExtSIOfExtUI>(context);
1508}
1509
1510LogicalResult arith::ExtSIOp::verify() {
1511 return verifyExtOp<IntegerType>(*this);
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// ExtFOp
1516//===----------------------------------------------------------------------===//
1517
1518/// Fold extension of float constants when there is no information loss due the
1519/// difference in fp semantics.
1520OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1521 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1522 if (truncFOp.getOperand().getType() == getType()) {
1523 arith::FastMathFlags truncFMF =
1524 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1525 bool isTruncContract =
1526 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1527 arith::FastMathFlags extFMF =
1528 getFastmath().value_or(arith::FastMathFlags::none);
1529 bool isExtContract =
1530 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1531 if (isTruncContract && isExtContract) {
1532 return truncFOp.getOperand();
1533 }
1534 }
1535 }
1536
1537 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1538 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1540 adaptor.getOperands(), getType(),
1541 [&targetSemantics](const APFloat &a, bool &castStatus) {
1542 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1543 if (failed(result)) {
1544 castStatus = false;
1545 return a;
1546 }
1547 return *result;
1548 });
1549}
1550
1551bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1552 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1553}
1554
1555LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1556
1557//===----------------------------------------------------------------------===//
1558// ScalingExtFOp
1559//===----------------------------------------------------------------------===//
1560
1561bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1562 TypeRange outputs) {
1563 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1564}
1565
1566LogicalResult arith::ScalingExtFOp::verify() {
1567 return verifyExtOp<FloatType>(*this);
1568}
1569
1570//===----------------------------------------------------------------------===//
1571// TruncIOp
1572//===----------------------------------------------------------------------===//
1573
1574OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1575 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1576 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1577 Value src = getOperand().getDefiningOp()->getOperand(0);
1578 Type srcType = getElementTypeOrSelf(src.getType());
1579 Type dstType = getElementTypeOrSelf(getType());
1580 // trunci(zexti(a)) -> trunci(a)
1581 // trunci(sexti(a)) -> trunci(a)
1582 if (llvm::cast<IntegerType>(srcType).getWidth() >
1583 llvm::cast<IntegerType>(dstType).getWidth()) {
1584 setOperand(src);
1585 return getResult();
1586 }
1587
1588 // trunci(zexti(a)) -> a
1589 // trunci(sexti(a)) -> a
1590 if (srcType == dstType)
1591 return src;
1592 }
1593
1594 // trunci(trunci(a)) -> trunci(a))
1595 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1596 setOperand(getOperand().getDefiningOp()->getOperand(0));
1597 return getResult();
1598 }
1599
1600 Type resType = getElementTypeOrSelf(getType());
1601 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1603 adaptor.getOperands(), getType(),
1604 [bitWidth](const APInt &a, bool &castStatus) {
1605 return a.trunc(bitWidth);
1606 });
1607}
1608
1609bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1610 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1611}
1612
1613void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1614 MLIRContext *context) {
1615 patterns
1616 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1617 context);
1618}
1619
1620LogicalResult arith::TruncIOp::verify() {
1621 return verifyTruncateOp<IntegerType>(*this);
1622}
1623
1624//===----------------------------------------------------------------------===//
1625// TruncFOp
1626//===----------------------------------------------------------------------===//
1627
1628/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1629/// can be represented without precision loss.
1630OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1631 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1632 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1633 Value src = extOp.getIn();
1634 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1635 auto intermediateType =
1636 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1637 // Check if the srcType is representable in the intermediateType.
1638 if (llvm::APFloatBase::isRepresentableBy(
1639 srcType.getFloatSemantics(),
1640 intermediateType.getFloatSemantics())) {
1641 // truncf(extf(a)) -> truncf(a)
1642 if (srcType.getWidth() > resElemType.getWidth()) {
1643 setOperand(src);
1644 return getResult();
1645 }
1646
1647 // truncf(extf(a)) -> a
1648 if (srcType == resElemType)
1649 return src;
1650 }
1651 }
1652
1653 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1655 adaptor.getOperands(), getType(),
1656 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1657 RoundingMode roundingMode =
1658 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1659 llvm::RoundingMode llvmRoundingMode =
1661 FailureOr<APFloat> result =
1662 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1663 if (failed(result)) {
1664 castStatus = false;
1665 return a;
1666 }
1667 return *result;
1668 });
1669}
1670
1671void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1672 MLIRContext *context) {
1673 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1674}
1675
1676bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1677 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1678}
1679
1680LogicalResult arith::TruncFOp::verify() {
1681 return verifyTruncateOp<FloatType>(*this);
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// ScalingTruncFOp
1686//===----------------------------------------------------------------------===//
1687
1688bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1689 TypeRange outputs) {
1690 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1691}
1692
1693LogicalResult arith::ScalingTruncFOp::verify() {
1694 return verifyTruncateOp<FloatType>(*this);
1695}
1696
1697//===----------------------------------------------------------------------===//
1698// AndIOp
1699//===----------------------------------------------------------------------===//
1700
1701void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1702 MLIRContext *context) {
1703 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1704}
1705
1706//===----------------------------------------------------------------------===//
1707// OrIOp
1708//===----------------------------------------------------------------------===//
1709
1710void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1711 MLIRContext *context) {
1712 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1713}
1714
1715//===----------------------------------------------------------------------===//
1716// Verifiers for casts between integers and floats.
1717//===----------------------------------------------------------------------===//
1718
1719template <typename From, typename To>
1720static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1721 if (!areValidCastInputsAndOutputs(inputs, outputs))
1722 return false;
1723
1724 auto srcType = getTypeIfLike<From>(inputs.front());
1725 auto dstType = getTypeIfLike<To>(outputs.back());
1726
1727 return srcType && dstType;
1728}
1729
1730//===----------------------------------------------------------------------===//
1731// UIToFPOp
1732//===----------------------------------------------------------------------===//
1733
1734bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1735 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1736}
1737
1738OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1739 Type resEleType = getElementTypeOrSelf(getType());
1741 adaptor.getOperands(), getType(),
1742 [&resEleType](const APInt &a, bool &castStatus) {
1743 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1744 APFloat apf(floatTy.getFloatSemantics(),
1745 APInt::getZero(floatTy.getWidth()));
1746 apf.convertFromAPInt(a, /*IsSigned=*/false,
1747 APFloat::rmNearestTiesToEven);
1748 return apf;
1749 });
1750}
1751
1752//===----------------------------------------------------------------------===//
1753// SIToFPOp
1754//===----------------------------------------------------------------------===//
1755
1756bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1757 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1758}
1759
1760OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1761 Type resEleType = getElementTypeOrSelf(getType());
1763 adaptor.getOperands(), getType(),
1764 [&resEleType](const APInt &a, bool &castStatus) {
1765 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1766 APFloat apf(floatTy.getFloatSemantics(),
1767 APInt::getZero(floatTy.getWidth()));
1768 apf.convertFromAPInt(a, /*IsSigned=*/true,
1769 APFloat::rmNearestTiesToEven);
1770 return apf;
1771 });
1772}
1773
1774//===----------------------------------------------------------------------===//
1775// FPToUIOp
1776//===----------------------------------------------------------------------===//
1777
1778bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1779 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1780}
1781
1782OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1783 Type resType = getElementTypeOrSelf(getType());
1784 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1786 adaptor.getOperands(), getType(),
1787 [&bitWidth](const APFloat &a, bool &castStatus) {
1788 bool ignored;
1789 APSInt api(bitWidth, /*isUnsigned=*/true);
1790 castStatus = APFloat::opInvalidOp !=
1791 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1792 return api;
1793 });
1794}
1795
1796//===----------------------------------------------------------------------===//
1797// FPToSIOp
1798//===----------------------------------------------------------------------===//
1799
1800bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1801 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1802}
1803
1804OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1805 Type resType = getElementTypeOrSelf(getType());
1806 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1808 adaptor.getOperands(), getType(),
1809 [&bitWidth](const APFloat &a, bool &castStatus) {
1810 bool ignored;
1811 APSInt api(bitWidth, /*isUnsigned=*/false);
1812 castStatus = APFloat::opInvalidOp !=
1813 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1814 return api;
1815 });
1816}
1817
1818//===----------------------------------------------------------------------===//
1819// IndexCastOp
1820//===----------------------------------------------------------------------===//
1821
1822static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1823 if (!areValidCastInputsAndOutputs(inputs, outputs))
1824 return false;
1825
1826 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1827 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1828 if (!srcType || !dstType)
1829 return false;
1830
1831 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1832 (srcType.isSignlessInteger() && dstType.isIndex());
1833}
1834
1835bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1836 TypeRange outputs) {
1837 return areIndexCastCompatible(inputs, outputs);
1838}
1839
1840OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1841 // index_cast(constant) -> constant
1842 unsigned resultBitwidth = 64; // Default for index integer attributes.
1843 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1844 resultBitwidth = intTy.getWidth();
1845
1847 adaptor.getOperands(), getType(),
1848 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1849 return a.sextOrTrunc(resultBitwidth);
1850 });
1851}
1852
1853void arith::IndexCastOp::getCanonicalizationPatterns(
1854 RewritePatternSet &patterns, MLIRContext *context) {
1855 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1856}
1857
1858//===----------------------------------------------------------------------===//
1859// IndexCastUIOp
1860//===----------------------------------------------------------------------===//
1861
1862bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1863 TypeRange outputs) {
1864 return areIndexCastCompatible(inputs, outputs);
1865}
1866
1867OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1868 // index_castui(constant) -> constant
1869 unsigned resultBitwidth = 64; // Default for index integer attributes.
1870 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1871 resultBitwidth = intTy.getWidth();
1872
1874 adaptor.getOperands(), getType(),
1875 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1876 return a.zextOrTrunc(resultBitwidth);
1877 });
1878}
1879
1880void arith::IndexCastUIOp::getCanonicalizationPatterns(
1881 RewritePatternSet &patterns, MLIRContext *context) {
1882 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1883}
1884
1885//===----------------------------------------------------------------------===//
1886// BitcastOp
1887//===----------------------------------------------------------------------===//
1888
1889bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1890 if (!areValidCastInputsAndOutputs(inputs, outputs))
1891 return false;
1892
1893 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1894 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1895 if (!srcType || !dstType)
1896 return false;
1897
1898 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1899}
1900
1901OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1902 auto resType = getType();
1903 auto operand = adaptor.getIn();
1904 if (!operand)
1905 return {};
1906
1907 /// Bitcast dense elements.
1908 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1909 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1910 /// Other shaped types unhandled.
1911 if (llvm::isa<ShapedType>(resType))
1912 return {};
1913
1914 /// Bitcast poison.
1915 if (llvm::isa<ub::PoisonAttr>(operand))
1916 return ub::PoisonAttr::get(getContext());
1917
1918 /// Bitcast integer or float to integer or float.
1919 APInt bits = llvm::isa<FloatAttr>(operand)
1920 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1921 : llvm::cast<IntegerAttr>(operand).getValue();
1922 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1923 "trying to fold on broken IR: operands have incompatible types");
1924
1925 if (auto resFloatType = dyn_cast<FloatType>(resType))
1926 return FloatAttr::get(resType,
1927 APFloat(resFloatType.getFloatSemantics(), bits));
1928 return IntegerAttr::get(resType, bits);
1929}
1930
1931void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1932 MLIRContext *context) {
1933 patterns.add<BitcastOfBitcast>(context);
1934}
1935
1936//===----------------------------------------------------------------------===//
1937// CmpIOp
1938//===----------------------------------------------------------------------===//
1939
1940/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1941/// comparison predicates.
1942bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1943 const APInt &lhs, const APInt &rhs) {
1944 switch (predicate) {
1945 case arith::CmpIPredicate::eq:
1946 return lhs.eq(rhs);
1947 case arith::CmpIPredicate::ne:
1948 return lhs.ne(rhs);
1949 case arith::CmpIPredicate::slt:
1950 return lhs.slt(rhs);
1951 case arith::CmpIPredicate::sle:
1952 return lhs.sle(rhs);
1953 case arith::CmpIPredicate::sgt:
1954 return lhs.sgt(rhs);
1955 case arith::CmpIPredicate::sge:
1956 return lhs.sge(rhs);
1957 case arith::CmpIPredicate::ult:
1958 return lhs.ult(rhs);
1959 case arith::CmpIPredicate::ule:
1960 return lhs.ule(rhs);
1961 case arith::CmpIPredicate::ugt:
1962 return lhs.ugt(rhs);
1963 case arith::CmpIPredicate::uge:
1964 return lhs.uge(rhs);
1965 }
1966 llvm_unreachable("unknown cmpi predicate kind");
1967}
1968
1969/// Returns true if the predicate is true for two equal operands.
1970static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1971 switch (predicate) {
1972 case arith::CmpIPredicate::eq:
1973 case arith::CmpIPredicate::sle:
1974 case arith::CmpIPredicate::sge:
1975 case arith::CmpIPredicate::ule:
1976 case arith::CmpIPredicate::uge:
1977 return true;
1978 case arith::CmpIPredicate::ne:
1979 case arith::CmpIPredicate::slt:
1980 case arith::CmpIPredicate::sgt:
1981 case arith::CmpIPredicate::ult:
1982 case arith::CmpIPredicate::ugt:
1983 return false;
1984 }
1985 llvm_unreachable("unknown cmpi predicate kind");
1986}
1987
1988static std::optional<int64_t> getIntegerWidth(Type t) {
1989 if (auto intType = dyn_cast<IntegerType>(t)) {
1990 return intType.getWidth();
1991 }
1992 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
1993 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1994 }
1995 return std::nullopt;
1996}
1997
1998OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1999 // cmpi(pred, x, x)
2000 if (getLhs() == getRhs()) {
2001 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2002 return getBoolAttribute(getType(), val);
2003 }
2004
2005 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2006 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2007 // extsi(%x : i1 -> iN) != 0 -> %x
2008 std::optional<int64_t> integerWidth =
2009 getIntegerWidth(extOp.getOperand().getType());
2010 if (integerWidth && integerWidth.value() == 1 &&
2011 getPredicate() == arith::CmpIPredicate::ne)
2012 return extOp.getOperand();
2013 }
2014 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2015 // extui(%x : i1 -> iN) != 0 -> %x
2016 std::optional<int64_t> integerWidth =
2017 getIntegerWidth(extOp.getOperand().getType());
2018 if (integerWidth && integerWidth.value() == 1 &&
2019 getPredicate() == arith::CmpIPredicate::ne)
2020 return extOp.getOperand();
2021 }
2022
2023 // arith.cmpi ne, %val, %zero : i1 -> %val
2024 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2025 getPredicate() == arith::CmpIPredicate::ne)
2026 return getLhs();
2027 }
2028
2029 if (matchPattern(adaptor.getRhs(), m_One())) {
2030 // arith.cmpi eq, %val, %one : i1 -> %val
2031 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2032 getPredicate() == arith::CmpIPredicate::eq)
2033 return getLhs();
2034 }
2035
2036 // Move constant to the right side.
2037 if (adaptor.getLhs() && !adaptor.getRhs()) {
2038 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2039 using Pred = CmpIPredicate;
2040 const std::pair<Pred, Pred> invPreds[] = {
2041 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2042 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2043 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2044 {Pred::ne, Pred::ne},
2045 };
2046 Pred origPred = getPredicate();
2047 for (auto pred : invPreds) {
2048 if (origPred == pred.first) {
2049 setPredicate(pred.second);
2050 Value lhs = getLhs();
2051 Value rhs = getRhs();
2052 getLhsMutable().assign(rhs);
2053 getRhsMutable().assign(lhs);
2054 return getResult();
2055 }
2056 }
2057 llvm_unreachable("unknown cmpi predicate kind");
2058 }
2059
2060 // We are moving constants to the right side; So if lhs is constant rhs is
2061 // guaranteed to be a constant.
2062 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2064 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2065 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2066 return APInt(1,
2067 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2068 });
2069 }
2070
2071 return {};
2072}
2073
2074void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2075 MLIRContext *context) {
2076 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2077}
2078
2079//===----------------------------------------------------------------------===//
2080// CmpFOp
2081//===----------------------------------------------------------------------===//
2082
2083/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2084/// comparison predicates.
2085bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2086 const APFloat &lhs, const APFloat &rhs) {
2087 auto cmpResult = lhs.compare(rhs);
2088 switch (predicate) {
2089 case arith::CmpFPredicate::AlwaysFalse:
2090 return false;
2091 case arith::CmpFPredicate::OEQ:
2092 return cmpResult == APFloat::cmpEqual;
2093 case arith::CmpFPredicate::OGT:
2094 return cmpResult == APFloat::cmpGreaterThan;
2095 case arith::CmpFPredicate::OGE:
2096 return cmpResult == APFloat::cmpGreaterThan ||
2097 cmpResult == APFloat::cmpEqual;
2098 case arith::CmpFPredicate::OLT:
2099 return cmpResult == APFloat::cmpLessThan;
2100 case arith::CmpFPredicate::OLE:
2101 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2102 case arith::CmpFPredicate::ONE:
2103 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2104 case arith::CmpFPredicate::ORD:
2105 return cmpResult != APFloat::cmpUnordered;
2106 case arith::CmpFPredicate::UEQ:
2107 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2108 case arith::CmpFPredicate::UGT:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpGreaterThan;
2111 case arith::CmpFPredicate::UGE:
2112 return cmpResult == APFloat::cmpUnordered ||
2113 cmpResult == APFloat::cmpGreaterThan ||
2114 cmpResult == APFloat::cmpEqual;
2115 case arith::CmpFPredicate::ULT:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpLessThan;
2118 case arith::CmpFPredicate::ULE:
2119 return cmpResult == APFloat::cmpUnordered ||
2120 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2121 case arith::CmpFPredicate::UNE:
2122 return cmpResult != APFloat::cmpEqual;
2123 case arith::CmpFPredicate::UNO:
2124 return cmpResult == APFloat::cmpUnordered;
2125 case arith::CmpFPredicate::AlwaysTrue:
2126 return true;
2127 }
2128 llvm_unreachable("unknown cmpf predicate kind");
2129}
2130
2131OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2132 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2133 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2134
2135 // If one operand is NaN, making them both NaN does not change the result.
2136 if (lhs && lhs.getValue().isNaN())
2137 rhs = lhs;
2138 if (rhs && rhs.getValue().isNaN())
2139 lhs = rhs;
2140
2141 if (!lhs || !rhs)
2142 return {};
2143
2144 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2145 return BoolAttr::get(getContext(), val);
2146}
2147
2148class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2149public:
2150 using Base::Base;
2151
2152 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2153 bool isUnsigned) {
2154 using namespace arith;
2155 switch (pred) {
2156 case CmpFPredicate::UEQ:
2157 case CmpFPredicate::OEQ:
2158 return CmpIPredicate::eq;
2159 case CmpFPredicate::UGT:
2160 case CmpFPredicate::OGT:
2161 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2162 case CmpFPredicate::UGE:
2163 case CmpFPredicate::OGE:
2164 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2165 case CmpFPredicate::ULT:
2166 case CmpFPredicate::OLT:
2167 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2168 case CmpFPredicate::ULE:
2169 case CmpFPredicate::OLE:
2170 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2171 case CmpFPredicate::UNE:
2172 case CmpFPredicate::ONE:
2173 return CmpIPredicate::ne;
2174 default:
2175 llvm_unreachable("Unexpected predicate!");
2176 }
2177 }
2178
2179 LogicalResult matchAndRewrite(CmpFOp op,
2180 PatternRewriter &rewriter) const override {
2181 FloatAttr flt;
2182 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2183 return failure();
2184
2185 const APFloat &rhs = flt.getValue();
2186
2187 // Don't attempt to fold a nan.
2188 if (rhs.isNaN())
2189 return failure();
2190
2191 // Get the width of the mantissa. We don't want to hack on conversions that
2192 // might lose information from the integer, e.g. "i64 -> float"
2193 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2194 int mantissaWidth = floatTy.getFPMantissaWidth();
2195 if (mantissaWidth <= 0)
2196 return failure();
2197
2198 bool isUnsigned;
2199 Value intVal;
2200
2201 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2202 isUnsigned = false;
2203 intVal = si.getIn();
2204 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2205 isUnsigned = true;
2206 intVal = ui.getIn();
2207 } else {
2208 return failure();
2209 }
2210
2211 // Check to see that the input is converted from an integer type that is
2212 // small enough that preserves all bits.
2213 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2214 auto intWidth = intTy.getWidth();
2215
2216 // Number of bits representing values, as opposed to the sign
2217 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2218
2219 // Following test does NOT adjust intWidth downwards for signed inputs,
2220 // because the most negative value still requires all the mantissa bits
2221 // to distinguish it from one less than that value.
2222 if ((int)intWidth > mantissaWidth) {
2223 // Conversion would lose accuracy. Check if loss can impact comparison.
2224 int exponent = ilogb(rhs);
2225 if (exponent == APFloat::IEK_Inf) {
2226 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2227 if (maxExponent < (int)valueBits) {
2228 // Conversion could create infinity.
2229 return failure();
2230 }
2231 } else {
2232 // Note that if rhs is zero or NaN, then Exp is negative
2233 // and first condition is trivially false.
2234 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2235 // Conversion could affect comparison.
2236 return failure();
2237 }
2238 }
2239 }
2240
2241 // Convert to equivalent cmpi predicate
2242 CmpIPredicate pred;
2243 switch (op.getPredicate()) {
2244 case CmpFPredicate::ORD:
2245 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2246 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2247 /*width=*/1);
2248 return success();
2249 case CmpFPredicate::UNO:
2250 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2251 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2252 /*width=*/1);
2253 return success();
2254 default:
2255 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2256 break;
2257 }
2258
2259 if (!isUnsigned) {
2260 // If the rhs value is > SignedMax, fold the comparison. This handles
2261 // +INF and large values.
2262 APFloat signedMax(rhs.getSemantics());
2263 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2264 APFloat::rmNearestTiesToEven);
2265 if (signedMax < rhs) { // smax < 13123.0
2266 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2267 pred == CmpIPredicate::sle)
2268 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2269 /*width=*/1);
2270 else
2271 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2272 /*width=*/1);
2273 return success();
2274 }
2275 } else {
2276 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2277 // +INF and large values.
2278 APFloat unsignedMax(rhs.getSemantics());
2279 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2280 APFloat::rmNearestTiesToEven);
2281 if (unsignedMax < rhs) { // umax < 13123.0
2282 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2283 pred == CmpIPredicate::ule)
2284 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2285 /*width=*/1);
2286 else
2287 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2288 /*width=*/1);
2289 return success();
2290 }
2291 }
2292
2293 if (!isUnsigned) {
2294 // See if the rhs value is < SignedMin.
2295 APFloat signedMin(rhs.getSemantics());
2296 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2297 APFloat::rmNearestTiesToEven);
2298 if (signedMin > rhs) { // smin > 12312.0
2299 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2300 pred == CmpIPredicate::sge)
2301 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2302 /*width=*/1);
2303 else
2304 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2305 /*width=*/1);
2306 return success();
2307 }
2308 } else {
2309 // See if the rhs value is < UnsignedMin.
2310 APFloat unsignedMin(rhs.getSemantics());
2311 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2312 APFloat::rmNearestTiesToEven);
2313 if (unsignedMin > rhs) { // umin > 12312.0
2314 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2315 pred == CmpIPredicate::uge)
2316 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2317 /*width=*/1);
2318 else
2319 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2320 /*width=*/1);
2321 return success();
2322 }
2323 }
2324
2325 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2326 // [0, UMAX], but it may still be fractional. See if it is fractional by
2327 // casting the FP value to the integer value and back, checking for
2328 // equality. Don't do this for zero, because -0.0 is not fractional.
2329 bool ignored;
2330 APSInt rhsInt(intWidth, isUnsigned);
2331 if (APFloat::opInvalidOp ==
2332 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2333 // Undefined behavior invoked - the destination type can't represent
2334 // the input constant.
2335 return failure();
2336 }
2337
2338 if (!rhs.isZero()) {
2339 APFloat apf(floatTy.getFloatSemantics(),
2340 APInt::getZero(floatTy.getWidth()));
2341 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2342
2343 bool equal = apf == rhs;
2344 if (!equal) {
2345 // If we had a comparison against a fractional value, we have to adjust
2346 // the compare predicate and sometimes the value. rhsInt is rounded
2347 // towards zero at this point.
2348 switch (pred) {
2349 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2350 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2351 /*width=*/1);
2352 return success();
2353 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2354 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2355 /*width=*/1);
2356 return success();
2357 case CmpIPredicate::ule:
2358 // (float)int <= 4.4 --> int <= 4
2359 // (float)int <= -4.4 --> false
2360 if (rhs.isNegative()) {
2361 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2362 /*width=*/1);
2363 return success();
2364 }
2365 break;
2366 case CmpIPredicate::sle:
2367 // (float)int <= 4.4 --> int <= 4
2368 // (float)int <= -4.4 --> int < -4
2369 if (rhs.isNegative())
2370 pred = CmpIPredicate::slt;
2371 break;
2372 case CmpIPredicate::ult:
2373 // (float)int < -4.4 --> false
2374 // (float)int < 4.4 --> int <= 4
2375 if (rhs.isNegative()) {
2376 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2377 /*width=*/1);
2378 return success();
2379 }
2380 pred = CmpIPredicate::ule;
2381 break;
2382 case CmpIPredicate::slt:
2383 // (float)int < -4.4 --> int < -4
2384 // (float)int < 4.4 --> int <= 4
2385 if (!rhs.isNegative())
2386 pred = CmpIPredicate::sle;
2387 break;
2388 case CmpIPredicate::ugt:
2389 // (float)int > 4.4 --> int > 4
2390 // (float)int > -4.4 --> true
2391 if (rhs.isNegative()) {
2392 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2393 /*width=*/1);
2394 return success();
2395 }
2396 break;
2397 case CmpIPredicate::sgt:
2398 // (float)int > 4.4 --> int > 4
2399 // (float)int > -4.4 --> int >= -4
2400 if (rhs.isNegative())
2401 pred = CmpIPredicate::sge;
2402 break;
2403 case CmpIPredicate::uge:
2404 // (float)int >= -4.4 --> true
2405 // (float)int >= 4.4 --> int > 4
2406 if (rhs.isNegative()) {
2407 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2408 /*width=*/1);
2409 return success();
2410 }
2411 pred = CmpIPredicate::ugt;
2412 break;
2413 case CmpIPredicate::sge:
2414 // (float)int >= -4.4 --> int >= -4
2415 // (float)int >= 4.4 --> int > 4
2416 if (!rhs.isNegative())
2417 pred = CmpIPredicate::sgt;
2418 break;
2419 }
2420 }
2421 }
2422
2423 // Lower this FP comparison into an appropriate integer version of the
2424 // comparison.
2425 rewriter.replaceOpWithNewOp<CmpIOp>(
2426 op, pred, intVal,
2427 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2428 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2429 return success();
2430 }
2431};
2432
2433void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2434 MLIRContext *context) {
2435 patterns.insert<CmpFIntToFPConst>(context);
2436}
2437
2438//===----------------------------------------------------------------------===//
2439// SelectOp
2440//===----------------------------------------------------------------------===//
2441
2442// select %arg, %c1, %c0 => extui %arg
2443struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2444 using Base::Base;
2445
2446 LogicalResult matchAndRewrite(arith::SelectOp op,
2447 PatternRewriter &rewriter) const override {
2448 // Cannot extui i1 to i1, or i1 to f32
2449 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2450 return failure();
2451
2452 // select %x, c1, %c0 => extui %arg
2453 if (matchPattern(op.getTrueValue(), m_One()) &&
2454 matchPattern(op.getFalseValue(), m_Zero())) {
2455 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2456 op.getCondition());
2457 return success();
2458 }
2459
2460 // select %x, c0, %c1 => extui (xor %arg, true)
2461 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2462 matchPattern(op.getFalseValue(), m_One())) {
2463 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2464 op, op.getType(),
2465 arith::XOrIOp::create(
2466 rewriter, op.getLoc(), op.getCondition(),
2467 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2468 op.getCondition().getType(), 1)));
2469 return success();
2470 }
2471
2472 return failure();
2473 }
2474};
2475
2476void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2477 MLIRContext *context) {
2478 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2479 SelectI1ToNot, SelectToExtUI>(context);
2480}
2481
2482OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2483 Value trueVal = getTrueValue();
2484 Value falseVal = getFalseValue();
2485 if (trueVal == falseVal)
2486 return trueVal;
2487
2488 Value condition = getCondition();
2489
2490 // select true, %0, %1 => %0
2491 if (matchPattern(adaptor.getCondition(), m_One()))
2492 return trueVal;
2493
2494 // select false, %0, %1 => %1
2495 if (matchPattern(adaptor.getCondition(), m_Zero()))
2496 return falseVal;
2497
2498 // If either operand is fully poisoned, return the other.
2499 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2500 return falseVal;
2501
2502 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2503 return trueVal;
2504
2505 // select %x, true, false => %x
2506 if (getType().isSignlessInteger(1) &&
2507 matchPattern(adaptor.getTrueValue(), m_One()) &&
2508 matchPattern(adaptor.getFalseValue(), m_Zero()))
2509 return condition;
2510
2511 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2512 auto pred = cmp.getPredicate();
2513 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2514 auto cmpLhs = cmp.getLhs();
2515 auto cmpRhs = cmp.getRhs();
2516
2517 // %0 = arith.cmpi eq, %arg0, %arg1
2518 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2519
2520 // %0 = arith.cmpi ne, %arg0, %arg1
2521 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2522
2523 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2524 (cmpRhs == trueVal && cmpLhs == falseVal))
2525 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2526 }
2527 }
2528
2529 // Constant-fold constant operands over non-splat constant condition.
2530 // select %cst_vec, %cst0, %cst1 => %cst2
2531 if (auto cond =
2532 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2533 // DenseElementsAttr by construction always has a static shape.
2534 assert(cond.getType().hasStaticShape() &&
2535 "DenseElementsAttr must have static shape");
2536 if (auto lhs =
2537 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2538 if (auto rhs =
2539 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2540 SmallVector<Attribute> results;
2541 results.reserve(static_cast<size_t>(cond.getNumElements()));
2542 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2543 cond.value_end<BoolAttr>());
2544 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2545 lhs.value_end<Attribute>());
2546 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2547 rhs.value_end<Attribute>());
2548
2549 for (auto [condVal, lhsVal, rhsVal] :
2550 llvm::zip_equal(condVals, lhsVals, rhsVals))
2551 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2552
2553 return DenseElementsAttr::get(lhs.getType(), results);
2554 }
2555 }
2556 }
2557
2558 return nullptr;
2559}
2560
2561ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2562 Type conditionType, resultType;
2563 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2564 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2565 parser.parseOptionalAttrDict(result.attributes) ||
2566 parser.parseColonType(resultType))
2567 return failure();
2568
2569 // Check for the explicit condition type if this is a masked tensor or vector.
2570 if (succeeded(parser.parseOptionalComma())) {
2571 conditionType = resultType;
2572 if (parser.parseType(resultType))
2573 return failure();
2574 } else {
2575 conditionType = parser.getBuilder().getI1Type();
2576 }
2577
2578 result.addTypes(resultType);
2579 return parser.resolveOperands(operands,
2580 {conditionType, resultType, resultType},
2581 parser.getNameLoc(), result.operands);
2582}
2583
2584void arith::SelectOp::print(OpAsmPrinter &p) {
2585 p << " " << getOperands();
2586 p.printOptionalAttrDict((*this)->getAttrs());
2587 p << " : ";
2588 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2589 p << condType << ", ";
2590 p << getType();
2591}
2592
2593LogicalResult arith::SelectOp::verify() {
2594 Type conditionType = getCondition().getType();
2595 if (conditionType.isSignlessInteger(1))
2596 return success();
2597
2598 // If the result type is a vector or tensor, the type can be a mask with the
2599 // same elements.
2600 Type resultType = getType();
2601 if (!llvm::isa<TensorType, VectorType>(resultType))
2602 return emitOpError() << "expected condition to be a signless i1, but got "
2603 << conditionType;
2604 Type shapedConditionType = getI1SameShape(resultType);
2605 if (conditionType != shapedConditionType) {
2606 return emitOpError() << "expected condition type to have the same shape "
2607 "as the result type, expected "
2608 << shapedConditionType << ", but got "
2609 << conditionType;
2610 }
2611 return success();
2612}
2613//===----------------------------------------------------------------------===//
2614// ShLIOp
2615//===----------------------------------------------------------------------===//
2616
2617OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2618 // shli(x, 0) -> x
2619 if (matchPattern(adaptor.getRhs(), m_Zero()))
2620 return getLhs();
2621 // Don't fold if shifting more or equal than the bit width.
2622 bool bounded = false;
2624 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2625 bounded = b.ult(b.getBitWidth());
2626 return a.shl(b);
2627 });
2628 return bounded ? result : Attribute();
2629}
2630
2631//===----------------------------------------------------------------------===//
2632// ShRUIOp
2633//===----------------------------------------------------------------------===//
2634
2635OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2636 // shrui(x, 0) -> x
2637 if (matchPattern(adaptor.getRhs(), m_Zero()))
2638 return getLhs();
2639 // Don't fold if shifting more or equal than the bit width.
2640 bool bounded = false;
2642 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2643 bounded = b.ult(b.getBitWidth());
2644 return a.lshr(b);
2645 });
2646 return bounded ? result : Attribute();
2647}
2648
2649//===----------------------------------------------------------------------===//
2650// ShRSIOp
2651//===----------------------------------------------------------------------===//
2652
2653OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2654 // shrsi(x, 0) -> x
2655 if (matchPattern(adaptor.getRhs(), m_Zero()))
2656 return getLhs();
2657 // Don't fold if shifting more or equal than the bit width.
2658 bool bounded = false;
2660 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2661 bounded = b.ult(b.getBitWidth());
2662 return a.ashr(b);
2663 });
2664 return bounded ? result : Attribute();
2665}
2666
2667//===----------------------------------------------------------------------===//
2668// Atomic Enum
2669//===----------------------------------------------------------------------===//
2670
2671/// Returns the identity value attribute associated with an AtomicRMWKind op.
2672TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2673 OpBuilder &builder, Location loc,
2674 bool useOnlyFiniteValue) {
2675 switch (kind) {
2676 case AtomicRMWKind::maximumf: {
2677 const llvm::fltSemantics &semantic =
2678 llvm::cast<FloatType>(resultType).getFloatSemantics();
2679 APFloat identity = useOnlyFiniteValue
2680 ? APFloat::getLargest(semantic, /*Negative=*/true)
2681 : APFloat::getInf(semantic, /*Negative=*/true);
2682 return builder.getFloatAttr(resultType, identity);
2683 }
2684 case AtomicRMWKind::maxnumf: {
2685 const llvm::fltSemantics &semantic =
2686 llvm::cast<FloatType>(resultType).getFloatSemantics();
2687 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2688 return builder.getFloatAttr(resultType, identity);
2689 }
2690 case AtomicRMWKind::addf:
2691 case AtomicRMWKind::addi:
2692 case AtomicRMWKind::maxu:
2693 case AtomicRMWKind::ori:
2694 case AtomicRMWKind::xori:
2695 return builder.getZeroAttr(resultType);
2696 case AtomicRMWKind::andi:
2697 return builder.getIntegerAttr(
2698 resultType,
2699 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2700 case AtomicRMWKind::maxs:
2701 return builder.getIntegerAttr(
2702 resultType, APInt::getSignedMinValue(
2703 llvm::cast<IntegerType>(resultType).getWidth()));
2704 case AtomicRMWKind::minimumf: {
2705 const llvm::fltSemantics &semantic =
2706 llvm::cast<FloatType>(resultType).getFloatSemantics();
2707 APFloat identity = useOnlyFiniteValue
2708 ? APFloat::getLargest(semantic, /*Negative=*/false)
2709 : APFloat::getInf(semantic, /*Negative=*/false);
2710
2711 return builder.getFloatAttr(resultType, identity);
2712 }
2713 case AtomicRMWKind::minnumf: {
2714 const llvm::fltSemantics &semantic =
2715 llvm::cast<FloatType>(resultType).getFloatSemantics();
2716 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2717 return builder.getFloatAttr(resultType, identity);
2718 }
2719 case AtomicRMWKind::mins:
2720 return builder.getIntegerAttr(
2721 resultType, APInt::getSignedMaxValue(
2722 llvm::cast<IntegerType>(resultType).getWidth()));
2723 case AtomicRMWKind::minu:
2724 return builder.getIntegerAttr(
2725 resultType,
2726 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2727 case AtomicRMWKind::muli:
2728 return builder.getIntegerAttr(resultType, 1);
2729 case AtomicRMWKind::mulf:
2730 return builder.getFloatAttr(resultType, 1);
2731 // TODO: Add remaining reduction operations.
2732 default:
2733 (void)emitOptionalError(loc, "Reduction operation type not supported");
2734 break;
2735 }
2736 return nullptr;
2737}
2738
2739/// Returns the identity numeric value of the given op.
2740std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2741 std::optional<AtomicRMWKind> maybeKind =
2743 // Floating-point operations.
2744 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2745 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2746 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2747 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2748 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2749 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2750 // Integer operations.
2751 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2752 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2753 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2754 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2755 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2756 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2757 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2758 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2759 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2760 .Default(std::nullopt);
2761 if (!maybeKind) {
2762 return std::nullopt;
2763 }
2764
2765 bool useOnlyFiniteValue = false;
2766 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2767 if (fmfOpInterface) {
2768 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2769 useOnlyFiniteValue =
2770 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2771 }
2772
2773 // Builder only used as helper for attribute creation.
2774 OpBuilder b(op->getContext());
2775 Type resultType = op->getResult(0).getType();
2776
2777 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2778 useOnlyFiniteValue);
2779}
2780
2781/// Returns the identity value associated with an AtomicRMWKind op.
2782Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2783 OpBuilder &builder, Location loc,
2784 bool useOnlyFiniteValue) {
2785 auto attr =
2786 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2787 return arith::ConstantOp::create(builder, loc, attr);
2788}
2789
2790/// Return the value obtained by applying the reduction operation kind
2791/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2793 Location loc, Value lhs, Value rhs) {
2794 switch (op) {
2795 case AtomicRMWKind::addf:
2796 return arith::AddFOp::create(builder, loc, lhs, rhs);
2797 case AtomicRMWKind::addi:
2798 return arith::AddIOp::create(builder, loc, lhs, rhs);
2799 case AtomicRMWKind::mulf:
2800 return arith::MulFOp::create(builder, loc, lhs, rhs);
2801 case AtomicRMWKind::muli:
2802 return arith::MulIOp::create(builder, loc, lhs, rhs);
2803 case AtomicRMWKind::maximumf:
2804 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2805 case AtomicRMWKind::minimumf:
2806 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2807 case AtomicRMWKind::maxnumf:
2808 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2809 case AtomicRMWKind::minnumf:
2810 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2811 case AtomicRMWKind::maxs:
2812 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2813 case AtomicRMWKind::mins:
2814 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2815 case AtomicRMWKind::maxu:
2816 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2817 case AtomicRMWKind::minu:
2818 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2819 case AtomicRMWKind::ori:
2820 return arith::OrIOp::create(builder, loc, lhs, rhs);
2821 case AtomicRMWKind::andi:
2822 return arith::AndIOp::create(builder, loc, lhs, rhs);
2823 case AtomicRMWKind::xori:
2824 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2825 // TODO: Add remaining reduction operations.
2826 default:
2827 (void)emitOptionalError(loc, "Reduction operation type not supported");
2828 break;
2829 }
2830 return nullptr;
2831}
2832
2833//===----------------------------------------------------------------------===//
2834// TableGen'd op method definitions
2835//===----------------------------------------------------------------------===//
2836
2837#define GET_OP_CLASSES
2838#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2839
2840//===----------------------------------------------------------------------===//
2841// TableGen'd enum attribute definitions
2842//===----------------------------------------------------------------------===//
2843
2844#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:714
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:675
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:774
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:756
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:149
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:129
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:955
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:173
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:436
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:141
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isIndex() const
Definition Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:333
static bool classof(Operation *op)
Definition ArithOps.cpp:350
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:327
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:356
static bool classof(Operation *op)
Definition ArithOps.cpp:377
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:261
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:254
static bool classof(Operation *op)
Definition ArithOps.cpp:321
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:383
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.