MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
126 invertPredicate(pred.getValue()));
127}
128
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
138 return getScalarOrElementWidth(value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, m_ConstantInt(&value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(type.getContext(), value);
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152 if (!shapedType)
153 return boolAttr;
154 // DenseElementsAttr requires a static shape.
155 if (!shapedType.hasStaticShape())
156 return {};
157 return DenseElementsAttr::get(shapedType, boolAttr);
158}
159
160//===----------------------------------------------------------------------===//
161// TableGen'd canonicalization patterns
162//===----------------------------------------------------------------------===//
163
164namespace {
165#include "ArithCanonicalization.inc"
166} // namespace
167
168//===----------------------------------------------------------------------===//
169// Common helpers
170//===----------------------------------------------------------------------===//
171
172/// Return the type of the same shape (scalar, vector or tensor) containing i1.
174 auto i1Type = IntegerType::get(type.getContext(), 1);
175 if (auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
179 return i1Type;
180}
181
182//===----------------------------------------------------------------------===//
183// ConstantOp
184//===----------------------------------------------------------------------===//
185
186void arith::ConstantOp::getAsmResultNames(
187 function_ref<void(Value, StringRef)> setNameFn) {
188 auto type = getType();
189 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
191
192 // Sugar i1 constants with 'true' and 'false'.
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
195
196 // Otherwise, build a complex name with the value and type.
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName << 'c' << intCst.getValue();
200 if (intType)
201 specialName << '_' << type;
202 setNameFn(getResult(), specialName.str());
203 } else {
204 setNameFn(getResult(), "cst");
205 }
206}
207
208/// TODO: disallow arith.constant to return anything other than signless integer
209/// or float like.
210LogicalResult arith::ConstantOp::verify() {
211 auto type = getType();
212 // Integer values must be signless.
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError("integer return type must be signless");
216 // Any float or elements attribute are acceptable.
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
218 return emitOpError(
219 "value must be an integer, float, or elements attribute");
220 }
221
222 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
223 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
224 // However, this would most likely require updating the lowerings to LLVM.
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
226 return emitOpError(
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
229 return success();
230}
231
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
233 // The value's type must be the same as the provided type.
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
236 return false;
237 // Integer values must be signless.
238 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
239 if (!intType.isSignless())
240 return false;
241 }
242 // Integer, float, and element attributes are buildable.
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
244}
245
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
250 return nullptr;
251}
252
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
254
256 int64_t value, unsigned width) {
257 auto type = builder.getIntegerType(width);
258 arith::ConstantOp::build(builder, result, type,
259 builder.getIntegerAttr(type, value));
260}
261
263 Location location,
265 unsigned width) {
266 mlir::OperationState state(location, getOperationName());
267 build(builder, state, value, width);
268 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
269 assert(result && "builder didn't return the right type");
270 return result;
271}
272
275 unsigned width) {
276 return create(builder, builder.getLoc(), value, width);
277}
278
280 Type type, int64_t value) {
281 arith::ConstantOp::build(builder, result, type,
282 builder.getIntegerAttr(type, value));
283}
284
286 Location location, Type type,
287 int64_t value) {
288 mlir::OperationState state(location, getOperationName());
289 build(builder, state, type, value);
290 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
291 assert(result && "builder didn't return the right type");
292 return result;
293}
294
296 Type type, int64_t value) {
297 return create(builder, builder.getLoc(), type, value);
298}
299
301 Type type, const APInt &value) {
302 arith::ConstantOp::build(builder, result, type,
303 builder.getIntegerAttr(type, value));
304}
305
307 Location location, Type type,
308 const APInt &value) {
309 mlir::OperationState state(location, getOperationName());
310 build(builder, state, type, value);
311 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
312 assert(result && "builder didn't return the right type");
313 return result;
314}
315
317 Type type,
318 const APInt &value) {
319 return create(builder, builder.getLoc(), type, value);
320}
321
323 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
325 return false;
326}
327
329 FloatType type, const APFloat &value) {
330 arith::ConstantOp::build(builder, result, type,
331 builder.getFloatAttr(type, value));
332}
333
335 Location location,
336 FloatType type,
337 const APFloat &value) {
338 mlir::OperationState state(location, getOperationName());
339 build(builder, state, type, value);
340 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
341 assert(result && "builder didn't return the right type");
342 return result;
343}
344
347 const APFloat &value) {
348 return create(builder, builder.getLoc(), type, value);
349}
350
352 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
354 return false;
355}
356
358 int64_t value) {
359 arith::ConstantOp::build(builder, result, builder.getIndexType(),
360 builder.getIndexAttr(value));
361}
362
364 Location location,
365 int64_t value) {
366 mlir::OperationState state(location, getOperationName());
367 build(builder, state, value);
368 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
369 assert(result && "builder didn't return the right type");
370 return result;
371}
372
375 return create(builder, builder.getLoc(), value);
376}
377
379 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
381 return false;
382}
383
385 Type type) {
386 // TODO: Incorporate this check to `FloatAttr::get*`.
387 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
388 "type doesn't have a zero representation");
389 TypedAttr zeroAttr = builder.getZeroAttr(type);
390 assert(zeroAttr && "unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
392}
393
394//===----------------------------------------------------------------------===//
395// AddIOp
396//===----------------------------------------------------------------------===//
397
398OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
399 // addi(x, 0) -> x
400 if (matchPattern(adaptor.getRhs(), m_Zero()))
401 return getLhs();
402
403 // addi(subi(a, b), b) -> a
404 if (auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
406 return sub.getLhs();
407
408 // addi(b, subi(a, b)) -> a
409 if (auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
411 return sub.getLhs();
412
414 adaptor.getOperands(),
415 [](APInt a, const APInt &b) { return std::move(a) + b; });
416}
417
418void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
419 MLIRContext *context) {
420 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
422}
423
424//===----------------------------------------------------------------------===//
425// AddUIExtendedOp
426//===----------------------------------------------------------------------===//
427
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (auto vt = dyn_cast<VectorType>(getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
432 return std::nullopt;
433}
434
435// Returns the overflow bit, assuming that `sum` is the result of unsigned
436// addition of `operand` and another number.
437static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
439}
440
441LogicalResult
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
445 // addui_extended(x, 0) -> x, false
446 if (matchPattern(getRhs(), m_Zero())) {
447 Builder builder(getContext());
448 auto falseValue = builder.getZeroAttr(overflowTy);
449
450 results.push_back(getLhs());
451 results.push_back(falseValue);
452 return success();
453 }
454
455 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
456 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
457 // operands. If that succeeds, calculate the overflow bit based on the sum
458 // and the first (constant) operand, `lhs`.
459 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
460 adaptor.getOperands(),
461 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
462 // If any operand is poison, propagate poison to both results.
463 if (matchPattern(sumAttr, ub::m_Poison())) {
464 results.push_back(sumAttr);
465 results.push_back(sumAttr);
466 return success();
467 }
468 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
469 ArrayRef({sumAttr, adaptor.getLhs()}),
470 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
472 if (!overflowAttr)
473 return failure();
474
475 results.push_back(sumAttr);
476 results.push_back(overflowAttr);
477 return success();
478 }
479
480 return failure();
481}
482
483void arith::AddUIExtendedOp::getCanonicalizationPatterns(
484 RewritePatternSet &patterns, MLIRContext *context) {
485 patterns.add<AddUIExtendedToAddI>(context);
486}
487
488//===----------------------------------------------------------------------===//
489// SubIOp
490//===----------------------------------------------------------------------===//
491
492OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
493 // subi(x,x) -> 0
494 if (getOperand(0) == getOperand(1)) {
495 auto shapedType = dyn_cast<ShapedType>(getType());
496 // We can't generate a constant with a dynamic shaped tensor.
497 if (!shapedType || shapedType.hasStaticShape())
498 return Builder(getContext()).getZeroAttr(getType());
499 }
500 // subi(x,0) -> x
501 if (matchPattern(adaptor.getRhs(), m_Zero()))
502 return getLhs();
503
504 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
505 // subi(addi(a, b), b) -> a
506 if (getRhs() == add.getRhs())
507 return add.getLhs();
508 // subi(addi(a, b), a) -> b
509 if (getRhs() == add.getLhs())
510 return add.getRhs();
511 }
512
514 adaptor.getOperands(),
515 [](APInt a, const APInt &b) { return std::move(a) - b; });
516}
517
518void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
519 MLIRContext *context) {
520 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
521 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
522 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
523}
524
525//===----------------------------------------------------------------------===//
526// MulIOp
527//===----------------------------------------------------------------------===//
528
529OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
530 // muli(x, 0) -> 0
531 if (matchPattern(adaptor.getRhs(), m_Zero()))
532 return getRhs();
533 // muli(x, 1) -> x
534 if (matchPattern(adaptor.getRhs(), m_One()))
535 return getLhs();
536 // TODO: Handle the overflow case.
537
538 // default folder
540 adaptor.getOperands(),
541 [](const APInt &a, const APInt &b) { return a * b; });
542}
543
544void arith::MulIOp::getAsmResultNames(
545 function_ref<void(Value, StringRef)> setNameFn) {
546 if (!isa<IndexType>(getType()))
547 return;
548
549 // Match vector.vscale by name to avoid depending on the vector dialect (which
550 // is a circular dependency).
551 auto isVscale = [](Operation *op) {
552 return op && op->getName().getStringRef() == "vector.vscale";
553 };
554
555 IntegerAttr baseValue;
556 auto isVscaleExpr = [&](Value a, Value b) {
557 return matchPattern(a, m_Constant(&baseValue)) &&
558 isVscale(b.getDefiningOp());
559 };
560
561 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
562 return;
563
564 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
565 SmallString<32> specialNameBuffer;
566 llvm::raw_svector_ostream specialName(specialNameBuffer);
567 specialName << 'c' << baseValue.getInt() << "_vscale";
568 setNameFn(getResult(), specialName.str());
569}
570
571void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
572 MLIRContext *context) {
573 patterns.add<MulIMulIConstant>(context);
574}
575
576//===----------------------------------------------------------------------===//
577// MulSIExtendedOp
578//===----------------------------------------------------------------------===//
579
580std::optional<SmallVector<int64_t, 4>>
581arith::MulSIExtendedOp::getShapeForUnroll() {
582 if (auto vt = dyn_cast<VectorType>(getType(0)))
583 return llvm::to_vector<4>(vt.getShape());
584 return std::nullopt;
585}
586
587LogicalResult
588arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
589 SmallVectorImpl<OpFoldResult> &results) {
590 // mulsi_extended(x, 0) -> 0, 0
591 if (matchPattern(adaptor.getRhs(), m_Zero())) {
592 Attribute zero = adaptor.getRhs();
593 results.push_back(zero);
594 results.push_back(zero);
595 return success();
596 }
597
598 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
599 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
600 adaptor.getOperands(),
601 [](const APInt &a, const APInt &b) { return a * b; })) {
602 // Invoke the constant fold helper again to calculate the 'high' result.
603 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
604 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
605 return llvm::APIntOps::mulhs(a, b);
606 });
607 assert(highAttr && "Unexpected constant-folding failure");
608
609 results.push_back(lowAttr);
610 results.push_back(highAttr);
611 return success();
612 }
613
614 return failure();
615}
616
617void arith::MulSIExtendedOp::getCanonicalizationPatterns(
618 RewritePatternSet &patterns, MLIRContext *context) {
619 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
620}
621
622//===----------------------------------------------------------------------===//
623// MulUIExtendedOp
624//===----------------------------------------------------------------------===//
625
626std::optional<SmallVector<int64_t, 4>>
627arith::MulUIExtendedOp::getShapeForUnroll() {
628 if (auto vt = dyn_cast<VectorType>(getType(0)))
629 return llvm::to_vector<4>(vt.getShape());
630 return std::nullopt;
631}
632
633LogicalResult
634arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
635 SmallVectorImpl<OpFoldResult> &results) {
636 // mului_extended(x, 0) -> 0, 0
637 if (matchPattern(adaptor.getRhs(), m_Zero())) {
638 Attribute zero = adaptor.getRhs();
639 results.push_back(zero);
640 results.push_back(zero);
641 return success();
642 }
643
644 // mului_extended(x, 1) -> x, 0
645 if (matchPattern(adaptor.getRhs(), m_One())) {
646 Builder builder(getContext());
647 Attribute zero = builder.getZeroAttr(getLhs().getType());
648 results.push_back(getLhs());
649 results.push_back(zero);
650 return success();
651 }
652
653 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
654 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
655 adaptor.getOperands(),
656 [](const APInt &a, const APInt &b) { return a * b; })) {
657 // Invoke the constant fold helper again to calculate the 'high' result.
658 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
659 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
660 return llvm::APIntOps::mulhu(a, b);
661 });
662 assert(highAttr && "Unexpected constant-folding failure");
663
664 results.push_back(lowAttr);
665 results.push_back(highAttr);
666 return success();
667 }
668
669 return failure();
670}
671
672void arith::MulUIExtendedOp::getCanonicalizationPatterns(
673 RewritePatternSet &patterns, MLIRContext *context) {
674 patterns.add<MulUIExtendedToMulI>(context);
675}
676
677//===----------------------------------------------------------------------===//
678// DivUIOp
679//===----------------------------------------------------------------------===//
680
681/// Fold `(a * b) / b -> a`
683 arith::IntegerOverflowFlags ovfFlags) {
684 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
685 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
686 return {};
687
688 if (mul.getLhs() == rhs)
689 return mul.getRhs();
690
691 if (mul.getRhs() == rhs)
692 return mul.getLhs();
693
694 return {};
695}
696
697OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
698 // divui (x, 1) -> x.
699 if (matchPattern(adaptor.getRhs(), m_One()))
700 return getLhs();
701
702 // (a * b) / b -> a
703 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
704 return val;
705
706 // Don't fold if it would require a division by zero.
707 bool div0 = false;
708 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
709 [&](APInt a, const APInt &b) {
710 if (div0 || !b) {
711 div0 = true;
712 return a;
713 }
714 return a.udiv(b);
715 });
716
717 return div0 ? Attribute() : result;
718}
719
720/// Returns whether an unsigned division by `divisor` is speculatable.
722 // X / 0 => UB
723 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
725
727}
728
729Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
730 return getDivUISpeculatability(getRhs());
731}
732
733//===----------------------------------------------------------------------===//
734// DivSIOp
735//===----------------------------------------------------------------------===//
736
737OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
738 // divsi (x, 1) -> x.
739 if (matchPattern(adaptor.getRhs(), m_One()))
740 return getLhs();
741
742 // (a * b) / b -> a
743 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
744 return val;
745
746 // Don't fold if it would overflow or if it requires a division by zero.
747 bool overflowOrDiv0 = false;
749 adaptor.getOperands(), [&](APInt a, const APInt &b) {
750 if (overflowOrDiv0 || !b) {
751 overflowOrDiv0 = true;
752 return a;
753 }
754 return a.sdiv_ov(b, overflowOrDiv0);
755 });
756
757 return overflowOrDiv0 ? Attribute() : result;
758}
759
760/// Returns whether a signed division by `divisor` is speculatable. This
761/// function conservatively assumes that all signed division by -1 are not
762/// speculatable.
764 // X / 0 => UB
765 // INT_MIN / -1 => UB
766 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
769
771}
772
773Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
774 return getDivSISpeculatability(getRhs());
775}
776
777//===----------------------------------------------------------------------===//
778// Ceil and floor division folding helpers
779//===----------------------------------------------------------------------===//
780
781static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
782 bool &overflow) {
783 // Returns (a-1)/b + 1
784 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
785 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
786 return val.sadd_ov(one, overflow);
787}
788
789//===----------------------------------------------------------------------===//
790// CeilDivUIOp
791//===----------------------------------------------------------------------===//
792
793OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
794 // ceildivui (x, 1) -> x.
795 if (matchPattern(adaptor.getRhs(), m_One()))
796 return getLhs();
797
798 bool overflowOrDiv0 = false;
800 adaptor.getOperands(), [&](APInt a, const APInt &b) {
801 if (overflowOrDiv0 || !b) {
802 overflowOrDiv0 = true;
803 return a;
804 }
805 APInt quotient = a.udiv(b);
806 if (!a.urem(b))
807 return quotient;
808 APInt one(a.getBitWidth(), 1, true);
809 return quotient.uadd_ov(one, overflowOrDiv0);
810 });
811
812 return overflowOrDiv0 ? Attribute() : result;
813}
814
815Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
816 return getDivUISpeculatability(getRhs());
817}
818
819//===----------------------------------------------------------------------===//
820// CeilDivSIOp
821//===----------------------------------------------------------------------===//
822
823OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
824 // ceildivsi (x, 1) -> x.
825 if (matchPattern(adaptor.getRhs(), m_One()))
826 return getLhs();
827
828 // Don't fold if it would overflow or if it requires a division by zero.
829 // TODO: This hook won't fold operations where a = MININT, because
830 // negating MININT overflows. This can be improved.
831 bool overflowOrDiv0 = false;
833 adaptor.getOperands(), [&](APInt a, const APInt &b) {
834 if (overflowOrDiv0 || !b) {
835 overflowOrDiv0 = true;
836 return a;
837 }
838 if (!a)
839 return a;
840 // After this point we know that neither a or b are zero.
841 unsigned bits = a.getBitWidth();
842 APInt zero = APInt::getZero(bits);
843 bool aGtZero = a.sgt(zero);
844 bool bGtZero = b.sgt(zero);
845 if (aGtZero && bGtZero) {
846 // Both positive, return ceil(a, b).
847 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
848 }
849
850 // No folding happens if any of the intermediate arithmetic operations
851 // overflows.
852 bool overflowNegA = false;
853 bool overflowNegB = false;
854 bool overflowDiv = false;
855 bool overflowNegRes = false;
856 if (!aGtZero && !bGtZero) {
857 // Both negative, return ceil(-a, -b).
858 APInt posA = zero.ssub_ov(a, overflowNegA);
859 APInt posB = zero.ssub_ov(b, overflowNegB);
860 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
861 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
862 return res;
863 }
864 if (!aGtZero && bGtZero) {
865 // A is negative, b is positive, return - ( -a / b).
866 APInt posA = zero.ssub_ov(a, overflowNegA);
867 APInt div = posA.sdiv_ov(b, overflowDiv);
868 APInt res = zero.ssub_ov(div, overflowNegRes);
869 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
870 return res;
871 }
872 // A is positive, b is negative, return - (a / -b).
873 APInt posB = zero.ssub_ov(b, overflowNegB);
874 APInt div = a.sdiv_ov(posB, overflowDiv);
875 APInt res = zero.ssub_ov(div, overflowNegRes);
876
877 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
878 return res;
879 });
880
881 return overflowOrDiv0 ? Attribute() : result;
882}
883
884Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
885 return getDivSISpeculatability(getRhs());
886}
887
888//===----------------------------------------------------------------------===//
889// FloorDivSIOp
890//===----------------------------------------------------------------------===//
891
892OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
893 // floordivsi (x, 1) -> x.
894 if (matchPattern(adaptor.getRhs(), m_One()))
895 return getLhs();
896
897 // Don't fold if it would overflow or if it requires a division by zero.
898 bool overflowOrDiv = false;
900 adaptor.getOperands(), [&](APInt a, const APInt &b) {
901 if (b.isZero()) {
902 overflowOrDiv = true;
903 return a;
904 }
905 return a.sfloordiv_ov(b, overflowOrDiv);
906 });
907
908 return overflowOrDiv ? Attribute() : result;
909}
910
911//===----------------------------------------------------------------------===//
912// RemUIOp
913//===----------------------------------------------------------------------===//
914
915OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
916 // remui (x, 1) -> 0.
917 if (matchPattern(adaptor.getRhs(), m_One()))
918 return Builder(getContext()).getZeroAttr(getType());
919
920 // Don't fold if it would require a division by zero.
921 bool div0 = false;
922 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
923 [&](APInt a, const APInt &b) {
924 if (div0 || b.isZero()) {
925 div0 = true;
926 return a;
927 }
928 return a.urem(b);
929 });
930
931 return div0 ? Attribute() : result;
932}
933
934//===----------------------------------------------------------------------===//
935// RemSIOp
936//===----------------------------------------------------------------------===//
937
938OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
939 // remsi (x, 1) -> 0.
940 if (matchPattern(adaptor.getRhs(), m_One()))
941 return Builder(getContext()).getZeroAttr(getType());
942
943 // Don't fold if it would require a division by zero.
944 bool div0 = false;
945 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
946 [&](APInt a, const APInt &b) {
947 if (div0 || b.isZero()) {
948 div0 = true;
949 return a;
950 }
951 return a.srem(b);
952 });
953
954 return div0 ? Attribute() : result;
955}
956
957//===----------------------------------------------------------------------===//
958// AndIOp
959//===----------------------------------------------------------------------===//
960
961/// Fold `and(a, and(a, b))` to `and(a, b)`
962static Value foldAndIofAndI(arith::AndIOp op) {
963 for (bool reversePrev : {false, true}) {
964 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
965 .getDefiningOp<arith::AndIOp>();
966 if (!prev)
967 continue;
968
969 Value other = (reversePrev ? op.getLhs() : op.getRhs());
970 if (other != prev.getLhs() && other != prev.getRhs())
971 continue;
972
973 return prev.getResult();
974 }
975 return {};
976}
977
978OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
979 /// and(x, 0) -> 0
980 if (matchPattern(adaptor.getRhs(), m_Zero()))
981 return getRhs();
982 /// and(x, allOnes) -> x
983 APInt intValue;
984 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
985 intValue.isAllOnes())
986 return getLhs();
987 /// and(x, not(x)) -> 0
988 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
989 m_ConstantInt(&intValue))) &&
990 intValue.isAllOnes())
991 return Builder(getContext()).getZeroAttr(getType());
992 /// and(not(x), x) -> 0
993 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
994 m_ConstantInt(&intValue))) &&
995 intValue.isAllOnes())
996 return Builder(getContext()).getZeroAttr(getType());
997
998 /// and(a, and(a, b)) -> and(a, b)
999 if (Value result = foldAndIofAndI(*this))
1000 return result;
1001
1003 adaptor.getOperands(),
1004 [](APInt a, const APInt &b) { return std::move(a) & b; });
1005}
1006
1007//===----------------------------------------------------------------------===//
1008// OrIOp
1009//===----------------------------------------------------------------------===//
1010
1011OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1012 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1013 /// or(x, 0) -> x
1014 if (rhsVal.isZero())
1015 return getLhs();
1016 /// or(x, <all ones>) -> <all ones>
1017 if (rhsVal.isAllOnes())
1018 return adaptor.getRhs();
1019 }
1020
1021 APInt intValue;
1022 /// or(x, xor(x, 1)) -> 1
1023 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1024 m_ConstantInt(&intValue))) &&
1025 intValue.isAllOnes())
1026 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1027 /// or(xor(x, 1), x) -> 1
1028 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1029 m_ConstantInt(&intValue))) &&
1030 intValue.isAllOnes())
1031 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1032
1034 adaptor.getOperands(),
1035 [](APInt a, const APInt &b) { return std::move(a) | b; });
1036}
1037
1038//===----------------------------------------------------------------------===//
1039// XOrIOp
1040//===----------------------------------------------------------------------===//
1041
1042OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1043 /// xor(x, 0) -> x
1044 if (matchPattern(adaptor.getRhs(), m_Zero()))
1045 return getLhs();
1046 /// xor(x, x) -> 0
1047 if (getLhs() == getRhs())
1048 return Builder(getContext()).getZeroAttr(getType());
1049 /// xor(xor(x, a), a) -> x
1050 /// xor(xor(a, x), a) -> x
1051 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1052 if (prev.getRhs() == getRhs())
1053 return prev.getLhs();
1054 if (prev.getLhs() == getRhs())
1055 return prev.getRhs();
1056 }
1057 /// xor(a, xor(x, a)) -> x
1058 /// xor(a, xor(a, x)) -> x
1059 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1060 if (prev.getRhs() == getLhs())
1061 return prev.getLhs();
1062 if (prev.getLhs() == getLhs())
1063 return prev.getRhs();
1064 }
1065
1067 adaptor.getOperands(),
1068 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1069}
1070
1071void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1072 MLIRContext *context) {
1073 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1074}
1075
1076//===----------------------------------------------------------------------===//
1077// NegFOp
1078//===----------------------------------------------------------------------===//
1079
1080OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1081 /// negf(negf(x)) -> x
1082 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1083 return op.getOperand();
1084 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1085 [](const APFloat &a) { return -a; });
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// AddFOp
1090//===----------------------------------------------------------------------===//
1091
1092OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1093 // addf(x, -0) -> x
1094 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1095 return getLhs();
1096
1098 adaptor.getOperands(),
1099 [](const APFloat &a, const APFloat &b) { return a + b; });
1100}
1101
1102//===----------------------------------------------------------------------===//
1103// SubFOp
1104//===----------------------------------------------------------------------===//
1105
1106OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1107 // subf(x, +0) -> x
1108 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1109 return getLhs();
1110
1112 adaptor.getOperands(),
1113 [](const APFloat &a, const APFloat &b) { return a - b; });
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// MaximumFOp
1118//===----------------------------------------------------------------------===//
1119
1120OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1121 // maximumf(x,x) -> x
1122 if (getLhs() == getRhs())
1123 return getRhs();
1124
1125 // maximumf(x, -inf) -> x
1126 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1127 return getLhs();
1128
1130 adaptor.getOperands(),
1131 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1132}
1133
1134//===----------------------------------------------------------------------===//
1135// MaxNumFOp
1136//===----------------------------------------------------------------------===//
1137
1138OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1139 // maxnumf(x,x) -> x
1140 if (getLhs() == getRhs())
1141 return getRhs();
1142
1143 // maxnumf(x, NaN) -> x
1144 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1145 return getLhs();
1146
1147 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1148}
1149
1150//===----------------------------------------------------------------------===//
1151// MaxSIOp
1152//===----------------------------------------------------------------------===//
1153
1154OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1155 // maxsi(x,x) -> x
1156 if (getLhs() == getRhs())
1157 return getRhs();
1158
1159 if (APInt intValue;
1160 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1161 // maxsi(x,MAX_INT) -> MAX_INT
1162 if (intValue.isMaxSignedValue())
1163 return getRhs();
1164 // maxsi(x, MIN_INT) -> x
1165 if (intValue.isMinSignedValue())
1166 return getLhs();
1167 }
1168
1169 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1170 [](const APInt &a, const APInt &b) {
1171 return llvm::APIntOps::smax(a, b);
1172 });
1173}
1174
1175//===----------------------------------------------------------------------===//
1176// MaxUIOp
1177//===----------------------------------------------------------------------===//
1178
1179OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1180 // maxui(x,x) -> x
1181 if (getLhs() == getRhs())
1182 return getRhs();
1183
1184 if (APInt intValue;
1185 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1186 // maxui(x,MAX_INT) -> MAX_INT
1187 if (intValue.isMaxValue())
1188 return getRhs();
1189 // maxui(x, MIN_INT) -> x
1190 if (intValue.isMinValue())
1191 return getLhs();
1192 }
1193
1194 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1195 [](const APInt &a, const APInt &b) {
1196 return llvm::APIntOps::umax(a, b);
1197 });
1198}
1199
1200//===----------------------------------------------------------------------===//
1201// MinimumFOp
1202//===----------------------------------------------------------------------===//
1203
1204OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1205 // minimumf(x,x) -> x
1206 if (getLhs() == getRhs())
1207 return getRhs();
1208
1209 // minimumf(x, +inf) -> x
1210 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1211 return getLhs();
1212
1214 adaptor.getOperands(),
1215 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1216}
1217
1218//===----------------------------------------------------------------------===//
1219// MinNumFOp
1220//===----------------------------------------------------------------------===//
1221
1222OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1223 // minnumf(x,x) -> x
1224 if (getLhs() == getRhs())
1225 return getRhs();
1226
1227 // minnumf(x, NaN) -> x
1228 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1229 return getLhs();
1230
1232 adaptor.getOperands(),
1233 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1234}
1235
1236//===----------------------------------------------------------------------===//
1237// MinSIOp
1238//===----------------------------------------------------------------------===//
1239
1240OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1241 // minsi(x,x) -> x
1242 if (getLhs() == getRhs())
1243 return getRhs();
1244
1245 if (APInt intValue;
1246 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1247 // minsi(x,MIN_INT) -> MIN_INT
1248 if (intValue.isMinSignedValue())
1249 return getRhs();
1250 // minsi(x, MAX_INT) -> x
1251 if (intValue.isMaxSignedValue())
1252 return getLhs();
1253 }
1254
1255 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1256 [](const APInt &a, const APInt &b) {
1257 return llvm::APIntOps::smin(a, b);
1258 });
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// MinUIOp
1263//===----------------------------------------------------------------------===//
1264
1265OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1266 // minui(x,x) -> x
1267 if (getLhs() == getRhs())
1268 return getRhs();
1269
1270 if (APInt intValue;
1271 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1272 // minui(x,MIN_INT) -> MIN_INT
1273 if (intValue.isMinValue())
1274 return getRhs();
1275 // minui(x, MAX_INT) -> x
1276 if (intValue.isMaxValue())
1277 return getLhs();
1278 }
1279
1280 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1281 [](const APInt &a, const APInt &b) {
1282 return llvm::APIntOps::umin(a, b);
1283 });
1284}
1285
1286//===----------------------------------------------------------------------===//
1287// MulFOp
1288//===----------------------------------------------------------------------===//
1289
1290OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1291 // mulf(x, 1) -> x
1292 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1293 return getLhs();
1294
1295 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1296 arith::FastMathFlags::nsz)) {
1297 // mulf(x, 0) -> 0
1298 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1299 return getRhs();
1300 }
1301
1303 adaptor.getOperands(),
1304 [](const APFloat &a, const APFloat &b) { return a * b; });
1305}
1306
1307void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1308 MLIRContext *context) {
1309 patterns.add<MulFOfNegF>(context);
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// DivFOp
1314//===----------------------------------------------------------------------===//
1315
1316OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1317 // divf(x, 1) -> x
1318 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1319 return getLhs();
1320
1322 adaptor.getOperands(),
1323 [](const APFloat &a, const APFloat &b) { return a / b; });
1324}
1325
1326void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1327 MLIRContext *context) {
1328 patterns.add<DivFOfNegF>(context);
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// RemFOp
1333//===----------------------------------------------------------------------===//
1334
1335OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1336 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1337 [](const APFloat &a, const APFloat &b) {
1338 APFloat result(a);
1339 // APFloat::mod() offers the remainder
1340 // behavior we want, i.e. the result has
1341 // the sign of LHS operand.
1342 (void)result.mod(b);
1343 return result;
1344 });
1345}
1346
1347//===----------------------------------------------------------------------===//
1348// Utility functions for verifying cast ops
1349//===----------------------------------------------------------------------===//
1350
1351template <typename... Types>
1352using type_list = std::tuple<Types...> *;
1353
1354/// Returns a non-null type only if the provided type is one of the allowed
1355/// types or one of the allowed shaped types of the allowed types. Returns the
1356/// element type if a valid shaped type is provided.
1357template <typename... ShapedTypes, typename... ElementTypes>
1360 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1361 return {};
1362
1363 auto underlyingType = getElementTypeOrSelf(type);
1364 if (!llvm::isa<ElementTypes...>(underlyingType))
1365 return {};
1366
1367 return underlyingType;
1368}
1369
1370/// Get allowed underlying types for vectors and tensors.
1371template <typename... ElementTypes>
1376
1377/// Get allowed underlying types for vectors, tensors, and memrefs.
1378template <typename... ElementTypes>
1384
1385/// Return false if both types are ranked tensor with mismatching encoding.
1386static bool hasSameEncoding(Type typeA, Type typeB) {
1387 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1388 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1389 if (!rankedTensorA || !rankedTensorB)
1390 return true;
1391 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1392}
1393
1395 if (inputs.size() != 1 || outputs.size() != 1)
1396 return false;
1397 if (!hasSameEncoding(inputs.front(), outputs.front()))
1398 return false;
1399 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1400}
1401
1402//===----------------------------------------------------------------------===//
1403// Verifiers for integer and floating point extension/truncation ops
1404//===----------------------------------------------------------------------===//
1405
1406// Extend ops can only extend to a wider type.
1407template <typename ValType, typename Op>
1408static LogicalResult verifyExtOp(Op op) {
1409 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1410 Type dstType = getElementTypeOrSelf(op.getType());
1411
1412 if (llvm::cast<ValType>(srcType).getWidth() >=
1413 llvm::cast<ValType>(dstType).getWidth())
1414 return op.emitError("result type ")
1415 << dstType << " must be wider than operand type " << srcType;
1416
1417 return success();
1418}
1419
1420// Truncate ops can only truncate to a shorter type.
1421template <typename ValType, typename Op>
1422static LogicalResult verifyTruncateOp(Op op) {
1423 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1424 Type dstType = getElementTypeOrSelf(op.getType());
1425
1426 if (llvm::cast<ValType>(srcType).getWidth() <=
1427 llvm::cast<ValType>(dstType).getWidth())
1428 return op.emitError("result type ")
1429 << dstType << " must be shorter than operand type " << srcType;
1430
1431 return success();
1432}
1433
1434/// Validate a cast that changes the width of a type.
1435template <template <typename> class WidthComparator, typename... ElementTypes>
1436static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1437 if (!areValidCastInputsAndOutputs(inputs, outputs))
1438 return false;
1439
1440 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1441 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1442 if (!srcType || !dstType)
1443 return false;
1444
1445 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1446 srcType.getIntOrFloatBitWidth());
1447}
1448
1449/// Attempts to convert `sourceValue` to an APFloat value with
1450/// `targetSemantics` and `roundingMode`, without any information loss.
1451static FailureOr<APFloat> convertFloatValue(
1452 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1453 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1454 bool losesInfo = false;
1455 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1456 if (losesInfo || status != APFloat::opOK)
1457 return failure();
1458
1459 return sourceValue;
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// ExtUIOp
1464//===----------------------------------------------------------------------===//
1465
1466OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1467 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1468 getInMutable().assign(lhs.getIn());
1469 return getResult();
1470 }
1471
1472 Type resType = getElementTypeOrSelf(getType());
1473 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1475 adaptor.getOperands(), getType(),
1476 [bitWidth](const APInt &a, bool &castStatus) {
1477 return a.zext(bitWidth);
1478 });
1479}
1480
1481bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1483}
1484
1485LogicalResult arith::ExtUIOp::verify() {
1486 return verifyExtOp<IntegerType>(*this);
1487}
1488
1489//===----------------------------------------------------------------------===//
1490// ExtSIOp
1491//===----------------------------------------------------------------------===//
1492
1493OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1494 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1495 getInMutable().assign(lhs.getIn());
1496 return getResult();
1497 }
1498
1499 Type resType = getElementTypeOrSelf(getType());
1500 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1502 adaptor.getOperands(), getType(),
1503 [bitWidth](const APInt &a, bool &castStatus) {
1504 return a.sext(bitWidth);
1505 });
1506}
1507
1508bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1510}
1511
1512void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1513 MLIRContext *context) {
1514 patterns.add<ExtSIOfExtUI>(context);
1515}
1516
1517LogicalResult arith::ExtSIOp::verify() {
1518 return verifyExtOp<IntegerType>(*this);
1519}
1520
1521//===----------------------------------------------------------------------===//
1522// ExtFOp
1523//===----------------------------------------------------------------------===//
1524
1525/// Fold extension of float constants when there is no information loss due the
1526/// difference in fp semantics.
1527OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1528 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1529 if (truncFOp.getOperand().getType() == getType()) {
1530 arith::FastMathFlags truncFMF =
1531 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1532 bool isTruncContract =
1533 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1534 arith::FastMathFlags extFMF =
1535 getFastmath().value_or(arith::FastMathFlags::none);
1536 bool isExtContract =
1537 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1538 if (isTruncContract && isExtContract) {
1539 return truncFOp.getOperand();
1540 }
1541 }
1542 }
1543
1544 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1545 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1547 adaptor.getOperands(), getType(),
1548 [&targetSemantics](const APFloat &a, bool &castStatus) {
1549 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1550 if (failed(result)) {
1551 castStatus = false;
1552 return a;
1553 }
1554 return *result;
1555 });
1556}
1557
1558bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1559 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1560}
1561
1562LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1563
1564//===----------------------------------------------------------------------===//
1565// ScalingExtFOp
1566//===----------------------------------------------------------------------===//
1567
1568bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1569 TypeRange outputs) {
1570 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1571}
1572
1573LogicalResult arith::ScalingExtFOp::verify() {
1574 return verifyExtOp<FloatType>(*this);
1575}
1576
1577//===----------------------------------------------------------------------===//
1578// TruncIOp
1579//===----------------------------------------------------------------------===//
1580
1581OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1582 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1583 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1584 Value src = getOperand().getDefiningOp()->getOperand(0);
1585 Type srcType = getElementTypeOrSelf(src.getType());
1586 Type dstType = getElementTypeOrSelf(getType());
1587 // trunci(zexti(a)) -> trunci(a)
1588 // trunci(sexti(a)) -> trunci(a)
1589 if (llvm::cast<IntegerType>(srcType).getWidth() >
1590 llvm::cast<IntegerType>(dstType).getWidth()) {
1591 setOperand(src);
1592 return getResult();
1593 }
1594
1595 // trunci(zexti(a)) -> a
1596 // trunci(sexti(a)) -> a
1597 if (srcType == dstType)
1598 return src;
1599 }
1600
1601 // trunci(trunci(a)) -> trunci(a))
1602 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1603 setOperand(getOperand().getDefiningOp()->getOperand(0));
1604 return getResult();
1605 }
1606
1607 Type resType = getElementTypeOrSelf(getType());
1608 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1610 adaptor.getOperands(), getType(),
1611 [bitWidth](const APInt &a, bool &castStatus) {
1612 return a.trunc(bitWidth);
1613 });
1614}
1615
1616bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1617 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1618}
1619
1620void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1621 MLIRContext *context) {
1622 patterns
1623 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1624 context);
1625}
1626
1627LogicalResult arith::TruncIOp::verify() {
1628 return verifyTruncateOp<IntegerType>(*this);
1629}
1630
1631//===----------------------------------------------------------------------===//
1632// TruncFOp
1633//===----------------------------------------------------------------------===//
1634
1635/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1636/// can be represented without precision loss.
1637OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1638 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1639 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1640 Value src = extOp.getIn();
1641 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1642 auto intermediateType =
1643 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1644 // Check if the srcType is representable in the intermediateType.
1645 if (llvm::APFloatBase::isRepresentableBy(
1646 srcType.getFloatSemantics(),
1647 intermediateType.getFloatSemantics())) {
1648 // truncf(extf(a)) -> truncf(a)
1649 if (srcType.getWidth() > resElemType.getWidth()) {
1650 setOperand(src);
1651 return getResult();
1652 }
1653
1654 // truncf(extf(a)) -> a
1655 if (srcType == resElemType)
1656 return src;
1657 }
1658 }
1659
1660 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1662 adaptor.getOperands(), getType(),
1663 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1664 RoundingMode roundingMode =
1665 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1666 llvm::RoundingMode llvmRoundingMode =
1668 FailureOr<APFloat> result =
1669 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1670 if (failed(result)) {
1671 castStatus = false;
1672 return a;
1673 }
1674 return *result;
1675 });
1676}
1677
1678void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1679 MLIRContext *context) {
1680 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1681}
1682
1683bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1684 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1685}
1686
1687LogicalResult arith::TruncFOp::verify() {
1688 return verifyTruncateOp<FloatType>(*this);
1689}
1690
1691//===----------------------------------------------------------------------===//
1692// ScalingTruncFOp
1693//===----------------------------------------------------------------------===//
1694
1695bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1696 TypeRange outputs) {
1697 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1698}
1699
1700LogicalResult arith::ScalingTruncFOp::verify() {
1701 return verifyTruncateOp<FloatType>(*this);
1702}
1703
1704//===----------------------------------------------------------------------===//
1705// AndIOp
1706//===----------------------------------------------------------------------===//
1707
1708void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1709 MLIRContext *context) {
1710 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// OrIOp
1715//===----------------------------------------------------------------------===//
1716
1717void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1718 MLIRContext *context) {
1719 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1720}
1721
1722//===----------------------------------------------------------------------===//
1723// Verifiers for casts between integers and floats.
1724//===----------------------------------------------------------------------===//
1725
1726template <typename From, typename To>
1727static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1728 if (!areValidCastInputsAndOutputs(inputs, outputs))
1729 return false;
1730
1731 auto srcType = getTypeIfLike<From>(inputs.front());
1732 auto dstType = getTypeIfLike<To>(outputs.back());
1733
1734 return srcType && dstType;
1735}
1736
1737//===----------------------------------------------------------------------===//
1738// UIToFPOp
1739//===----------------------------------------------------------------------===//
1740
1741bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1742 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1743}
1744
1745OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1746 Type resEleType = getElementTypeOrSelf(getType());
1748 adaptor.getOperands(), getType(),
1749 [&resEleType](const APInt &a, bool &castStatus) {
1750 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1751 APFloat apf(floatTy.getFloatSemantics(),
1752 APInt::getZero(floatTy.getWidth()));
1753 apf.convertFromAPInt(a, /*IsSigned=*/false,
1754 APFloat::rmNearestTiesToEven);
1755 return apf;
1756 });
1757}
1758
1759//===----------------------------------------------------------------------===//
1760// SIToFPOp
1761//===----------------------------------------------------------------------===//
1762
1763bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1764 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1765}
1766
1767OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1768 Type resEleType = getElementTypeOrSelf(getType());
1770 adaptor.getOperands(), getType(),
1771 [&resEleType](const APInt &a, bool &castStatus) {
1772 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1773 APFloat apf(floatTy.getFloatSemantics(),
1774 APInt::getZero(floatTy.getWidth()));
1775 apf.convertFromAPInt(a, /*IsSigned=*/true,
1776 APFloat::rmNearestTiesToEven);
1777 return apf;
1778 });
1779}
1780
1781//===----------------------------------------------------------------------===//
1782// FPToUIOp
1783//===----------------------------------------------------------------------===//
1784
1785bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1786 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1787}
1788
1789OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1790 Type resType = getElementTypeOrSelf(getType());
1791 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1793 adaptor.getOperands(), getType(),
1794 [&bitWidth](const APFloat &a, bool &castStatus) {
1795 bool ignored;
1796 APSInt api(bitWidth, /*isUnsigned=*/true);
1797 castStatus = APFloat::opInvalidOp !=
1798 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1799 return api;
1800 });
1801}
1802
1803//===----------------------------------------------------------------------===//
1804// FPToSIOp
1805//===----------------------------------------------------------------------===//
1806
1807bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1808 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1809}
1810
1811OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1812 Type resType = getElementTypeOrSelf(getType());
1813 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1815 adaptor.getOperands(), getType(),
1816 [&bitWidth](const APFloat &a, bool &castStatus) {
1817 bool ignored;
1818 APSInt api(bitWidth, /*isUnsigned=*/false);
1819 castStatus = APFloat::opInvalidOp !=
1820 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1821 return api;
1822 });
1823}
1824
1825//===----------------------------------------------------------------------===//
1826// IndexCastOp
1827//===----------------------------------------------------------------------===//
1828
1829static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1830 if (!areValidCastInputsAndOutputs(inputs, outputs))
1831 return false;
1832
1833 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1834 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1835 if (!srcType || !dstType)
1836 return false;
1837
1838 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1839 (srcType.isSignlessInteger() && dstType.isIndex());
1840}
1841
1842bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1843 TypeRange outputs) {
1844 return areIndexCastCompatible(inputs, outputs);
1845}
1846
1847OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1848 // index_cast(constant) -> constant
1849 unsigned resultBitwidth = 64; // Default for index integer attributes.
1850 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1851 resultBitwidth = intTy.getWidth();
1852
1854 adaptor.getOperands(), getType(),
1855 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1856 return a.sextOrTrunc(resultBitwidth);
1857 });
1858}
1859
1860void arith::IndexCastOp::getCanonicalizationPatterns(
1861 RewritePatternSet &patterns, MLIRContext *context) {
1862 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1863}
1864
1865//===----------------------------------------------------------------------===//
1866// IndexCastUIOp
1867//===----------------------------------------------------------------------===//
1868
1869bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1870 TypeRange outputs) {
1871 return areIndexCastCompatible(inputs, outputs);
1872}
1873
1874OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1875 // index_castui(constant) -> constant
1876 unsigned resultBitwidth = 64; // Default for index integer attributes.
1877 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1878 resultBitwidth = intTy.getWidth();
1879
1881 adaptor.getOperands(), getType(),
1882 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1883 return a.zextOrTrunc(resultBitwidth);
1884 });
1885}
1886
1887void arith::IndexCastUIOp::getCanonicalizationPatterns(
1888 RewritePatternSet &patterns, MLIRContext *context) {
1889 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1890}
1891
1892//===----------------------------------------------------------------------===//
1893// BitcastOp
1894//===----------------------------------------------------------------------===//
1895
1896bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1897 if (!areValidCastInputsAndOutputs(inputs, outputs))
1898 return false;
1899
1900 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1901 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1902 if (!srcType || !dstType)
1903 return false;
1904
1905 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1906}
1907
1908OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1909 auto resType = getType();
1910 auto operand = adaptor.getIn();
1911 if (!operand)
1912 return {};
1913
1914 /// Bitcast dense elements.
1915 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1916 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1917 /// Other shaped types unhandled.
1918 if (llvm::isa<ShapedType>(resType))
1919 return {};
1920
1921 /// Bitcast poison.
1922 if (matchPattern(operand, ub::m_Poison()))
1923 return ub::PoisonAttr::get(getContext());
1924
1925 /// Bitcast integer or float to integer or float.
1926 APInt bits = llvm::isa<FloatAttr>(operand)
1927 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1928 : llvm::cast<IntegerAttr>(operand).getValue();
1929 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1930 "trying to fold on broken IR: operands have incompatible types");
1931
1932 if (auto resFloatType = dyn_cast<FloatType>(resType))
1933 return FloatAttr::get(resType,
1934 APFloat(resFloatType.getFloatSemantics(), bits));
1935 return IntegerAttr::get(resType, bits);
1936}
1937
1938void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1939 MLIRContext *context) {
1940 patterns.add<BitcastOfBitcast>(context);
1941}
1942
1943//===----------------------------------------------------------------------===//
1944// CmpIOp
1945//===----------------------------------------------------------------------===//
1946
1947/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1948/// comparison predicates.
1949bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1950 const APInt &lhs, const APInt &rhs) {
1951 switch (predicate) {
1952 case arith::CmpIPredicate::eq:
1953 return lhs.eq(rhs);
1954 case arith::CmpIPredicate::ne:
1955 return lhs.ne(rhs);
1956 case arith::CmpIPredicate::slt:
1957 return lhs.slt(rhs);
1958 case arith::CmpIPredicate::sle:
1959 return lhs.sle(rhs);
1960 case arith::CmpIPredicate::sgt:
1961 return lhs.sgt(rhs);
1962 case arith::CmpIPredicate::sge:
1963 return lhs.sge(rhs);
1964 case arith::CmpIPredicate::ult:
1965 return lhs.ult(rhs);
1966 case arith::CmpIPredicate::ule:
1967 return lhs.ule(rhs);
1968 case arith::CmpIPredicate::ugt:
1969 return lhs.ugt(rhs);
1970 case arith::CmpIPredicate::uge:
1971 return lhs.uge(rhs);
1972 }
1973 llvm_unreachable("unknown cmpi predicate kind");
1974}
1975
1976/// Returns true if the predicate is true for two equal operands.
1977static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1978 switch (predicate) {
1979 case arith::CmpIPredicate::eq:
1980 case arith::CmpIPredicate::sle:
1981 case arith::CmpIPredicate::sge:
1982 case arith::CmpIPredicate::ule:
1983 case arith::CmpIPredicate::uge:
1984 return true;
1985 case arith::CmpIPredicate::ne:
1986 case arith::CmpIPredicate::slt:
1987 case arith::CmpIPredicate::sgt:
1988 case arith::CmpIPredicate::ult:
1989 case arith::CmpIPredicate::ugt:
1990 return false;
1991 }
1992 llvm_unreachable("unknown cmpi predicate kind");
1993}
1994
1995static std::optional<int64_t> getIntegerWidth(Type t) {
1996 if (auto intType = dyn_cast<IntegerType>(t)) {
1997 return intType.getWidth();
1998 }
1999 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2000 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2001 }
2002 return std::nullopt;
2003}
2004
2005OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2006 // cmpi(pred, x, x)
2007 if (getLhs() == getRhs()) {
2008 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2009 return getBoolAttribute(getType(), val);
2010 }
2011
2012 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2013 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2014 // extsi(%x : i1 -> iN) != 0 -> %x
2015 std::optional<int64_t> integerWidth =
2016 getIntegerWidth(extOp.getOperand().getType());
2017 if (integerWidth && integerWidth.value() == 1 &&
2018 getPredicate() == arith::CmpIPredicate::ne)
2019 return extOp.getOperand();
2020 }
2021 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2022 // extui(%x : i1 -> iN) != 0 -> %x
2023 std::optional<int64_t> integerWidth =
2024 getIntegerWidth(extOp.getOperand().getType());
2025 if (integerWidth && integerWidth.value() == 1 &&
2026 getPredicate() == arith::CmpIPredicate::ne)
2027 return extOp.getOperand();
2028 }
2029
2030 // arith.cmpi ne, %val, %zero : i1 -> %val
2031 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2032 getPredicate() == arith::CmpIPredicate::ne)
2033 return getLhs();
2034 }
2035
2036 if (matchPattern(adaptor.getRhs(), m_One())) {
2037 // arith.cmpi eq, %val, %one : i1 -> %val
2038 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2039 getPredicate() == arith::CmpIPredicate::eq)
2040 return getLhs();
2041 }
2042
2043 // Move constant to the right side.
2044 if (adaptor.getLhs() && !adaptor.getRhs()) {
2045 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2046 using Pred = CmpIPredicate;
2047 const std::pair<Pred, Pred> invPreds[] = {
2048 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2049 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2050 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2051 {Pred::ne, Pred::ne},
2052 };
2053 Pred origPred = getPredicate();
2054 for (auto pred : invPreds) {
2055 if (origPred == pred.first) {
2056 setPredicate(pred.second);
2057 Value lhs = getLhs();
2058 Value rhs = getRhs();
2059 getLhsMutable().assign(rhs);
2060 getRhsMutable().assign(lhs);
2061 return getResult();
2062 }
2063 }
2064 llvm_unreachable("unknown cmpi predicate kind");
2065 }
2066
2067 // We are moving constants to the right side; So if lhs is constant rhs is
2068 // guaranteed to be a constant.
2069 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2071 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2072 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2073 return APInt(1,
2074 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2075 });
2076 }
2077
2078 return {};
2079}
2080
2081void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2082 MLIRContext *context) {
2083 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2084}
2085
2086//===----------------------------------------------------------------------===//
2087// CmpFOp
2088//===----------------------------------------------------------------------===//
2089
2090/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2091/// comparison predicates.
2092bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2093 const APFloat &lhs, const APFloat &rhs) {
2094 auto cmpResult = lhs.compare(rhs);
2095 switch (predicate) {
2096 case arith::CmpFPredicate::AlwaysFalse:
2097 return false;
2098 case arith::CmpFPredicate::OEQ:
2099 return cmpResult == APFloat::cmpEqual;
2100 case arith::CmpFPredicate::OGT:
2101 return cmpResult == APFloat::cmpGreaterThan;
2102 case arith::CmpFPredicate::OGE:
2103 return cmpResult == APFloat::cmpGreaterThan ||
2104 cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::OLT:
2106 return cmpResult == APFloat::cmpLessThan;
2107 case arith::CmpFPredicate::OLE:
2108 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2109 case arith::CmpFPredicate::ONE:
2110 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2111 case arith::CmpFPredicate::ORD:
2112 return cmpResult != APFloat::cmpUnordered;
2113 case arith::CmpFPredicate::UEQ:
2114 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2115 case arith::CmpFPredicate::UGT:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpGreaterThan;
2118 case arith::CmpFPredicate::UGE:
2119 return cmpResult == APFloat::cmpUnordered ||
2120 cmpResult == APFloat::cmpGreaterThan ||
2121 cmpResult == APFloat::cmpEqual;
2122 case arith::CmpFPredicate::ULT:
2123 return cmpResult == APFloat::cmpUnordered ||
2124 cmpResult == APFloat::cmpLessThan;
2125 case arith::CmpFPredicate::ULE:
2126 return cmpResult == APFloat::cmpUnordered ||
2127 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2128 case arith::CmpFPredicate::UNE:
2129 return cmpResult != APFloat::cmpEqual;
2130 case arith::CmpFPredicate::UNO:
2131 return cmpResult == APFloat::cmpUnordered;
2132 case arith::CmpFPredicate::AlwaysTrue:
2133 return true;
2134 }
2135 llvm_unreachable("unknown cmpf predicate kind");
2136}
2137
2138OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2139 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2140 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2141
2142 // If one operand is NaN, making them both NaN does not change the result.
2143 if (lhs && lhs.getValue().isNaN())
2144 rhs = lhs;
2145 if (rhs && rhs.getValue().isNaN())
2146 lhs = rhs;
2147
2148 if (!lhs || !rhs)
2149 return {};
2150
2151 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2152 return BoolAttr::get(getContext(), val);
2153}
2154
2155class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2156public:
2157 using Base::Base;
2158
2159 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2160 bool isUnsigned) {
2161 using namespace arith;
2162 switch (pred) {
2163 case CmpFPredicate::UEQ:
2164 case CmpFPredicate::OEQ:
2165 return CmpIPredicate::eq;
2166 case CmpFPredicate::UGT:
2167 case CmpFPredicate::OGT:
2168 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2169 case CmpFPredicate::UGE:
2170 case CmpFPredicate::OGE:
2171 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2172 case CmpFPredicate::ULT:
2173 case CmpFPredicate::OLT:
2174 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2175 case CmpFPredicate::ULE:
2176 case CmpFPredicate::OLE:
2177 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2178 case CmpFPredicate::UNE:
2179 case CmpFPredicate::ONE:
2180 return CmpIPredicate::ne;
2181 default:
2182 llvm_unreachable("Unexpected predicate!");
2183 }
2184 }
2185
2186 LogicalResult matchAndRewrite(CmpFOp op,
2187 PatternRewriter &rewriter) const override {
2188 FloatAttr flt;
2189 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2190 return failure();
2191
2192 const APFloat &rhs = flt.getValue();
2193
2194 // Don't attempt to fold a nan.
2195 if (rhs.isNaN())
2196 return failure();
2197
2198 // Get the width of the mantissa. We don't want to hack on conversions that
2199 // might lose information from the integer, e.g. "i64 -> float"
2200 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2201 int mantissaWidth = floatTy.getFPMantissaWidth();
2202 if (mantissaWidth <= 0)
2203 return failure();
2204
2205 bool isUnsigned;
2206 Value intVal;
2207
2208 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2209 isUnsigned = false;
2210 intVal = si.getIn();
2211 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2212 isUnsigned = true;
2213 intVal = ui.getIn();
2214 } else {
2215 return failure();
2216 }
2217
2218 // Check to see that the input is converted from an integer type that is
2219 // small enough that preserves all bits.
2220 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2221 auto intWidth = intTy.getWidth();
2222
2223 // Number of bits representing values, as opposed to the sign
2224 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2225
2226 // Following test does NOT adjust intWidth downwards for signed inputs,
2227 // because the most negative value still requires all the mantissa bits
2228 // to distinguish it from one less than that value.
2229 if ((int)intWidth > mantissaWidth) {
2230 // Conversion would lose accuracy. Check if loss can impact comparison.
2231 int exponent = ilogb(rhs);
2232 if (exponent == APFloat::IEK_Inf) {
2233 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2234 if (maxExponent < (int)valueBits) {
2235 // Conversion could create infinity.
2236 return failure();
2237 }
2238 } else {
2239 // Note that if rhs is zero or NaN, then Exp is negative
2240 // and first condition is trivially false.
2241 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2242 // Conversion could affect comparison.
2243 return failure();
2244 }
2245 }
2246 }
2247
2248 // Convert to equivalent cmpi predicate
2249 CmpIPredicate pred;
2250 switch (op.getPredicate()) {
2251 case CmpFPredicate::ORD:
2252 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2253 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2254 /*width=*/1);
2255 return success();
2256 case CmpFPredicate::UNO:
2257 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2258 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2259 /*width=*/1);
2260 return success();
2261 default:
2262 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2263 break;
2264 }
2265
2266 if (!isUnsigned) {
2267 // If the rhs value is > SignedMax, fold the comparison. This handles
2268 // +INF and large values.
2269 APFloat signedMax(rhs.getSemantics());
2270 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2271 APFloat::rmNearestTiesToEven);
2272 if (signedMax < rhs) { // smax < 13123.0
2273 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2274 pred == CmpIPredicate::sle)
2275 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2276 /*width=*/1);
2277 else
2278 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2279 /*width=*/1);
2280 return success();
2281 }
2282 } else {
2283 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2284 // +INF and large values.
2285 APFloat unsignedMax(rhs.getSemantics());
2286 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2287 APFloat::rmNearestTiesToEven);
2288 if (unsignedMax < rhs) { // umax < 13123.0
2289 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2290 pred == CmpIPredicate::ule)
2291 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2292 /*width=*/1);
2293 else
2294 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2295 /*width=*/1);
2296 return success();
2297 }
2298 }
2299
2300 if (!isUnsigned) {
2301 // See if the rhs value is < SignedMin.
2302 APFloat signedMin(rhs.getSemantics());
2303 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2304 APFloat::rmNearestTiesToEven);
2305 if (signedMin > rhs) { // smin > 12312.0
2306 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2307 pred == CmpIPredicate::sge)
2308 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2309 /*width=*/1);
2310 else
2311 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2312 /*width=*/1);
2313 return success();
2314 }
2315 } else {
2316 // See if the rhs value is < UnsignedMin.
2317 APFloat unsignedMin(rhs.getSemantics());
2318 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2319 APFloat::rmNearestTiesToEven);
2320 if (unsignedMin > rhs) { // umin > 12312.0
2321 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2322 pred == CmpIPredicate::uge)
2323 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2324 /*width=*/1);
2325 else
2326 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2327 /*width=*/1);
2328 return success();
2329 }
2330 }
2331
2332 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2333 // [0, UMAX], but it may still be fractional. See if it is fractional by
2334 // casting the FP value to the integer value and back, checking for
2335 // equality. Don't do this for zero, because -0.0 is not fractional.
2336 bool ignored;
2337 APSInt rhsInt(intWidth, isUnsigned);
2338 if (APFloat::opInvalidOp ==
2339 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2340 // Undefined behavior invoked - the destination type can't represent
2341 // the input constant.
2342 return failure();
2343 }
2344
2345 if (!rhs.isZero()) {
2346 APFloat apf(floatTy.getFloatSemantics(),
2347 APInt::getZero(floatTy.getWidth()));
2348 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2349
2350 bool equal = apf == rhs;
2351 if (!equal) {
2352 // If we had a comparison against a fractional value, we have to adjust
2353 // the compare predicate and sometimes the value. rhsInt is rounded
2354 // towards zero at this point.
2355 switch (pred) {
2356 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2357 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2358 /*width=*/1);
2359 return success();
2360 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2361 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2362 /*width=*/1);
2363 return success();
2364 case CmpIPredicate::ule:
2365 // (float)int <= 4.4 --> int <= 4
2366 // (float)int <= -4.4 --> false
2367 if (rhs.isNegative()) {
2368 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2369 /*width=*/1);
2370 return success();
2371 }
2372 break;
2373 case CmpIPredicate::sle:
2374 // (float)int <= 4.4 --> int <= 4
2375 // (float)int <= -4.4 --> int < -4
2376 if (rhs.isNegative())
2377 pred = CmpIPredicate::slt;
2378 break;
2379 case CmpIPredicate::ult:
2380 // (float)int < -4.4 --> false
2381 // (float)int < 4.4 --> int <= 4
2382 if (rhs.isNegative()) {
2383 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2384 /*width=*/1);
2385 return success();
2386 }
2387 pred = CmpIPredicate::ule;
2388 break;
2389 case CmpIPredicate::slt:
2390 // (float)int < -4.4 --> int < -4
2391 // (float)int < 4.4 --> int <= 4
2392 if (!rhs.isNegative())
2393 pred = CmpIPredicate::sle;
2394 break;
2395 case CmpIPredicate::ugt:
2396 // (float)int > 4.4 --> int > 4
2397 // (float)int > -4.4 --> true
2398 if (rhs.isNegative()) {
2399 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2400 /*width=*/1);
2401 return success();
2402 }
2403 break;
2404 case CmpIPredicate::sgt:
2405 // (float)int > 4.4 --> int > 4
2406 // (float)int > -4.4 --> int >= -4
2407 if (rhs.isNegative())
2408 pred = CmpIPredicate::sge;
2409 break;
2410 case CmpIPredicate::uge:
2411 // (float)int >= -4.4 --> true
2412 // (float)int >= 4.4 --> int > 4
2413 if (rhs.isNegative()) {
2414 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2415 /*width=*/1);
2416 return success();
2417 }
2418 pred = CmpIPredicate::ugt;
2419 break;
2420 case CmpIPredicate::sge:
2421 // (float)int >= -4.4 --> int >= -4
2422 // (float)int >= 4.4 --> int > 4
2423 if (!rhs.isNegative())
2424 pred = CmpIPredicate::sgt;
2425 break;
2426 }
2427 }
2428 }
2429
2430 // Lower this FP comparison into an appropriate integer version of the
2431 // comparison.
2432 rewriter.replaceOpWithNewOp<CmpIOp>(
2433 op, pred, intVal,
2434 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2435 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2436 return success();
2437 }
2438};
2439
2440void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2441 MLIRContext *context) {
2442 patterns.insert<CmpFIntToFPConst>(context);
2443}
2444
2445//===----------------------------------------------------------------------===//
2446// SelectOp
2447//===----------------------------------------------------------------------===//
2448
2449// select %arg, %c1, %c0 => extui %arg
2450struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2451 using Base::Base;
2452
2453 LogicalResult matchAndRewrite(arith::SelectOp op,
2454 PatternRewriter &rewriter) const override {
2455 // Cannot extui i1 to i1, or i1 to f32
2456 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2457 return failure();
2458
2459 // select %x, c1, %c0 => extui %arg
2460 if (matchPattern(op.getTrueValue(), m_One()) &&
2461 matchPattern(op.getFalseValue(), m_Zero())) {
2462 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2463 op.getCondition());
2464 return success();
2465 }
2466
2467 // select %x, c0, %c1 => extui (xor %arg, true)
2468 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2469 matchPattern(op.getFalseValue(), m_One())) {
2470 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2471 op, op.getType(),
2472 arith::XOrIOp::create(
2473 rewriter, op.getLoc(), op.getCondition(),
2474 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2475 op.getCondition().getType(), 1)));
2476 return success();
2477 }
2478
2479 return failure();
2480 }
2481};
2482
2483void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2484 MLIRContext *context) {
2485 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2486 SelectI1ToNot, SelectToExtUI>(context);
2487}
2488
2489OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2490 Value trueVal = getTrueValue();
2491 Value falseVal = getFalseValue();
2492 if (trueVal == falseVal)
2493 return trueVal;
2494
2495 Value condition = getCondition();
2496
2497 // select true, %0, %1 => %0
2498 if (matchPattern(adaptor.getCondition(), m_One()))
2499 return trueVal;
2500
2501 // select false, %0, %1 => %1
2502 if (matchPattern(adaptor.getCondition(), m_Zero()))
2503 return falseVal;
2504
2505 // If either operand is fully poisoned, return the other.
2506 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2507 return falseVal;
2508
2509 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2510 return trueVal;
2511
2512 // select %x, true, false => %x
2513 if (getType().isSignlessInteger(1) &&
2514 matchPattern(adaptor.getTrueValue(), m_One()) &&
2515 matchPattern(adaptor.getFalseValue(), m_Zero()))
2516 return condition;
2517
2518 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2519 auto pred = cmp.getPredicate();
2520 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2521 auto cmpLhs = cmp.getLhs();
2522 auto cmpRhs = cmp.getRhs();
2523
2524 // %0 = arith.cmpi eq, %arg0, %arg1
2525 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2526
2527 // %0 = arith.cmpi ne, %arg0, %arg1
2528 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2529
2530 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2531 (cmpRhs == trueVal && cmpLhs == falseVal))
2532 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2533 }
2534 }
2535
2536 // Constant-fold constant operands over non-splat constant condition.
2537 // select %cst_vec, %cst0, %cst1 => %cst2
2538 if (auto cond =
2539 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2540 // DenseElementsAttr by construction always has a static shape.
2541 assert(cond.getType().hasStaticShape() &&
2542 "DenseElementsAttr must have static shape");
2543 if (auto lhs =
2544 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2545 if (auto rhs =
2546 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2547 SmallVector<Attribute> results;
2548 results.reserve(static_cast<size_t>(cond.getNumElements()));
2549 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2550 cond.value_end<BoolAttr>());
2551 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2552 lhs.value_end<Attribute>());
2553 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2554 rhs.value_end<Attribute>());
2555
2556 for (auto [condVal, lhsVal, rhsVal] :
2557 llvm::zip_equal(condVals, lhsVals, rhsVals))
2558 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2559
2560 return DenseElementsAttr::get(lhs.getType(), results);
2561 }
2562 }
2563 }
2564
2565 return nullptr;
2566}
2567
2568ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2569 Type conditionType, resultType;
2570 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2571 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2572 parser.parseOptionalAttrDict(result.attributes) ||
2573 parser.parseColonType(resultType))
2574 return failure();
2575
2576 // Check for the explicit condition type if this is a masked tensor or vector.
2577 if (succeeded(parser.parseOptionalComma())) {
2578 conditionType = resultType;
2579 if (parser.parseType(resultType))
2580 return failure();
2581 } else {
2582 conditionType = parser.getBuilder().getI1Type();
2583 }
2584
2585 result.addTypes(resultType);
2586 return parser.resolveOperands(operands,
2587 {conditionType, resultType, resultType},
2588 parser.getNameLoc(), result.operands);
2589}
2590
2591void arith::SelectOp::print(OpAsmPrinter &p) {
2592 p << " " << getOperands();
2593 p.printOptionalAttrDict((*this)->getAttrs());
2594 p << " : ";
2595 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2596 p << condType << ", ";
2597 p << getType();
2598}
2599
2600LogicalResult arith::SelectOp::verify() {
2601 Type conditionType = getCondition().getType();
2602 if (conditionType.isSignlessInteger(1))
2603 return success();
2604
2605 // If the result type is a vector or tensor, the type can be a mask with the
2606 // same elements.
2607 Type resultType = getType();
2608 if (!llvm::isa<TensorType, VectorType>(resultType))
2609 return emitOpError() << "expected condition to be a signless i1, but got "
2610 << conditionType;
2611 Type shapedConditionType = getI1SameShape(resultType);
2612 if (conditionType != shapedConditionType) {
2613 return emitOpError() << "expected condition type to have the same shape "
2614 "as the result type, expected "
2615 << shapedConditionType << ", but got "
2616 << conditionType;
2617 }
2618 return success();
2619}
2620//===----------------------------------------------------------------------===//
2621// ShLIOp
2622//===----------------------------------------------------------------------===//
2623
2624OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2625 // shli(x, 0) -> x
2626 if (matchPattern(adaptor.getRhs(), m_Zero()))
2627 return getLhs();
2628 // Don't fold if shifting more or equal than the bit width.
2629 bool bounded = false;
2631 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2632 bounded = b.ult(b.getBitWidth());
2633 return a.shl(b);
2634 });
2635 return bounded ? result : Attribute();
2636}
2637
2638//===----------------------------------------------------------------------===//
2639// ShRUIOp
2640//===----------------------------------------------------------------------===//
2641
2642OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2643 // shrui(x, 0) -> x
2644 if (matchPattern(adaptor.getRhs(), m_Zero()))
2645 return getLhs();
2646 // Don't fold if shifting more or equal than the bit width.
2647 bool bounded = false;
2649 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2650 bounded = b.ult(b.getBitWidth());
2651 return a.lshr(b);
2652 });
2653 return bounded ? result : Attribute();
2654}
2655
2656//===----------------------------------------------------------------------===//
2657// ShRSIOp
2658//===----------------------------------------------------------------------===//
2659
2660OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2661 // shrsi(x, 0) -> x
2662 if (matchPattern(adaptor.getRhs(), m_Zero()))
2663 return getLhs();
2664 // Don't fold if shifting more or equal than the bit width.
2665 bool bounded = false;
2667 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2668 bounded = b.ult(b.getBitWidth());
2669 return a.ashr(b);
2670 });
2671 return bounded ? result : Attribute();
2672}
2673
2674//===----------------------------------------------------------------------===//
2675// Atomic Enum
2676//===----------------------------------------------------------------------===//
2677
2678/// Returns the identity value attribute associated with an AtomicRMWKind op.
2679TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2680 OpBuilder &builder, Location loc,
2681 bool useOnlyFiniteValue) {
2682 switch (kind) {
2683 case AtomicRMWKind::maximumf: {
2684 const llvm::fltSemantics &semantic =
2685 llvm::cast<FloatType>(resultType).getFloatSemantics();
2686 APFloat identity = useOnlyFiniteValue
2687 ? APFloat::getLargest(semantic, /*Negative=*/true)
2688 : APFloat::getInf(semantic, /*Negative=*/true);
2689 return builder.getFloatAttr(resultType, identity);
2690 }
2691 case AtomicRMWKind::maxnumf: {
2692 const llvm::fltSemantics &semantic =
2693 llvm::cast<FloatType>(resultType).getFloatSemantics();
2694 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2695 return builder.getFloatAttr(resultType, identity);
2696 }
2697 case AtomicRMWKind::addf:
2698 case AtomicRMWKind::addi:
2699 case AtomicRMWKind::maxu:
2700 case AtomicRMWKind::ori:
2701 case AtomicRMWKind::xori:
2702 return builder.getZeroAttr(resultType);
2703 case AtomicRMWKind::andi:
2704 return builder.getIntegerAttr(
2705 resultType,
2706 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2707 case AtomicRMWKind::maxs:
2708 return builder.getIntegerAttr(
2709 resultType, APInt::getSignedMinValue(
2710 llvm::cast<IntegerType>(resultType).getWidth()));
2711 case AtomicRMWKind::minimumf: {
2712 const llvm::fltSemantics &semantic =
2713 llvm::cast<FloatType>(resultType).getFloatSemantics();
2714 APFloat identity = useOnlyFiniteValue
2715 ? APFloat::getLargest(semantic, /*Negative=*/false)
2716 : APFloat::getInf(semantic, /*Negative=*/false);
2717
2718 return builder.getFloatAttr(resultType, identity);
2719 }
2720 case AtomicRMWKind::minnumf: {
2721 const llvm::fltSemantics &semantic =
2722 llvm::cast<FloatType>(resultType).getFloatSemantics();
2723 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2724 return builder.getFloatAttr(resultType, identity);
2725 }
2726 case AtomicRMWKind::mins:
2727 return builder.getIntegerAttr(
2728 resultType, APInt::getSignedMaxValue(
2729 llvm::cast<IntegerType>(resultType).getWidth()));
2730 case AtomicRMWKind::minu:
2731 return builder.getIntegerAttr(
2732 resultType,
2733 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2734 case AtomicRMWKind::muli:
2735 return builder.getIntegerAttr(resultType, 1);
2736 case AtomicRMWKind::mulf:
2737 return builder.getFloatAttr(resultType, 1);
2738 // TODO: Add remaining reduction operations.
2739 default:
2740 (void)emitOptionalError(loc, "Reduction operation type not supported");
2741 break;
2742 }
2743 return nullptr;
2744}
2745
2746/// Returns the identity numeric value of the given op.
2747std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2748 std::optional<AtomicRMWKind> maybeKind =
2750 // Floating-point operations.
2751 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2752 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2753 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2754 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2755 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2756 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2757 // Integer operations.
2758 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2759 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2760 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2761 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2762 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2763 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2764 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2765 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2766 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2767 .Default(std::nullopt);
2768 if (!maybeKind) {
2769 return std::nullopt;
2770 }
2771
2772 bool useOnlyFiniteValue = false;
2773 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2774 if (fmfOpInterface) {
2775 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2776 useOnlyFiniteValue =
2777 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2778 }
2779
2780 // Builder only used as helper for attribute creation.
2781 OpBuilder b(op->getContext());
2782 Type resultType = op->getResult(0).getType();
2783
2784 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2785 useOnlyFiniteValue);
2786}
2787
2788/// Returns the identity value associated with an AtomicRMWKind op.
2789Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2790 OpBuilder &builder, Location loc,
2791 bool useOnlyFiniteValue) {
2792 auto attr =
2793 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2794 return arith::ConstantOp::create(builder, loc, attr);
2795}
2796
2797/// Return the value obtained by applying the reduction operation kind
2798/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2800 Location loc, Value lhs, Value rhs) {
2801 switch (op) {
2802 case AtomicRMWKind::addf:
2803 return arith::AddFOp::create(builder, loc, lhs, rhs);
2804 case AtomicRMWKind::addi:
2805 return arith::AddIOp::create(builder, loc, lhs, rhs);
2806 case AtomicRMWKind::mulf:
2807 return arith::MulFOp::create(builder, loc, lhs, rhs);
2808 case AtomicRMWKind::muli:
2809 return arith::MulIOp::create(builder, loc, lhs, rhs);
2810 case AtomicRMWKind::maximumf:
2811 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2812 case AtomicRMWKind::minimumf:
2813 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2814 case AtomicRMWKind::maxnumf:
2815 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2816 case AtomicRMWKind::minnumf:
2817 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2818 case AtomicRMWKind::maxs:
2819 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2820 case AtomicRMWKind::mins:
2821 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2822 case AtomicRMWKind::maxu:
2823 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2824 case AtomicRMWKind::minu:
2825 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2826 case AtomicRMWKind::ori:
2827 return arith::OrIOp::create(builder, loc, lhs, rhs);
2828 case AtomicRMWKind::andi:
2829 return arith::AndIOp::create(builder, loc, lhs, rhs);
2830 case AtomicRMWKind::xori:
2831 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2832 // TODO: Add remaining reduction operations.
2833 default:
2834 (void)emitOptionalError(loc, "Reduction operation type not supported");
2835 break;
2836 }
2837 return nullptr;
2838}
2839
2840//===----------------------------------------------------------------------===//
2841// TableGen'd op method definitions
2842//===----------------------------------------------------------------------===//
2843
2844#define GET_OP_CLASSES
2845#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2846
2847//===----------------------------------------------------------------------===//
2848// TableGen'd enum attribute definitions
2849//===----------------------------------------------------------------------===//
2850
2851#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:721
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:682
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:781
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:763
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:149
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:129
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:962
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:173
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:437
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:141
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static bool classof(Operation *op)
Definition ArithOps.cpp:351
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:328
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:357
static bool classof(Operation *op)
Definition ArithOps.cpp:378
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:255
static bool classof(Operation *op)
Definition ArithOps.cpp:322
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:384
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.