MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
126 invertPredicate(pred.getValue()));
127}
128
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
138 return getScalarOrElementWidth(value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, m_ConstantInt(&value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(type.getContext(), value);
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152 if (!shapedType)
153 return boolAttr;
154 // DenseElementsAttr requires a static shape.
155 if (!shapedType.hasStaticShape())
156 return {};
157 return DenseElementsAttr::get(shapedType, boolAttr);
158}
159
160//===----------------------------------------------------------------------===//
161// TableGen'd canonicalization patterns
162//===----------------------------------------------------------------------===//
163
164namespace {
165#include "ArithCanonicalization.inc"
166} // namespace
167
168//===----------------------------------------------------------------------===//
169// Common helpers
170//===----------------------------------------------------------------------===//
171
172/// Return the type of the same shape (scalar, vector or tensor) containing i1.
174 auto i1Type = IntegerType::get(type.getContext(), 1);
175 if (auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
179 return i1Type;
180}
181
182//===----------------------------------------------------------------------===//
183// ConstantOp
184//===----------------------------------------------------------------------===//
185
186void arith::ConstantOp::getAsmResultNames(
187 function_ref<void(Value, StringRef)> setNameFn) {
188 auto type = getType();
189 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
191
192 // Sugar i1 constants with 'true' and 'false'.
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
195
196 // Otherwise, build a complex name with the value and type.
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName << 'c' << intCst.getValue();
200 if (intType)
201 specialName << '_' << type;
202 setNameFn(getResult(), specialName.str());
203 } else {
204 setNameFn(getResult(), "cst");
205 }
206}
207
208/// TODO: disallow arith.constant to return anything other than signless integer
209/// or float like.
210LogicalResult arith::ConstantOp::verify() {
211 auto type = getType();
212 // Integer values must be signless.
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError("integer return type must be signless");
216 // Any float or elements attribute are acceptable.
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
218 return emitOpError(
219 "value must be an integer, float, or elements attribute");
220 }
221
222 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
223 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
224 // However, this would most likely require updating the lowerings to LLVM.
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
226 return emitOpError(
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
229 return success();
230}
231
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
233 // The value's type must be the same as the provided type.
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
236 return false;
237 // Integer values must be signless.
238 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
239 if (!intType.isSignless())
240 return false;
241 }
242 // Integer, float, and element attributes are buildable.
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
244}
245
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
250 return nullptr;
251}
252
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
254
256 int64_t value, unsigned width) {
257 auto type = builder.getIntegerType(width);
258 arith::ConstantOp::build(builder, result, type,
259 builder.getIntegerAttr(type, value));
260}
261
263 Location location,
265 unsigned width) {
266 mlir::OperationState state(location, getOperationName());
267 build(builder, state, value, width);
268 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
269 assert(result && "builder didn't return the right type");
270 return result;
271}
272
275 unsigned width) {
276 return create(builder, builder.getLoc(), value, width);
277}
278
280 Type type, int64_t value) {
281 arith::ConstantOp::build(builder, result, type,
282 builder.getIntegerAttr(type, value));
283}
284
286 Location location, Type type,
287 int64_t value) {
288 mlir::OperationState state(location, getOperationName());
289 build(builder, state, type, value);
290 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
291 assert(result && "builder didn't return the right type");
292 return result;
293}
294
296 Type type, int64_t value) {
297 return create(builder, builder.getLoc(), type, value);
298}
299
301 Type type, const APInt &value) {
302 arith::ConstantOp::build(builder, result, type,
303 builder.getIntegerAttr(type, value));
304}
305
307 Location location, Type type,
308 const APInt &value) {
309 mlir::OperationState state(location, getOperationName());
310 build(builder, state, type, value);
311 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
312 assert(result && "builder didn't return the right type");
313 return result;
314}
315
317 Type type,
318 const APInt &value) {
319 return create(builder, builder.getLoc(), type, value);
320}
321
323 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
325 return false;
326}
327
329 FloatType type, const APFloat &value) {
330 arith::ConstantOp::build(builder, result, type,
331 builder.getFloatAttr(type, value));
332}
333
335 Location location,
336 FloatType type,
337 const APFloat &value) {
338 mlir::OperationState state(location, getOperationName());
339 build(builder, state, type, value);
340 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
341 assert(result && "builder didn't return the right type");
342 return result;
343}
344
347 const APFloat &value) {
348 return create(builder, builder.getLoc(), type, value);
349}
350
352 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
354 return false;
355}
356
358 int64_t value) {
359 arith::ConstantOp::build(builder, result, builder.getIndexType(),
360 builder.getIndexAttr(value));
361}
362
364 Location location,
365 int64_t value) {
366 mlir::OperationState state(location, getOperationName());
367 build(builder, state, value);
368 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
369 assert(result && "builder didn't return the right type");
370 return result;
371}
372
375 return create(builder, builder.getLoc(), value);
376}
377
379 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
381 return false;
382}
383
385 Type type) {
386 // TODO: Incorporate this check to `FloatAttr::get*`.
387 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
388 "type doesn't have a zero representation");
389 TypedAttr zeroAttr = builder.getZeroAttr(type);
390 assert(zeroAttr && "unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
392}
393
394//===----------------------------------------------------------------------===//
395// AddIOp
396//===----------------------------------------------------------------------===//
397
398OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
399 // addi(x, 0) -> x
400 if (matchPattern(adaptor.getRhs(), m_Zero()))
401 return getLhs();
402
403 // addi(subi(a, b), b) -> a
404 if (auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
406 return sub.getLhs();
407
408 // addi(b, subi(a, b)) -> a
409 if (auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
411 return sub.getLhs();
412
414 adaptor.getOperands(),
415 [](APInt a, const APInt &b) { return std::move(a) + b; });
416}
417
418void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
419 MLIRContext *context) {
420 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
422}
423
424//===----------------------------------------------------------------------===//
425// AddUIExtendedOp
426//===----------------------------------------------------------------------===//
427
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (auto vt = dyn_cast<VectorType>(getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
432 return std::nullopt;
433}
434
435// Returns the overflow bit, assuming that `sum` is the result of unsigned
436// addition of `operand` and another number.
437static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
439}
440
441LogicalResult
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
445 // addui_extended(x, 0) -> x, false
446 if (matchPattern(getRhs(), m_Zero())) {
447 Builder builder(getContext());
448 auto falseValue = builder.getZeroAttr(overflowTy);
449
450 results.push_back(getLhs());
451 results.push_back(falseValue);
452 return success();
453 }
454
455 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
456 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
457 // operands. If that succeeds, calculate the overflow bit based on the sum
458 // and the first (constant) operand, `lhs`.
459 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
460 adaptor.getOperands(),
461 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
462 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
463 ArrayRef({sumAttr, adaptor.getLhs()}),
464 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
466 if (!overflowAttr)
467 return failure();
468
469 results.push_back(sumAttr);
470 results.push_back(overflowAttr);
471 return success();
472 }
473
474 return failure();
475}
476
477void arith::AddUIExtendedOp::getCanonicalizationPatterns(
478 RewritePatternSet &patterns, MLIRContext *context) {
479 patterns.add<AddUIExtendedToAddI>(context);
480}
481
482//===----------------------------------------------------------------------===//
483// SubIOp
484//===----------------------------------------------------------------------===//
485
486OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
487 // subi(x,x) -> 0
488 if (getOperand(0) == getOperand(1)) {
489 auto shapedType = dyn_cast<ShapedType>(getType());
490 // We can't generate a constant with a dynamic shaped tensor.
491 if (!shapedType || shapedType.hasStaticShape())
492 return Builder(getContext()).getZeroAttr(getType());
493 }
494 // subi(x,0) -> x
495 if (matchPattern(adaptor.getRhs(), m_Zero()))
496 return getLhs();
497
498 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
499 // subi(addi(a, b), b) -> a
500 if (getRhs() == add.getRhs())
501 return add.getLhs();
502 // subi(addi(a, b), a) -> b
503 if (getRhs() == add.getLhs())
504 return add.getRhs();
505 }
506
508 adaptor.getOperands(),
509 [](APInt a, const APInt &b) { return std::move(a) - b; });
510}
511
512void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
513 MLIRContext *context) {
514 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
515 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
516 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
517}
518
519//===----------------------------------------------------------------------===//
520// MulIOp
521//===----------------------------------------------------------------------===//
522
523OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
524 // muli(x, 0) -> 0
525 if (matchPattern(adaptor.getRhs(), m_Zero()))
526 return getRhs();
527 // muli(x, 1) -> x
528 if (matchPattern(adaptor.getRhs(), m_One()))
529 return getLhs();
530 // TODO: Handle the overflow case.
531
532 // default folder
534 adaptor.getOperands(),
535 [](const APInt &a, const APInt &b) { return a * b; });
536}
537
538void arith::MulIOp::getAsmResultNames(
539 function_ref<void(Value, StringRef)> setNameFn) {
540 if (!isa<IndexType>(getType()))
541 return;
542
543 // Match vector.vscale by name to avoid depending on the vector dialect (which
544 // is a circular dependency).
545 auto isVscale = [](Operation *op) {
546 return op && op->getName().getStringRef() == "vector.vscale";
547 };
548
549 IntegerAttr baseValue;
550 auto isVscaleExpr = [&](Value a, Value b) {
551 return matchPattern(a, m_Constant(&baseValue)) &&
552 isVscale(b.getDefiningOp());
553 };
554
555 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
556 return;
557
558 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
559 SmallString<32> specialNameBuffer;
560 llvm::raw_svector_ostream specialName(specialNameBuffer);
561 specialName << 'c' << baseValue.getInt() << "_vscale";
562 setNameFn(getResult(), specialName.str());
563}
564
565void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
566 MLIRContext *context) {
567 patterns.add<MulIMulIConstant>(context);
568}
569
570//===----------------------------------------------------------------------===//
571// MulSIExtendedOp
572//===----------------------------------------------------------------------===//
573
574std::optional<SmallVector<int64_t, 4>>
575arith::MulSIExtendedOp::getShapeForUnroll() {
576 if (auto vt = dyn_cast<VectorType>(getType(0)))
577 return llvm::to_vector<4>(vt.getShape());
578 return std::nullopt;
579}
580
581LogicalResult
582arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
583 SmallVectorImpl<OpFoldResult> &results) {
584 // mulsi_extended(x, 0) -> 0, 0
585 if (matchPattern(adaptor.getRhs(), m_Zero())) {
586 Attribute zero = adaptor.getRhs();
587 results.push_back(zero);
588 results.push_back(zero);
589 return success();
590 }
591
592 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
593 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
594 adaptor.getOperands(),
595 [](const APInt &a, const APInt &b) { return a * b; })) {
596 // Invoke the constant fold helper again to calculate the 'high' result.
597 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
598 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
599 return llvm::APIntOps::mulhs(a, b);
600 });
601 assert(highAttr && "Unexpected constant-folding failure");
602
603 results.push_back(lowAttr);
604 results.push_back(highAttr);
605 return success();
606 }
607
608 return failure();
609}
610
611void arith::MulSIExtendedOp::getCanonicalizationPatterns(
612 RewritePatternSet &patterns, MLIRContext *context) {
613 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
614}
615
616//===----------------------------------------------------------------------===//
617// MulUIExtendedOp
618//===----------------------------------------------------------------------===//
619
620std::optional<SmallVector<int64_t, 4>>
621arith::MulUIExtendedOp::getShapeForUnroll() {
622 if (auto vt = dyn_cast<VectorType>(getType(0)))
623 return llvm::to_vector<4>(vt.getShape());
624 return std::nullopt;
625}
626
627LogicalResult
628arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
629 SmallVectorImpl<OpFoldResult> &results) {
630 // mului_extended(x, 0) -> 0, 0
631 if (matchPattern(adaptor.getRhs(), m_Zero())) {
632 Attribute zero = adaptor.getRhs();
633 results.push_back(zero);
634 results.push_back(zero);
635 return success();
636 }
637
638 // mului_extended(x, 1) -> x, 0
639 if (matchPattern(adaptor.getRhs(), m_One())) {
640 Builder builder(getContext());
641 Attribute zero = builder.getZeroAttr(getLhs().getType());
642 results.push_back(getLhs());
643 results.push_back(zero);
644 return success();
645 }
646
647 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
648 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
649 adaptor.getOperands(),
650 [](const APInt &a, const APInt &b) { return a * b; })) {
651 // Invoke the constant fold helper again to calculate the 'high' result.
652 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
653 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
654 return llvm::APIntOps::mulhu(a, b);
655 });
656 assert(highAttr && "Unexpected constant-folding failure");
657
658 results.push_back(lowAttr);
659 results.push_back(highAttr);
660 return success();
661 }
662
663 return failure();
664}
665
666void arith::MulUIExtendedOp::getCanonicalizationPatterns(
667 RewritePatternSet &patterns, MLIRContext *context) {
668 patterns.add<MulUIExtendedToMulI>(context);
669}
670
671//===----------------------------------------------------------------------===//
672// DivUIOp
673//===----------------------------------------------------------------------===//
674
675/// Fold `(a * b) / b -> a`
677 arith::IntegerOverflowFlags ovfFlags) {
678 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
679 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
680 return {};
681
682 if (mul.getLhs() == rhs)
683 return mul.getRhs();
684
685 if (mul.getRhs() == rhs)
686 return mul.getLhs();
687
688 return {};
689}
690
691OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
692 // divui (x, 1) -> x.
693 if (matchPattern(adaptor.getRhs(), m_One()))
694 return getLhs();
695
696 // (a * b) / b -> a
697 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
698 return val;
699
700 // Don't fold if it would require a division by zero.
701 bool div0 = false;
702 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
703 [&](APInt a, const APInt &b) {
704 if (div0 || !b) {
705 div0 = true;
706 return a;
707 }
708 return a.udiv(b);
709 });
710
711 return div0 ? Attribute() : result;
712}
713
714/// Returns whether an unsigned division by `divisor` is speculatable.
716 // X / 0 => UB
717 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
719
721}
722
723Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
724 return getDivUISpeculatability(getRhs());
725}
726
727//===----------------------------------------------------------------------===//
728// DivSIOp
729//===----------------------------------------------------------------------===//
730
731OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
732 // divsi (x, 1) -> x.
733 if (matchPattern(adaptor.getRhs(), m_One()))
734 return getLhs();
735
736 // (a * b) / b -> a
737 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
738 return val;
739
740 // Don't fold if it would overflow or if it requires a division by zero.
741 bool overflowOrDiv0 = false;
743 adaptor.getOperands(), [&](APInt a, const APInt &b) {
744 if (overflowOrDiv0 || !b) {
745 overflowOrDiv0 = true;
746 return a;
747 }
748 return a.sdiv_ov(b, overflowOrDiv0);
749 });
750
751 return overflowOrDiv0 ? Attribute() : result;
752}
753
754/// Returns whether a signed division by `divisor` is speculatable. This
755/// function conservatively assumes that all signed division by -1 are not
756/// speculatable.
758 // X / 0 => UB
759 // INT_MIN / -1 => UB
760 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
763
765}
766
767Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
768 return getDivSISpeculatability(getRhs());
769}
770
771//===----------------------------------------------------------------------===//
772// Ceil and floor division folding helpers
773//===----------------------------------------------------------------------===//
774
775static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
776 bool &overflow) {
777 // Returns (a-1)/b + 1
778 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
779 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
780 return val.sadd_ov(one, overflow);
781}
782
783//===----------------------------------------------------------------------===//
784// CeilDivUIOp
785//===----------------------------------------------------------------------===//
786
787OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
788 // ceildivui (x, 1) -> x.
789 if (matchPattern(adaptor.getRhs(), m_One()))
790 return getLhs();
791
792 bool overflowOrDiv0 = false;
794 adaptor.getOperands(), [&](APInt a, const APInt &b) {
795 if (overflowOrDiv0 || !b) {
796 overflowOrDiv0 = true;
797 return a;
798 }
799 APInt quotient = a.udiv(b);
800 if (!a.urem(b))
801 return quotient;
802 APInt one(a.getBitWidth(), 1, true);
803 return quotient.uadd_ov(one, overflowOrDiv0);
804 });
805
806 return overflowOrDiv0 ? Attribute() : result;
807}
808
809Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
810 return getDivUISpeculatability(getRhs());
811}
812
813//===----------------------------------------------------------------------===//
814// CeilDivSIOp
815//===----------------------------------------------------------------------===//
816
817OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
818 // ceildivsi (x, 1) -> x.
819 if (matchPattern(adaptor.getRhs(), m_One()))
820 return getLhs();
821
822 // Don't fold if it would overflow or if it requires a division by zero.
823 // TODO: This hook won't fold operations where a = MININT, because
824 // negating MININT overflows. This can be improved.
825 bool overflowOrDiv0 = false;
827 adaptor.getOperands(), [&](APInt a, const APInt &b) {
828 if (overflowOrDiv0 || !b) {
829 overflowOrDiv0 = true;
830 return a;
831 }
832 if (!a)
833 return a;
834 // After this point we know that neither a or b are zero.
835 unsigned bits = a.getBitWidth();
836 APInt zero = APInt::getZero(bits);
837 bool aGtZero = a.sgt(zero);
838 bool bGtZero = b.sgt(zero);
839 if (aGtZero && bGtZero) {
840 // Both positive, return ceil(a, b).
841 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
842 }
843
844 // No folding happens if any of the intermediate arithmetic operations
845 // overflows.
846 bool overflowNegA = false;
847 bool overflowNegB = false;
848 bool overflowDiv = false;
849 bool overflowNegRes = false;
850 if (!aGtZero && !bGtZero) {
851 // Both negative, return ceil(-a, -b).
852 APInt posA = zero.ssub_ov(a, overflowNegA);
853 APInt posB = zero.ssub_ov(b, overflowNegB);
854 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
855 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
856 return res;
857 }
858 if (!aGtZero && bGtZero) {
859 // A is negative, b is positive, return - ( -a / b).
860 APInt posA = zero.ssub_ov(a, overflowNegA);
861 APInt div = posA.sdiv_ov(b, overflowDiv);
862 APInt res = zero.ssub_ov(div, overflowNegRes);
863 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
864 return res;
865 }
866 // A is positive, b is negative, return - (a / -b).
867 APInt posB = zero.ssub_ov(b, overflowNegB);
868 APInt div = a.sdiv_ov(posB, overflowDiv);
869 APInt res = zero.ssub_ov(div, overflowNegRes);
870
871 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
872 return res;
873 });
874
875 return overflowOrDiv0 ? Attribute() : result;
876}
877
878Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
879 return getDivSISpeculatability(getRhs());
880}
881
882//===----------------------------------------------------------------------===//
883// FloorDivSIOp
884//===----------------------------------------------------------------------===//
885
886OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
887 // floordivsi (x, 1) -> x.
888 if (matchPattern(adaptor.getRhs(), m_One()))
889 return getLhs();
890
891 // Don't fold if it would overflow or if it requires a division by zero.
892 bool overflowOrDiv = false;
894 adaptor.getOperands(), [&](APInt a, const APInt &b) {
895 if (b.isZero()) {
896 overflowOrDiv = true;
897 return a;
898 }
899 return a.sfloordiv_ov(b, overflowOrDiv);
900 });
901
902 return overflowOrDiv ? Attribute() : result;
903}
904
905//===----------------------------------------------------------------------===//
906// RemUIOp
907//===----------------------------------------------------------------------===//
908
909OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
910 // remui (x, 1) -> 0.
911 if (matchPattern(adaptor.getRhs(), m_One()))
912 return Builder(getContext()).getZeroAttr(getType());
913
914 // Don't fold if it would require a division by zero.
915 bool div0 = false;
916 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
917 [&](APInt a, const APInt &b) {
918 if (div0 || b.isZero()) {
919 div0 = true;
920 return a;
921 }
922 return a.urem(b);
923 });
924
925 return div0 ? Attribute() : result;
926}
927
928//===----------------------------------------------------------------------===//
929// RemSIOp
930//===----------------------------------------------------------------------===//
931
932OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
933 // remsi (x, 1) -> 0.
934 if (matchPattern(adaptor.getRhs(), m_One()))
935 return Builder(getContext()).getZeroAttr(getType());
936
937 // Don't fold if it would require a division by zero.
938 bool div0 = false;
939 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
940 [&](APInt a, const APInt &b) {
941 if (div0 || b.isZero()) {
942 div0 = true;
943 return a;
944 }
945 return a.srem(b);
946 });
947
948 return div0 ? Attribute() : result;
949}
950
951//===----------------------------------------------------------------------===//
952// AndIOp
953//===----------------------------------------------------------------------===//
954
955/// Fold `and(a, and(a, b))` to `and(a, b)`
956static Value foldAndIofAndI(arith::AndIOp op) {
957 for (bool reversePrev : {false, true}) {
958 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
959 .getDefiningOp<arith::AndIOp>();
960 if (!prev)
961 continue;
962
963 Value other = (reversePrev ? op.getLhs() : op.getRhs());
964 if (other != prev.getLhs() && other != prev.getRhs())
965 continue;
966
967 return prev.getResult();
968 }
969 return {};
970}
971
972OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
973 /// and(x, 0) -> 0
974 if (matchPattern(adaptor.getRhs(), m_Zero()))
975 return getRhs();
976 /// and(x, allOnes) -> x
977 APInt intValue;
978 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
979 intValue.isAllOnes())
980 return getLhs();
981 /// and(x, not(x)) -> 0
982 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
983 m_ConstantInt(&intValue))) &&
984 intValue.isAllOnes())
985 return Builder(getContext()).getZeroAttr(getType());
986 /// and(not(x), x) -> 0
987 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
988 m_ConstantInt(&intValue))) &&
989 intValue.isAllOnes())
990 return Builder(getContext()).getZeroAttr(getType());
991
992 /// and(a, and(a, b)) -> and(a, b)
993 if (Value result = foldAndIofAndI(*this))
994 return result;
995
997 adaptor.getOperands(),
998 [](APInt a, const APInt &b) { return std::move(a) & b; });
999}
1000
1001//===----------------------------------------------------------------------===//
1002// OrIOp
1003//===----------------------------------------------------------------------===//
1004
1005OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1006 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1007 /// or(x, 0) -> x
1008 if (rhsVal.isZero())
1009 return getLhs();
1010 /// or(x, <all ones>) -> <all ones>
1011 if (rhsVal.isAllOnes())
1012 return adaptor.getRhs();
1013 }
1014
1015 APInt intValue;
1016 /// or(x, xor(x, 1)) -> 1
1017 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1018 m_ConstantInt(&intValue))) &&
1019 intValue.isAllOnes())
1020 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1021 /// or(xor(x, 1), x) -> 1
1022 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1023 m_ConstantInt(&intValue))) &&
1024 intValue.isAllOnes())
1025 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1026
1028 adaptor.getOperands(),
1029 [](APInt a, const APInt &b) { return std::move(a) | b; });
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// XOrIOp
1034//===----------------------------------------------------------------------===//
1035
1036OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1037 /// xor(x, 0) -> x
1038 if (matchPattern(adaptor.getRhs(), m_Zero()))
1039 return getLhs();
1040 /// xor(x, x) -> 0
1041 if (getLhs() == getRhs())
1042 return Builder(getContext()).getZeroAttr(getType());
1043 /// xor(xor(x, a), a) -> x
1044 /// xor(xor(a, x), a) -> x
1045 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1046 if (prev.getRhs() == getRhs())
1047 return prev.getLhs();
1048 if (prev.getLhs() == getRhs())
1049 return prev.getRhs();
1050 }
1051 /// xor(a, xor(x, a)) -> x
1052 /// xor(a, xor(a, x)) -> x
1053 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1054 if (prev.getRhs() == getLhs())
1055 return prev.getLhs();
1056 if (prev.getLhs() == getLhs())
1057 return prev.getRhs();
1058 }
1059
1061 adaptor.getOperands(),
1062 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1063}
1064
1065void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1066 MLIRContext *context) {
1067 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1068}
1069
1070//===----------------------------------------------------------------------===//
1071// NegFOp
1072//===----------------------------------------------------------------------===//
1073
1074OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1075 /// negf(negf(x)) -> x
1076 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1077 return op.getOperand();
1078 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1079 [](const APFloat &a) { return -a; });
1080}
1081
1082//===----------------------------------------------------------------------===//
1083// AddFOp
1084//===----------------------------------------------------------------------===//
1085
1086OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1087 // addf(x, -0) -> x
1088 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1089 return getLhs();
1090
1092 adaptor.getOperands(),
1093 [](const APFloat &a, const APFloat &b) { return a + b; });
1094}
1095
1096//===----------------------------------------------------------------------===//
1097// SubFOp
1098//===----------------------------------------------------------------------===//
1099
1100OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1101 // subf(x, +0) -> x
1102 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1103 return getLhs();
1104
1106 adaptor.getOperands(),
1107 [](const APFloat &a, const APFloat &b) { return a - b; });
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// MaximumFOp
1112//===----------------------------------------------------------------------===//
1113
1114OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1115 // maximumf(x,x) -> x
1116 if (getLhs() == getRhs())
1117 return getRhs();
1118
1119 // maximumf(x, -inf) -> x
1120 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1121 return getLhs();
1122
1124 adaptor.getOperands(),
1125 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// MaxNumFOp
1130//===----------------------------------------------------------------------===//
1131
1132OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1133 // maxnumf(x,x) -> x
1134 if (getLhs() == getRhs())
1135 return getRhs();
1136
1137 // maxnumf(x, NaN) -> x
1138 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1139 return getLhs();
1140
1141 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1142}
1143
1144//===----------------------------------------------------------------------===//
1145// MaxSIOp
1146//===----------------------------------------------------------------------===//
1147
1148OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1149 // maxsi(x,x) -> x
1150 if (getLhs() == getRhs())
1151 return getRhs();
1152
1153 if (APInt intValue;
1154 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1155 // maxsi(x,MAX_INT) -> MAX_INT
1156 if (intValue.isMaxSignedValue())
1157 return getRhs();
1158 // maxsi(x, MIN_INT) -> x
1159 if (intValue.isMinSignedValue())
1160 return getLhs();
1161 }
1162
1163 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1164 [](const APInt &a, const APInt &b) {
1165 return llvm::APIntOps::smax(a, b);
1166 });
1167}
1168
1169//===----------------------------------------------------------------------===//
1170// MaxUIOp
1171//===----------------------------------------------------------------------===//
1172
1173OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1174 // maxui(x,x) -> x
1175 if (getLhs() == getRhs())
1176 return getRhs();
1177
1178 if (APInt intValue;
1179 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1180 // maxui(x,MAX_INT) -> MAX_INT
1181 if (intValue.isMaxValue())
1182 return getRhs();
1183 // maxui(x, MIN_INT) -> x
1184 if (intValue.isMinValue())
1185 return getLhs();
1186 }
1187
1188 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1189 [](const APInt &a, const APInt &b) {
1190 return llvm::APIntOps::umax(a, b);
1191 });
1192}
1193
1194//===----------------------------------------------------------------------===//
1195// MinimumFOp
1196//===----------------------------------------------------------------------===//
1197
1198OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1199 // minimumf(x,x) -> x
1200 if (getLhs() == getRhs())
1201 return getRhs();
1202
1203 // minimumf(x, +inf) -> x
1204 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1205 return getLhs();
1206
1208 adaptor.getOperands(),
1209 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1210}
1211
1212//===----------------------------------------------------------------------===//
1213// MinNumFOp
1214//===----------------------------------------------------------------------===//
1215
1216OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1217 // minnumf(x,x) -> x
1218 if (getLhs() == getRhs())
1219 return getRhs();
1220
1221 // minnumf(x, NaN) -> x
1222 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1223 return getLhs();
1224
1226 adaptor.getOperands(),
1227 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1228}
1229
1230//===----------------------------------------------------------------------===//
1231// MinSIOp
1232//===----------------------------------------------------------------------===//
1233
1234OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1235 // minsi(x,x) -> x
1236 if (getLhs() == getRhs())
1237 return getRhs();
1238
1239 if (APInt intValue;
1240 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1241 // minsi(x,MIN_INT) -> MIN_INT
1242 if (intValue.isMinSignedValue())
1243 return getRhs();
1244 // minsi(x, MAX_INT) -> x
1245 if (intValue.isMaxSignedValue())
1246 return getLhs();
1247 }
1248
1249 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1250 [](const APInt &a, const APInt &b) {
1251 return llvm::APIntOps::smin(a, b);
1252 });
1253}
1254
1255//===----------------------------------------------------------------------===//
1256// MinUIOp
1257//===----------------------------------------------------------------------===//
1258
1259OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1260 // minui(x,x) -> x
1261 if (getLhs() == getRhs())
1262 return getRhs();
1263
1264 if (APInt intValue;
1265 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1266 // minui(x,MIN_INT) -> MIN_INT
1267 if (intValue.isMinValue())
1268 return getRhs();
1269 // minui(x, MAX_INT) -> x
1270 if (intValue.isMaxValue())
1271 return getLhs();
1272 }
1273
1274 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1275 [](const APInt &a, const APInt &b) {
1276 return llvm::APIntOps::umin(a, b);
1277 });
1278}
1279
1280//===----------------------------------------------------------------------===//
1281// MulFOp
1282//===----------------------------------------------------------------------===//
1283
1284OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1285 // mulf(x, 1) -> x
1286 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1287 return getLhs();
1288
1289 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1290 arith::FastMathFlags::nsz)) {
1291 // mulf(x, 0) -> 0
1292 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1293 return getRhs();
1294 }
1295
1297 adaptor.getOperands(),
1298 [](const APFloat &a, const APFloat &b) { return a * b; });
1299}
1300
1301void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1302 MLIRContext *context) {
1303 patterns.add<MulFOfNegF>(context);
1304}
1305
1306//===----------------------------------------------------------------------===//
1307// DivFOp
1308//===----------------------------------------------------------------------===//
1309
1310OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1311 // divf(x, 1) -> x
1312 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1313 return getLhs();
1314
1316 adaptor.getOperands(),
1317 [](const APFloat &a, const APFloat &b) { return a / b; });
1318}
1319
1320void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1321 MLIRContext *context) {
1322 patterns.add<DivFOfNegF>(context);
1323}
1324
1325//===----------------------------------------------------------------------===//
1326// RemFOp
1327//===----------------------------------------------------------------------===//
1328
1329OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1330 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1331 [](const APFloat &a, const APFloat &b) {
1332 APFloat result(a);
1333 // APFloat::mod() offers the remainder
1334 // behavior we want, i.e. the result has
1335 // the sign of LHS operand.
1336 (void)result.mod(b);
1337 return result;
1338 });
1339}
1340
1341//===----------------------------------------------------------------------===//
1342// Utility functions for verifying cast ops
1343//===----------------------------------------------------------------------===//
1344
1345template <typename... Types>
1346using type_list = std::tuple<Types...> *;
1347
1348/// Returns a non-null type only if the provided type is one of the allowed
1349/// types or one of the allowed shaped types of the allowed types. Returns the
1350/// element type if a valid shaped type is provided.
1351template <typename... ShapedTypes, typename... ElementTypes>
1354 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1355 return {};
1356
1357 auto underlyingType = getElementTypeOrSelf(type);
1358 if (!llvm::isa<ElementTypes...>(underlyingType))
1359 return {};
1360
1361 return underlyingType;
1362}
1363
1364/// Get allowed underlying types for vectors and tensors.
1365template <typename... ElementTypes>
1370
1371/// Get allowed underlying types for vectors, tensors, and memrefs.
1372template <typename... ElementTypes>
1378
1379/// Return false if both types are ranked tensor with mismatching encoding.
1380static bool hasSameEncoding(Type typeA, Type typeB) {
1381 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1382 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1383 if (!rankedTensorA || !rankedTensorB)
1384 return true;
1385 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1386}
1387
1389 if (inputs.size() != 1 || outputs.size() != 1)
1390 return false;
1391 if (!hasSameEncoding(inputs.front(), outputs.front()))
1392 return false;
1393 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1394}
1395
1396//===----------------------------------------------------------------------===//
1397// Verifiers for integer and floating point extension/truncation ops
1398//===----------------------------------------------------------------------===//
1399
1400// Extend ops can only extend to a wider type.
1401template <typename ValType, typename Op>
1402static LogicalResult verifyExtOp(Op op) {
1403 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1404 Type dstType = getElementTypeOrSelf(op.getType());
1405
1406 if (llvm::cast<ValType>(srcType).getWidth() >=
1407 llvm::cast<ValType>(dstType).getWidth())
1408 return op.emitError("result type ")
1409 << dstType << " must be wider than operand type " << srcType;
1410
1411 return success();
1412}
1413
1414// Truncate ops can only truncate to a shorter type.
1415template <typename ValType, typename Op>
1416static LogicalResult verifyTruncateOp(Op op) {
1417 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1418 Type dstType = getElementTypeOrSelf(op.getType());
1419
1420 if (llvm::cast<ValType>(srcType).getWidth() <=
1421 llvm::cast<ValType>(dstType).getWidth())
1422 return op.emitError("result type ")
1423 << dstType << " must be shorter than operand type " << srcType;
1424
1425 return success();
1426}
1427
1428/// Validate a cast that changes the width of a type.
1429template <template <typename> class WidthComparator, typename... ElementTypes>
1430static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1431 if (!areValidCastInputsAndOutputs(inputs, outputs))
1432 return false;
1433
1434 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1435 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1436 if (!srcType || !dstType)
1437 return false;
1438
1439 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1440 srcType.getIntOrFloatBitWidth());
1441}
1442
1443/// Attempts to convert `sourceValue` to an APFloat value with
1444/// `targetSemantics` and `roundingMode`, without any information loss.
1445static FailureOr<APFloat> convertFloatValue(
1446 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1447 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1448 bool losesInfo = false;
1449 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1450 if (losesInfo || status != APFloat::opOK)
1451 return failure();
1452
1453 return sourceValue;
1454}
1455
1456//===----------------------------------------------------------------------===//
1457// ExtUIOp
1458//===----------------------------------------------------------------------===//
1459
1460OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1461 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1462 getInMutable().assign(lhs.getIn());
1463 return getResult();
1464 }
1465
1466 Type resType = getElementTypeOrSelf(getType());
1467 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1469 adaptor.getOperands(), getType(),
1470 [bitWidth](const APInt &a, bool &castStatus) {
1471 return a.zext(bitWidth);
1472 });
1473}
1474
1475bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1477}
1478
1479LogicalResult arith::ExtUIOp::verify() {
1480 return verifyExtOp<IntegerType>(*this);
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// ExtSIOp
1485//===----------------------------------------------------------------------===//
1486
1487OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1488 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1489 getInMutable().assign(lhs.getIn());
1490 return getResult();
1491 }
1492
1493 Type resType = getElementTypeOrSelf(getType());
1494 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1496 adaptor.getOperands(), getType(),
1497 [bitWidth](const APInt &a, bool &castStatus) {
1498 return a.sext(bitWidth);
1499 });
1500}
1501
1502bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1504}
1505
1506void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1507 MLIRContext *context) {
1508 patterns.add<ExtSIOfExtUI>(context);
1509}
1510
1511LogicalResult arith::ExtSIOp::verify() {
1512 return verifyExtOp<IntegerType>(*this);
1513}
1514
1515//===----------------------------------------------------------------------===//
1516// ExtFOp
1517//===----------------------------------------------------------------------===//
1518
1519/// Fold extension of float constants when there is no information loss due the
1520/// difference in fp semantics.
1521OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1522 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1523 if (truncFOp.getOperand().getType() == getType()) {
1524 arith::FastMathFlags truncFMF =
1525 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1526 bool isTruncContract =
1527 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1528 arith::FastMathFlags extFMF =
1529 getFastmath().value_or(arith::FastMathFlags::none);
1530 bool isExtContract =
1531 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1532 if (isTruncContract && isExtContract) {
1533 return truncFOp.getOperand();
1534 }
1535 }
1536 }
1537
1538 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1539 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1541 adaptor.getOperands(), getType(),
1542 [&targetSemantics](const APFloat &a, bool &castStatus) {
1543 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1544 if (failed(result)) {
1545 castStatus = false;
1546 return a;
1547 }
1548 return *result;
1549 });
1550}
1551
1552bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1553 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1554}
1555
1556LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1557
1558//===----------------------------------------------------------------------===//
1559// ScalingExtFOp
1560//===----------------------------------------------------------------------===//
1561
1562bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1563 TypeRange outputs) {
1564 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1565}
1566
1567LogicalResult arith::ScalingExtFOp::verify() {
1568 return verifyExtOp<FloatType>(*this);
1569}
1570
1571//===----------------------------------------------------------------------===//
1572// TruncIOp
1573//===----------------------------------------------------------------------===//
1574
1575OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1576 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1577 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1578 Value src = getOperand().getDefiningOp()->getOperand(0);
1579 Type srcType = getElementTypeOrSelf(src.getType());
1580 Type dstType = getElementTypeOrSelf(getType());
1581 // trunci(zexti(a)) -> trunci(a)
1582 // trunci(sexti(a)) -> trunci(a)
1583 if (llvm::cast<IntegerType>(srcType).getWidth() >
1584 llvm::cast<IntegerType>(dstType).getWidth()) {
1585 setOperand(src);
1586 return getResult();
1587 }
1588
1589 // trunci(zexti(a)) -> a
1590 // trunci(sexti(a)) -> a
1591 if (srcType == dstType)
1592 return src;
1593 }
1594
1595 // trunci(trunci(a)) -> trunci(a))
1596 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1597 setOperand(getOperand().getDefiningOp()->getOperand(0));
1598 return getResult();
1599 }
1600
1601 Type resType = getElementTypeOrSelf(getType());
1602 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1604 adaptor.getOperands(), getType(),
1605 [bitWidth](const APInt &a, bool &castStatus) {
1606 return a.trunc(bitWidth);
1607 });
1608}
1609
1610bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1611 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1612}
1613
1614void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1615 MLIRContext *context) {
1616 patterns
1617 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1618 context);
1619}
1620
1621LogicalResult arith::TruncIOp::verify() {
1622 return verifyTruncateOp<IntegerType>(*this);
1623}
1624
1625//===----------------------------------------------------------------------===//
1626// TruncFOp
1627//===----------------------------------------------------------------------===//
1628
1629/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1630/// can be represented without precision loss.
1631OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1632 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1633 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1634 Value src = extOp.getIn();
1635 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1636 auto intermediateType =
1637 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1638 // Check if the srcType is representable in the intermediateType.
1639 if (llvm::APFloatBase::isRepresentableBy(
1640 srcType.getFloatSemantics(),
1641 intermediateType.getFloatSemantics())) {
1642 // truncf(extf(a)) -> truncf(a)
1643 if (srcType.getWidth() > resElemType.getWidth()) {
1644 setOperand(src);
1645 return getResult();
1646 }
1647
1648 // truncf(extf(a)) -> a
1649 if (srcType == resElemType)
1650 return src;
1651 }
1652 }
1653
1654 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1656 adaptor.getOperands(), getType(),
1657 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1658 RoundingMode roundingMode =
1659 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1660 llvm::RoundingMode llvmRoundingMode =
1662 FailureOr<APFloat> result =
1663 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1664 if (failed(result)) {
1665 castStatus = false;
1666 return a;
1667 }
1668 return *result;
1669 });
1670}
1671
1672void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1673 MLIRContext *context) {
1674 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1675}
1676
1677bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1678 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1679}
1680
1681LogicalResult arith::TruncFOp::verify() {
1682 return verifyTruncateOp<FloatType>(*this);
1683}
1684
1685//===----------------------------------------------------------------------===//
1686// ScalingTruncFOp
1687//===----------------------------------------------------------------------===//
1688
1689bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1690 TypeRange outputs) {
1691 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1692}
1693
1694LogicalResult arith::ScalingTruncFOp::verify() {
1695 return verifyTruncateOp<FloatType>(*this);
1696}
1697
1698//===----------------------------------------------------------------------===//
1699// AndIOp
1700//===----------------------------------------------------------------------===//
1701
1702void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1703 MLIRContext *context) {
1704 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1705}
1706
1707//===----------------------------------------------------------------------===//
1708// OrIOp
1709//===----------------------------------------------------------------------===//
1710
1711void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1712 MLIRContext *context) {
1713 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1714}
1715
1716//===----------------------------------------------------------------------===//
1717// Verifiers for casts between integers and floats.
1718//===----------------------------------------------------------------------===//
1719
1720template <typename From, typename To>
1721static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1722 if (!areValidCastInputsAndOutputs(inputs, outputs))
1723 return false;
1724
1725 auto srcType = getTypeIfLike<From>(inputs.front());
1726 auto dstType = getTypeIfLike<To>(outputs.back());
1727
1728 return srcType && dstType;
1729}
1730
1731//===----------------------------------------------------------------------===//
1732// UIToFPOp
1733//===----------------------------------------------------------------------===//
1734
1735bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1736 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1737}
1738
1739OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1740 Type resEleType = getElementTypeOrSelf(getType());
1742 adaptor.getOperands(), getType(),
1743 [&resEleType](const APInt &a, bool &castStatus) {
1744 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1745 APFloat apf(floatTy.getFloatSemantics(),
1746 APInt::getZero(floatTy.getWidth()));
1747 apf.convertFromAPInt(a, /*IsSigned=*/false,
1748 APFloat::rmNearestTiesToEven);
1749 return apf;
1750 });
1751}
1752
1753//===----------------------------------------------------------------------===//
1754// SIToFPOp
1755//===----------------------------------------------------------------------===//
1756
1757bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1758 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1759}
1760
1761OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1762 Type resEleType = getElementTypeOrSelf(getType());
1764 adaptor.getOperands(), getType(),
1765 [&resEleType](const APInt &a, bool &castStatus) {
1766 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1767 APFloat apf(floatTy.getFloatSemantics(),
1768 APInt::getZero(floatTy.getWidth()));
1769 apf.convertFromAPInt(a, /*IsSigned=*/true,
1770 APFloat::rmNearestTiesToEven);
1771 return apf;
1772 });
1773}
1774
1775//===----------------------------------------------------------------------===//
1776// FPToUIOp
1777//===----------------------------------------------------------------------===//
1778
1779bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1780 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1781}
1782
1783OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1784 Type resType = getElementTypeOrSelf(getType());
1785 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1787 adaptor.getOperands(), getType(),
1788 [&bitWidth](const APFloat &a, bool &castStatus) {
1789 bool ignored;
1790 APSInt api(bitWidth, /*isUnsigned=*/true);
1791 castStatus = APFloat::opInvalidOp !=
1792 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1793 return api;
1794 });
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// FPToSIOp
1799//===----------------------------------------------------------------------===//
1800
1801bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1802 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1803}
1804
1805OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1806 Type resType = getElementTypeOrSelf(getType());
1807 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1809 adaptor.getOperands(), getType(),
1810 [&bitWidth](const APFloat &a, bool &castStatus) {
1811 bool ignored;
1812 APSInt api(bitWidth, /*isUnsigned=*/false);
1813 castStatus = APFloat::opInvalidOp !=
1814 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1815 return api;
1816 });
1817}
1818
1819//===----------------------------------------------------------------------===//
1820// IndexCastOp
1821//===----------------------------------------------------------------------===//
1822
1823static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1824 if (!areValidCastInputsAndOutputs(inputs, outputs))
1825 return false;
1826
1827 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1828 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1829 if (!srcType || !dstType)
1830 return false;
1831
1832 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1833 (srcType.isSignlessInteger() && dstType.isIndex());
1834}
1835
1836bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1837 TypeRange outputs) {
1838 return areIndexCastCompatible(inputs, outputs);
1839}
1840
1841OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1842 // index_cast(constant) -> constant
1843 unsigned resultBitwidth = 64; // Default for index integer attributes.
1844 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1845 resultBitwidth = intTy.getWidth();
1846
1848 adaptor.getOperands(), getType(),
1849 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1850 return a.sextOrTrunc(resultBitwidth);
1851 });
1852}
1853
1854void arith::IndexCastOp::getCanonicalizationPatterns(
1855 RewritePatternSet &patterns, MLIRContext *context) {
1856 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1857}
1858
1859//===----------------------------------------------------------------------===//
1860// IndexCastUIOp
1861//===----------------------------------------------------------------------===//
1862
1863bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1864 TypeRange outputs) {
1865 return areIndexCastCompatible(inputs, outputs);
1866}
1867
1868OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1869 // index_castui(constant) -> constant
1870 unsigned resultBitwidth = 64; // Default for index integer attributes.
1871 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1872 resultBitwidth = intTy.getWidth();
1873
1875 adaptor.getOperands(), getType(),
1876 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1877 return a.zextOrTrunc(resultBitwidth);
1878 });
1879}
1880
1881void arith::IndexCastUIOp::getCanonicalizationPatterns(
1882 RewritePatternSet &patterns, MLIRContext *context) {
1883 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1884}
1885
1886//===----------------------------------------------------------------------===//
1887// BitcastOp
1888//===----------------------------------------------------------------------===//
1889
1890bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1891 if (!areValidCastInputsAndOutputs(inputs, outputs))
1892 return false;
1893
1894 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1895 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1896 if (!srcType || !dstType)
1897 return false;
1898
1899 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1900}
1901
1902OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1903 auto resType = getType();
1904 auto operand = adaptor.getIn();
1905 if (!operand)
1906 return {};
1907
1908 /// Bitcast dense elements.
1909 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1910 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1911 /// Other shaped types unhandled.
1912 if (llvm::isa<ShapedType>(resType))
1913 return {};
1914
1915 /// Bitcast poison.
1916 if (llvm::isa<ub::PoisonAttr>(operand))
1917 return ub::PoisonAttr::get(getContext());
1918
1919 /// Bitcast integer or float to integer or float.
1920 APInt bits = llvm::isa<FloatAttr>(operand)
1921 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1922 : llvm::cast<IntegerAttr>(operand).getValue();
1923 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1924 "trying to fold on broken IR: operands have incompatible types");
1925
1926 if (auto resFloatType = dyn_cast<FloatType>(resType))
1927 return FloatAttr::get(resType,
1928 APFloat(resFloatType.getFloatSemantics(), bits));
1929 return IntegerAttr::get(resType, bits);
1930}
1931
1932void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1933 MLIRContext *context) {
1934 patterns.add<BitcastOfBitcast>(context);
1935}
1936
1937//===----------------------------------------------------------------------===//
1938// CmpIOp
1939//===----------------------------------------------------------------------===//
1940
1941/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1942/// comparison predicates.
1943bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1944 const APInt &lhs, const APInt &rhs) {
1945 switch (predicate) {
1946 case arith::CmpIPredicate::eq:
1947 return lhs.eq(rhs);
1948 case arith::CmpIPredicate::ne:
1949 return lhs.ne(rhs);
1950 case arith::CmpIPredicate::slt:
1951 return lhs.slt(rhs);
1952 case arith::CmpIPredicate::sle:
1953 return lhs.sle(rhs);
1954 case arith::CmpIPredicate::sgt:
1955 return lhs.sgt(rhs);
1956 case arith::CmpIPredicate::sge:
1957 return lhs.sge(rhs);
1958 case arith::CmpIPredicate::ult:
1959 return lhs.ult(rhs);
1960 case arith::CmpIPredicate::ule:
1961 return lhs.ule(rhs);
1962 case arith::CmpIPredicate::ugt:
1963 return lhs.ugt(rhs);
1964 case arith::CmpIPredicate::uge:
1965 return lhs.uge(rhs);
1966 }
1967 llvm_unreachable("unknown cmpi predicate kind");
1968}
1969
1970/// Returns true if the predicate is true for two equal operands.
1971static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1972 switch (predicate) {
1973 case arith::CmpIPredicate::eq:
1974 case arith::CmpIPredicate::sle:
1975 case arith::CmpIPredicate::sge:
1976 case arith::CmpIPredicate::ule:
1977 case arith::CmpIPredicate::uge:
1978 return true;
1979 case arith::CmpIPredicate::ne:
1980 case arith::CmpIPredicate::slt:
1981 case arith::CmpIPredicate::sgt:
1982 case arith::CmpIPredicate::ult:
1983 case arith::CmpIPredicate::ugt:
1984 return false;
1985 }
1986 llvm_unreachable("unknown cmpi predicate kind");
1987}
1988
1989static std::optional<int64_t> getIntegerWidth(Type t) {
1990 if (auto intType = dyn_cast<IntegerType>(t)) {
1991 return intType.getWidth();
1992 }
1993 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
1994 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1995 }
1996 return std::nullopt;
1997}
1998
1999OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2000 // cmpi(pred, x, x)
2001 if (getLhs() == getRhs()) {
2002 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2003 return getBoolAttribute(getType(), val);
2004 }
2005
2006 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2007 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2008 // extsi(%x : i1 -> iN) != 0 -> %x
2009 std::optional<int64_t> integerWidth =
2010 getIntegerWidth(extOp.getOperand().getType());
2011 if (integerWidth && integerWidth.value() == 1 &&
2012 getPredicate() == arith::CmpIPredicate::ne)
2013 return extOp.getOperand();
2014 }
2015 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2016 // extui(%x : i1 -> iN) != 0 -> %x
2017 std::optional<int64_t> integerWidth =
2018 getIntegerWidth(extOp.getOperand().getType());
2019 if (integerWidth && integerWidth.value() == 1 &&
2020 getPredicate() == arith::CmpIPredicate::ne)
2021 return extOp.getOperand();
2022 }
2023
2024 // arith.cmpi ne, %val, %zero : i1 -> %val
2025 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2026 getPredicate() == arith::CmpIPredicate::ne)
2027 return getLhs();
2028 }
2029
2030 if (matchPattern(adaptor.getRhs(), m_One())) {
2031 // arith.cmpi eq, %val, %one : i1 -> %val
2032 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2033 getPredicate() == arith::CmpIPredicate::eq)
2034 return getLhs();
2035 }
2036
2037 // Move constant to the right side.
2038 if (adaptor.getLhs() && !adaptor.getRhs()) {
2039 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2040 using Pred = CmpIPredicate;
2041 const std::pair<Pred, Pred> invPreds[] = {
2042 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2043 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2044 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2045 {Pred::ne, Pred::ne},
2046 };
2047 Pred origPred = getPredicate();
2048 for (auto pred : invPreds) {
2049 if (origPred == pred.first) {
2050 setPredicate(pred.second);
2051 Value lhs = getLhs();
2052 Value rhs = getRhs();
2053 getLhsMutable().assign(rhs);
2054 getRhsMutable().assign(lhs);
2055 return getResult();
2056 }
2057 }
2058 llvm_unreachable("unknown cmpi predicate kind");
2059 }
2060
2061 // We are moving constants to the right side; So if lhs is constant rhs is
2062 // guaranteed to be a constant.
2063 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2065 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2066 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2067 return APInt(1,
2068 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2069 });
2070 }
2071
2072 return {};
2073}
2074
2075void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2076 MLIRContext *context) {
2077 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// CmpFOp
2082//===----------------------------------------------------------------------===//
2083
2084/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2085/// comparison predicates.
2086bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2087 const APFloat &lhs, const APFloat &rhs) {
2088 auto cmpResult = lhs.compare(rhs);
2089 switch (predicate) {
2090 case arith::CmpFPredicate::AlwaysFalse:
2091 return false;
2092 case arith::CmpFPredicate::OEQ:
2093 return cmpResult == APFloat::cmpEqual;
2094 case arith::CmpFPredicate::OGT:
2095 return cmpResult == APFloat::cmpGreaterThan;
2096 case arith::CmpFPredicate::OGE:
2097 return cmpResult == APFloat::cmpGreaterThan ||
2098 cmpResult == APFloat::cmpEqual;
2099 case arith::CmpFPredicate::OLT:
2100 return cmpResult == APFloat::cmpLessThan;
2101 case arith::CmpFPredicate::OLE:
2102 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2103 case arith::CmpFPredicate::ONE:
2104 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2105 case arith::CmpFPredicate::ORD:
2106 return cmpResult != APFloat::cmpUnordered;
2107 case arith::CmpFPredicate::UEQ:
2108 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2109 case arith::CmpFPredicate::UGT:
2110 return cmpResult == APFloat::cmpUnordered ||
2111 cmpResult == APFloat::cmpGreaterThan;
2112 case arith::CmpFPredicate::UGE:
2113 return cmpResult == APFloat::cmpUnordered ||
2114 cmpResult == APFloat::cmpGreaterThan ||
2115 cmpResult == APFloat::cmpEqual;
2116 case arith::CmpFPredicate::ULT:
2117 return cmpResult == APFloat::cmpUnordered ||
2118 cmpResult == APFloat::cmpLessThan;
2119 case arith::CmpFPredicate::ULE:
2120 return cmpResult == APFloat::cmpUnordered ||
2121 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2122 case arith::CmpFPredicate::UNE:
2123 return cmpResult != APFloat::cmpEqual;
2124 case arith::CmpFPredicate::UNO:
2125 return cmpResult == APFloat::cmpUnordered;
2126 case arith::CmpFPredicate::AlwaysTrue:
2127 return true;
2128 }
2129 llvm_unreachable("unknown cmpf predicate kind");
2130}
2131
2132OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2133 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2134 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2135
2136 // If one operand is NaN, making them both NaN does not change the result.
2137 if (lhs && lhs.getValue().isNaN())
2138 rhs = lhs;
2139 if (rhs && rhs.getValue().isNaN())
2140 lhs = rhs;
2141
2142 if (!lhs || !rhs)
2143 return {};
2144
2145 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2146 return BoolAttr::get(getContext(), val);
2147}
2148
2149class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2150public:
2151 using Base::Base;
2152
2153 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2154 bool isUnsigned) {
2155 using namespace arith;
2156 switch (pred) {
2157 case CmpFPredicate::UEQ:
2158 case CmpFPredicate::OEQ:
2159 return CmpIPredicate::eq;
2160 case CmpFPredicate::UGT:
2161 case CmpFPredicate::OGT:
2162 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2163 case CmpFPredicate::UGE:
2164 case CmpFPredicate::OGE:
2165 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2166 case CmpFPredicate::ULT:
2167 case CmpFPredicate::OLT:
2168 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2169 case CmpFPredicate::ULE:
2170 case CmpFPredicate::OLE:
2171 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2172 case CmpFPredicate::UNE:
2173 case CmpFPredicate::ONE:
2174 return CmpIPredicate::ne;
2175 default:
2176 llvm_unreachable("Unexpected predicate!");
2177 }
2178 }
2179
2180 LogicalResult matchAndRewrite(CmpFOp op,
2181 PatternRewriter &rewriter) const override {
2182 FloatAttr flt;
2183 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2184 return failure();
2185
2186 const APFloat &rhs = flt.getValue();
2187
2188 // Don't attempt to fold a nan.
2189 if (rhs.isNaN())
2190 return failure();
2191
2192 // Get the width of the mantissa. We don't want to hack on conversions that
2193 // might lose information from the integer, e.g. "i64 -> float"
2194 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2195 int mantissaWidth = floatTy.getFPMantissaWidth();
2196 if (mantissaWidth <= 0)
2197 return failure();
2198
2199 bool isUnsigned;
2200 Value intVal;
2201
2202 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2203 isUnsigned = false;
2204 intVal = si.getIn();
2205 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2206 isUnsigned = true;
2207 intVal = ui.getIn();
2208 } else {
2209 return failure();
2210 }
2211
2212 // Check to see that the input is converted from an integer type that is
2213 // small enough that preserves all bits.
2214 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2215 auto intWidth = intTy.getWidth();
2216
2217 // Number of bits representing values, as opposed to the sign
2218 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2219
2220 // Following test does NOT adjust intWidth downwards for signed inputs,
2221 // because the most negative value still requires all the mantissa bits
2222 // to distinguish it from one less than that value.
2223 if ((int)intWidth > mantissaWidth) {
2224 // Conversion would lose accuracy. Check if loss can impact comparison.
2225 int exponent = ilogb(rhs);
2226 if (exponent == APFloat::IEK_Inf) {
2227 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2228 if (maxExponent < (int)valueBits) {
2229 // Conversion could create infinity.
2230 return failure();
2231 }
2232 } else {
2233 // Note that if rhs is zero or NaN, then Exp is negative
2234 // and first condition is trivially false.
2235 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2236 // Conversion could affect comparison.
2237 return failure();
2238 }
2239 }
2240 }
2241
2242 // Convert to equivalent cmpi predicate
2243 CmpIPredicate pred;
2244 switch (op.getPredicate()) {
2245 case CmpFPredicate::ORD:
2246 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2247 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2248 /*width=*/1);
2249 return success();
2250 case CmpFPredicate::UNO:
2251 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2252 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2253 /*width=*/1);
2254 return success();
2255 default:
2256 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2257 break;
2258 }
2259
2260 if (!isUnsigned) {
2261 // If the rhs value is > SignedMax, fold the comparison. This handles
2262 // +INF and large values.
2263 APFloat signedMax(rhs.getSemantics());
2264 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2265 APFloat::rmNearestTiesToEven);
2266 if (signedMax < rhs) { // smax < 13123.0
2267 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2268 pred == CmpIPredicate::sle)
2269 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2270 /*width=*/1);
2271 else
2272 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2273 /*width=*/1);
2274 return success();
2275 }
2276 } else {
2277 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2278 // +INF and large values.
2279 APFloat unsignedMax(rhs.getSemantics());
2280 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2281 APFloat::rmNearestTiesToEven);
2282 if (unsignedMax < rhs) { // umax < 13123.0
2283 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2284 pred == CmpIPredicate::ule)
2285 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2286 /*width=*/1);
2287 else
2288 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2289 /*width=*/1);
2290 return success();
2291 }
2292 }
2293
2294 if (!isUnsigned) {
2295 // See if the rhs value is < SignedMin.
2296 APFloat signedMin(rhs.getSemantics());
2297 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2298 APFloat::rmNearestTiesToEven);
2299 if (signedMin > rhs) { // smin > 12312.0
2300 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2301 pred == CmpIPredicate::sge)
2302 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2303 /*width=*/1);
2304 else
2305 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2306 /*width=*/1);
2307 return success();
2308 }
2309 } else {
2310 // See if the rhs value is < UnsignedMin.
2311 APFloat unsignedMin(rhs.getSemantics());
2312 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2313 APFloat::rmNearestTiesToEven);
2314 if (unsignedMin > rhs) { // umin > 12312.0
2315 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2316 pred == CmpIPredicate::uge)
2317 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2318 /*width=*/1);
2319 else
2320 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2321 /*width=*/1);
2322 return success();
2323 }
2324 }
2325
2326 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2327 // [0, UMAX], but it may still be fractional. See if it is fractional by
2328 // casting the FP value to the integer value and back, checking for
2329 // equality. Don't do this for zero, because -0.0 is not fractional.
2330 bool ignored;
2331 APSInt rhsInt(intWidth, isUnsigned);
2332 if (APFloat::opInvalidOp ==
2333 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2334 // Undefined behavior invoked - the destination type can't represent
2335 // the input constant.
2336 return failure();
2337 }
2338
2339 if (!rhs.isZero()) {
2340 APFloat apf(floatTy.getFloatSemantics(),
2341 APInt::getZero(floatTy.getWidth()));
2342 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2343
2344 bool equal = apf == rhs;
2345 if (!equal) {
2346 // If we had a comparison against a fractional value, we have to adjust
2347 // the compare predicate and sometimes the value. rhsInt is rounded
2348 // towards zero at this point.
2349 switch (pred) {
2350 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2351 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2352 /*width=*/1);
2353 return success();
2354 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2355 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2356 /*width=*/1);
2357 return success();
2358 case CmpIPredicate::ule:
2359 // (float)int <= 4.4 --> int <= 4
2360 // (float)int <= -4.4 --> false
2361 if (rhs.isNegative()) {
2362 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2363 /*width=*/1);
2364 return success();
2365 }
2366 break;
2367 case CmpIPredicate::sle:
2368 // (float)int <= 4.4 --> int <= 4
2369 // (float)int <= -4.4 --> int < -4
2370 if (rhs.isNegative())
2371 pred = CmpIPredicate::slt;
2372 break;
2373 case CmpIPredicate::ult:
2374 // (float)int < -4.4 --> false
2375 // (float)int < 4.4 --> int <= 4
2376 if (rhs.isNegative()) {
2377 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2378 /*width=*/1);
2379 return success();
2380 }
2381 pred = CmpIPredicate::ule;
2382 break;
2383 case CmpIPredicate::slt:
2384 // (float)int < -4.4 --> int < -4
2385 // (float)int < 4.4 --> int <= 4
2386 if (!rhs.isNegative())
2387 pred = CmpIPredicate::sle;
2388 break;
2389 case CmpIPredicate::ugt:
2390 // (float)int > 4.4 --> int > 4
2391 // (float)int > -4.4 --> true
2392 if (rhs.isNegative()) {
2393 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2394 /*width=*/1);
2395 return success();
2396 }
2397 break;
2398 case CmpIPredicate::sgt:
2399 // (float)int > 4.4 --> int > 4
2400 // (float)int > -4.4 --> int >= -4
2401 if (rhs.isNegative())
2402 pred = CmpIPredicate::sge;
2403 break;
2404 case CmpIPredicate::uge:
2405 // (float)int >= -4.4 --> true
2406 // (float)int >= 4.4 --> int > 4
2407 if (rhs.isNegative()) {
2408 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2409 /*width=*/1);
2410 return success();
2411 }
2412 pred = CmpIPredicate::ugt;
2413 break;
2414 case CmpIPredicate::sge:
2415 // (float)int >= -4.4 --> int >= -4
2416 // (float)int >= 4.4 --> int > 4
2417 if (!rhs.isNegative())
2418 pred = CmpIPredicate::sgt;
2419 break;
2420 }
2421 }
2422 }
2423
2424 // Lower this FP comparison into an appropriate integer version of the
2425 // comparison.
2426 rewriter.replaceOpWithNewOp<CmpIOp>(
2427 op, pred, intVal,
2428 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2429 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2430 return success();
2431 }
2432};
2433
2434void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2435 MLIRContext *context) {
2436 patterns.insert<CmpFIntToFPConst>(context);
2437}
2438
2439//===----------------------------------------------------------------------===//
2440// SelectOp
2441//===----------------------------------------------------------------------===//
2442
2443// select %arg, %c1, %c0 => extui %arg
2444struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2445 using Base::Base;
2446
2447 LogicalResult matchAndRewrite(arith::SelectOp op,
2448 PatternRewriter &rewriter) const override {
2449 // Cannot extui i1 to i1, or i1 to f32
2450 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2451 return failure();
2452
2453 // select %x, c1, %c0 => extui %arg
2454 if (matchPattern(op.getTrueValue(), m_One()) &&
2455 matchPattern(op.getFalseValue(), m_Zero())) {
2456 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2457 op.getCondition());
2458 return success();
2459 }
2460
2461 // select %x, c0, %c1 => extui (xor %arg, true)
2462 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2463 matchPattern(op.getFalseValue(), m_One())) {
2464 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2465 op, op.getType(),
2466 arith::XOrIOp::create(
2467 rewriter, op.getLoc(), op.getCondition(),
2468 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2469 op.getCondition().getType(), 1)));
2470 return success();
2471 }
2472
2473 return failure();
2474 }
2475};
2476
2477void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2478 MLIRContext *context) {
2479 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2480 SelectI1ToNot, SelectToExtUI>(context);
2481}
2482
2483OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2484 Value trueVal = getTrueValue();
2485 Value falseVal = getFalseValue();
2486 if (trueVal == falseVal)
2487 return trueVal;
2488
2489 Value condition = getCondition();
2490
2491 // select true, %0, %1 => %0
2492 if (matchPattern(adaptor.getCondition(), m_One()))
2493 return trueVal;
2494
2495 // select false, %0, %1 => %1
2496 if (matchPattern(adaptor.getCondition(), m_Zero()))
2497 return falseVal;
2498
2499 // If either operand is fully poisoned, return the other.
2500 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2501 return falseVal;
2502
2503 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2504 return trueVal;
2505
2506 // select %x, true, false => %x
2507 if (getType().isSignlessInteger(1) &&
2508 matchPattern(adaptor.getTrueValue(), m_One()) &&
2509 matchPattern(adaptor.getFalseValue(), m_Zero()))
2510 return condition;
2511
2512 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2513 auto pred = cmp.getPredicate();
2514 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2515 auto cmpLhs = cmp.getLhs();
2516 auto cmpRhs = cmp.getRhs();
2517
2518 // %0 = arith.cmpi eq, %arg0, %arg1
2519 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2520
2521 // %0 = arith.cmpi ne, %arg0, %arg1
2522 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2523
2524 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2525 (cmpRhs == trueVal && cmpLhs == falseVal))
2526 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2527 }
2528 }
2529
2530 // Constant-fold constant operands over non-splat constant condition.
2531 // select %cst_vec, %cst0, %cst1 => %cst2
2532 if (auto cond =
2533 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2534 // DenseElementsAttr by construction always has a static shape.
2535 assert(cond.getType().hasStaticShape() &&
2536 "DenseElementsAttr must have static shape");
2537 if (auto lhs =
2538 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2539 if (auto rhs =
2540 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2541 SmallVector<Attribute> results;
2542 results.reserve(static_cast<size_t>(cond.getNumElements()));
2543 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2544 cond.value_end<BoolAttr>());
2545 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2546 lhs.value_end<Attribute>());
2547 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2548 rhs.value_end<Attribute>());
2549
2550 for (auto [condVal, lhsVal, rhsVal] :
2551 llvm::zip_equal(condVals, lhsVals, rhsVals))
2552 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2553
2554 return DenseElementsAttr::get(lhs.getType(), results);
2555 }
2556 }
2557 }
2558
2559 return nullptr;
2560}
2561
2562ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2563 Type conditionType, resultType;
2564 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2565 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2566 parser.parseOptionalAttrDict(result.attributes) ||
2567 parser.parseColonType(resultType))
2568 return failure();
2569
2570 // Check for the explicit condition type if this is a masked tensor or vector.
2571 if (succeeded(parser.parseOptionalComma())) {
2572 conditionType = resultType;
2573 if (parser.parseType(resultType))
2574 return failure();
2575 } else {
2576 conditionType = parser.getBuilder().getI1Type();
2577 }
2578
2579 result.addTypes(resultType);
2580 return parser.resolveOperands(operands,
2581 {conditionType, resultType, resultType},
2582 parser.getNameLoc(), result.operands);
2583}
2584
2585void arith::SelectOp::print(OpAsmPrinter &p) {
2586 p << " " << getOperands();
2587 p.printOptionalAttrDict((*this)->getAttrs());
2588 p << " : ";
2589 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2590 p << condType << ", ";
2591 p << getType();
2592}
2593
2594LogicalResult arith::SelectOp::verify() {
2595 Type conditionType = getCondition().getType();
2596 if (conditionType.isSignlessInteger(1))
2597 return success();
2598
2599 // If the result type is a vector or tensor, the type can be a mask with the
2600 // same elements.
2601 Type resultType = getType();
2602 if (!llvm::isa<TensorType, VectorType>(resultType))
2603 return emitOpError() << "expected condition to be a signless i1, but got "
2604 << conditionType;
2605 Type shapedConditionType = getI1SameShape(resultType);
2606 if (conditionType != shapedConditionType) {
2607 return emitOpError() << "expected condition type to have the same shape "
2608 "as the result type, expected "
2609 << shapedConditionType << ", but got "
2610 << conditionType;
2611 }
2612 return success();
2613}
2614//===----------------------------------------------------------------------===//
2615// ShLIOp
2616//===----------------------------------------------------------------------===//
2617
2618OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2619 // shli(x, 0) -> x
2620 if (matchPattern(adaptor.getRhs(), m_Zero()))
2621 return getLhs();
2622 // Don't fold if shifting more or equal than the bit width.
2623 bool bounded = false;
2625 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2626 bounded = b.ult(b.getBitWidth());
2627 return a.shl(b);
2628 });
2629 return bounded ? result : Attribute();
2630}
2631
2632//===----------------------------------------------------------------------===//
2633// ShRUIOp
2634//===----------------------------------------------------------------------===//
2635
2636OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2637 // shrui(x, 0) -> x
2638 if (matchPattern(adaptor.getRhs(), m_Zero()))
2639 return getLhs();
2640 // Don't fold if shifting more or equal than the bit width.
2641 bool bounded = false;
2643 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2644 bounded = b.ult(b.getBitWidth());
2645 return a.lshr(b);
2646 });
2647 return bounded ? result : Attribute();
2648}
2649
2650//===----------------------------------------------------------------------===//
2651// ShRSIOp
2652//===----------------------------------------------------------------------===//
2653
2654OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2655 // shrsi(x, 0) -> x
2656 if (matchPattern(adaptor.getRhs(), m_Zero()))
2657 return getLhs();
2658 // Don't fold if shifting more or equal than the bit width.
2659 bool bounded = false;
2661 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2662 bounded = b.ult(b.getBitWidth());
2663 return a.ashr(b);
2664 });
2665 return bounded ? result : Attribute();
2666}
2667
2668//===----------------------------------------------------------------------===//
2669// Atomic Enum
2670//===----------------------------------------------------------------------===//
2671
2672/// Returns the identity value attribute associated with an AtomicRMWKind op.
2673TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2674 OpBuilder &builder, Location loc,
2675 bool useOnlyFiniteValue) {
2676 switch (kind) {
2677 case AtomicRMWKind::maximumf: {
2678 const llvm::fltSemantics &semantic =
2679 llvm::cast<FloatType>(resultType).getFloatSemantics();
2680 APFloat identity = useOnlyFiniteValue
2681 ? APFloat::getLargest(semantic, /*Negative=*/true)
2682 : APFloat::getInf(semantic, /*Negative=*/true);
2683 return builder.getFloatAttr(resultType, identity);
2684 }
2685 case AtomicRMWKind::maxnumf: {
2686 const llvm::fltSemantics &semantic =
2687 llvm::cast<FloatType>(resultType).getFloatSemantics();
2688 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2689 return builder.getFloatAttr(resultType, identity);
2690 }
2691 case AtomicRMWKind::addf:
2692 case AtomicRMWKind::addi:
2693 case AtomicRMWKind::maxu:
2694 case AtomicRMWKind::ori:
2695 case AtomicRMWKind::xori:
2696 return builder.getZeroAttr(resultType);
2697 case AtomicRMWKind::andi:
2698 return builder.getIntegerAttr(
2699 resultType,
2700 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2701 case AtomicRMWKind::maxs:
2702 return builder.getIntegerAttr(
2703 resultType, APInt::getSignedMinValue(
2704 llvm::cast<IntegerType>(resultType).getWidth()));
2705 case AtomicRMWKind::minimumf: {
2706 const llvm::fltSemantics &semantic =
2707 llvm::cast<FloatType>(resultType).getFloatSemantics();
2708 APFloat identity = useOnlyFiniteValue
2709 ? APFloat::getLargest(semantic, /*Negative=*/false)
2710 : APFloat::getInf(semantic, /*Negative=*/false);
2711
2712 return builder.getFloatAttr(resultType, identity);
2713 }
2714 case AtomicRMWKind::minnumf: {
2715 const llvm::fltSemantics &semantic =
2716 llvm::cast<FloatType>(resultType).getFloatSemantics();
2717 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2718 return builder.getFloatAttr(resultType, identity);
2719 }
2720 case AtomicRMWKind::mins:
2721 return builder.getIntegerAttr(
2722 resultType, APInt::getSignedMaxValue(
2723 llvm::cast<IntegerType>(resultType).getWidth()));
2724 case AtomicRMWKind::minu:
2725 return builder.getIntegerAttr(
2726 resultType,
2727 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2728 case AtomicRMWKind::muli:
2729 return builder.getIntegerAttr(resultType, 1);
2730 case AtomicRMWKind::mulf:
2731 return builder.getFloatAttr(resultType, 1);
2732 // TODO: Add remaining reduction operations.
2733 default:
2734 (void)emitOptionalError(loc, "Reduction operation type not supported");
2735 break;
2736 }
2737 return nullptr;
2738}
2739
2740/// Returns the identity numeric value of the given op.
2741std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2742 std::optional<AtomicRMWKind> maybeKind =
2744 // Floating-point operations.
2745 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2746 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2747 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2748 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2749 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2750 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2751 // Integer operations.
2752 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2753 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2754 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2755 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2756 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2757 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2758 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2759 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2760 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2761 .Default(std::nullopt);
2762 if (!maybeKind) {
2763 return std::nullopt;
2764 }
2765
2766 bool useOnlyFiniteValue = false;
2767 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2768 if (fmfOpInterface) {
2769 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2770 useOnlyFiniteValue =
2771 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2772 }
2773
2774 // Builder only used as helper for attribute creation.
2775 OpBuilder b(op->getContext());
2776 Type resultType = op->getResult(0).getType();
2777
2778 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2779 useOnlyFiniteValue);
2780}
2781
2782/// Returns the identity value associated with an AtomicRMWKind op.
2783Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2784 OpBuilder &builder, Location loc,
2785 bool useOnlyFiniteValue) {
2786 auto attr =
2787 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2788 return arith::ConstantOp::create(builder, loc, attr);
2789}
2790
2791/// Return the value obtained by applying the reduction operation kind
2792/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2794 Location loc, Value lhs, Value rhs) {
2795 switch (op) {
2796 case AtomicRMWKind::addf:
2797 return arith::AddFOp::create(builder, loc, lhs, rhs);
2798 case AtomicRMWKind::addi:
2799 return arith::AddIOp::create(builder, loc, lhs, rhs);
2800 case AtomicRMWKind::mulf:
2801 return arith::MulFOp::create(builder, loc, lhs, rhs);
2802 case AtomicRMWKind::muli:
2803 return arith::MulIOp::create(builder, loc, lhs, rhs);
2804 case AtomicRMWKind::maximumf:
2805 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2806 case AtomicRMWKind::minimumf:
2807 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2808 case AtomicRMWKind::maxnumf:
2809 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2810 case AtomicRMWKind::minnumf:
2811 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2812 case AtomicRMWKind::maxs:
2813 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2814 case AtomicRMWKind::mins:
2815 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2816 case AtomicRMWKind::maxu:
2817 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2818 case AtomicRMWKind::minu:
2819 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2820 case AtomicRMWKind::ori:
2821 return arith::OrIOp::create(builder, loc, lhs, rhs);
2822 case AtomicRMWKind::andi:
2823 return arith::AndIOp::create(builder, loc, lhs, rhs);
2824 case AtomicRMWKind::xori:
2825 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2826 // TODO: Add remaining reduction operations.
2827 default:
2828 (void)emitOptionalError(loc, "Reduction operation type not supported");
2829 break;
2830 }
2831 return nullptr;
2832}
2833
2834//===----------------------------------------------------------------------===//
2835// TableGen'd op method definitions
2836//===----------------------------------------------------------------------===//
2837
2838#define GET_OP_CLASSES
2839#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2840
2841//===----------------------------------------------------------------------===//
2842// TableGen'd enum attribute definitions
2843//===----------------------------------------------------------------------===//
2844
2845#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:715
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:676
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:775
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:757
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:149
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:129
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:956
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:173
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:437
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:141
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static bool classof(Operation *op)
Definition ArithOps.cpp:351
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:328
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:357
static bool classof(Operation *op)
Definition ArithOps.cpp:378
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:255
static bool classof(Operation *op)
Definition ArithOps.cpp:322
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:384
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.